It's not a loss.
In reinforcement learning, we are interested in modeling agents interacting with environments by receiving observations and taking actions, with the goal of maximizing some cumulative reward signal. Concretely, we want to some function \(\pi(\theta, s)\) which takes some input environment observations \(s\) and parameters \(\theta\) and outputs some probability distribution over actions. After the agent takes an action, the environment transitions to a new state \(s'\) and gives it some reward \(R\). The goal is to find parameters \(\theta^*\) that maximize the expected reward signal \(R\) we get from interacting with some environment by following the policy \(\pi(\cdot, \theta^*)\).
If you want to maximize something, the obvious algorithm is gradient ascent, following the policy gradient. This idea has spawned a cottage industry of policy gradient methods, which take the expected return \(v(s) = \mathbb{E}_{\mu_\theta}[\sum_{t=0}^\infty \gamma(t) R(s_t, a_t)| s_0 = s]\) (where \(\mu_\theta\) is the distribution over environment trajectories induced by the policy \(\pi_\theta\), assuming the usual ergodicity etc. conditions) and estimates the gradient \(\nabla_\theta v(s_0)\), then updates the parameters in this direction (modulo some regularization, momentum, adaptive step size, etc).
This is important: the gradient that all policy gradient methods are computing is the gradient of the return with respect to the parameters. Insofar as you can say policy gradients are running gradient descent on a loss, the loss is just the return! And yet somehow, whenever I go on a safari into what people on the internet write about policy gradients, I inevitably run into discussions of some mysterious “loss” which, as far as I can tell, is decidedly not the expected sum of discounted rewards. Instead it looks something like this:
\[ \mathcal{L}(s, a, \theta) = -A(s,a) \log \pi_\theta(s,a)\]
where \(A\) is the estimated advantage given rollout and some value baseline \(V\).
I don’t know about you, but \(-A(s,a)\log \pi(s,a)\) doesn’t look much like \(\mathbb{E}_{\mu_\theta}\sum \gamma^t R_t\) to me, and when I was initially studying policy gradient methods it was easy to forget that the thing that was being optimized at the end of the day was indeed the return that every RL researcher comes to know and love. So this weekend I engaged in a practice of radical empathy to understand what is driving the gaping chasm between these two interpretations of what it is that REINFORCE is trying to optimize.
In supervised learning, the estimated objective function and the thing you feed into jax.grad are usually the same thing. When you estimate a gradient, you do so by first estimating the loss on some data and then taking the gradient of the estimated loss. This works because the sampling distribution, which is generally uniform over a fixed dataset, is usually independent of the model parameters, so you can swap the gradient and the expectation
\[\nabla_\theta \mathbb{E}_{x,y \sim \mathcal{D}}[ \mathcal{L}(f_\theta(x), y)] = \mathbb{E}_{x,y \sim \mathcal{D}}[\nabla_\theta \mathcal{L}(f_\theta(x), y)] \]
which means you can get an unbiased estimator of the gradient by sampling a bunch of points, then computing jax.grad on the estimated loss.
In RL, the parameters we optimize show up in the distribution we’re computing an expectation over, so you can’t swap the \(\nabla_\theta\) and the \(\mathbb{E}_{\mu_\theta}\) without making your math teacher extremely angry. Concretely, \[\nabla_\theta \mathbb{E}_{(s_t, a_t)\sim \mu_\theta^t}\bigg[ \sum_{t=0}^\infty \gamma^t R(s_t, a_t)\bigg] \neq \mathbb{E}_{(s_t, a_t)\sim \mu_\theta^t}\bigg[ \nabla_\theta \sum_{t=0}^\infty \gamma^t R(s_t, a_t)\bigg]\] Which is easy to see because once you’ve selected the action \(a_t\), the reward is independent of the parameters of the policy you used to sample and so \[ \mathbb{E}_{(s_t, a_t)\sim \mu_\theta^t}\bigg[ \nabla_\theta \sum_{t=0}^\infty \gamma^t R(s_t, a_t)\bigg] = 0\]
which definitely isn’t a useful gradient.
So to estimate the gradient, we have to do something else.
Converting \(\mathbb{E}_{\mu_\theta}\bigg [\sum \gamma_t R_t\bigg]\) into something that you can feed into jax.grad requires a few non-obvious steps that cause the gradient estimator to look very different from the function it is estimating. If you’re used to learning paradigms where you differentiate a sampled loss, you can be forgiven for assuming that this complicated combination of sampled advantages and log likelihoods is the function you’re optimizing. Forgiven, but still technically wrong.
The goal of this blog post is to add a bit of clarity into the discussion around policy gradient methods so that when you are talking about “the PPO loss” or the “GRPO loss” or the “ABCDEFGHIJKLMNOPO” loss, the type signature of your statement is slightly less murky. To do so, I’ll give a quick walk-through of how the estimated return is converted into a weighted sum of log-likelihoods for gradient estimation. I’ll then explain why, even though this thing looks like a loss, you should be careful about treating it like one. If you aren’t sick of my pedanticism by this point, you can stick around for a brief lecture on why the gradient people compute in temporal difference learning also isn’t really a gradient.
The way I wrote down \(\mathbb{E}_{\mu^\theta}[\sum_{t=0}^\infty \gamma^t R_t ]\) obfuscates a great deal of the interdependency between the expectation and the parameters \(\theta\). Once we write things out more explicitly, the mystery of how the log probability shows up will start to disappear.
As a quick note on notation, the function \(v\) denotes the expected state-value function \(v_\theta (s) = \mathbb{E}_{\mu_\theta}[\sum \gamma^t R(s_t, a_t) | s_0 = s]\), while \(q\) denotes the action-value function \(q_\theta (s, a) = \mathbb{E}_{\mu_\theta}[\sum \gamma^t R(s_t, a_t) | s_0 = s, a_0 = a]\),
The proof I’m going to give is a bit different from the one in Sutton and Barto, and a bit shorter at the expense of a slight a buse of matrix notation. We start by rewriting the gradient computation using a Bellman-like equation. \(\begin{align} \nabla_\theta v_\theta(s_0) &= \nabla_\theta \sum_{a} \pi_\theta(a|s_0) q_\theta(s_0, a) \\ &= \sum_{a} \nabla_\theta \pi_\theta(a|s_0) q_\theta(s_0, a) + \pi_\theta(a|s_0) \nabla_\theta q_\theta(s_0, a) \\ &= \sum_{a}[\nabla_\theta \pi_\theta(a|s_0) q_\theta(s_0, a) + \gamma \pi_\theta(a|s_0)\nabla_\theta \sum_{s'} p(s'|s_0, a) v_\theta(s')] \\ &= \sum_{a}[\nabla_\theta \pi_\theta(a|s_0) q_\theta(s_0, a) + \gamma \pi_\theta(a|s_0) \sum_{s'} p(s'|s_0, a) \nabla_\theta v_\theta(s')] \\ &= \sum_a \nabla \pi_\theta(a|s_0) q_\theta(s_0, a) + \mathbb{E}_{s' \sim P^\pi_\theta(s_0)}\nabla_\theta v_\theta(s') \\ &= \sum_a \nabla \pi_\theta(a|s_0) q_\theta(s_0, a) + (\gamma P^\pi_\theta \nabla_\theta v)(s_0) \end{align}\)
This general form should be familiar to any reader who has taken a course in RL or read the Q-learning paper, since we now have a recursion of the form \(v = a + (\gamma P^\pi v)\), which has the closed form \((I - \gamma P^\pi)^{-1}a\). This step sweeps a lot of type signatures under the rug, but if you have the intuition that gradients are linear and so play nice with other linear functions then you hopefully won’t find the notion of swapping out \(a\) with \(\nabla \pi_\theta(a|s) q_\theta(s_0, a)\) too offensive.
Before applying this recursion, we can apply the score function trick to make the sum a bit nicer (remember, we’re looking for something we can use as an estimator, and in practice it isn’t possible to try every action in every state at once). We get this by multiplying by 1, or more precisely \(\frac{\pi(a|s)}{\pi(a|s)}\), to turn the sum into an expectation.
\(\begin{align} \sum_a \nabla \pi(a|s_0) q(s_0, a) &= \sum_a \frac{\pi(a|s_0)}{\pi(a|s_0)}\nabla \pi(a|s_0) q(s_0, a) \\ &= \mathbb{E}_{a \sim \pi(a|s_0)} \bigg [ \frac{1}{\pi(a|s_0)}\nabla \pi(a|s_0) q(s_0, a) \bigg ] \\ &= \mathbb{E}_{a \sim \pi(s_0)} [q(s_0, a) \nabla_\theta \log \pi(a|s_0)] \end{align}\)
Now combining these and converting to vectorized notation, we get
\(\begin{align} \nabla_\theta V &= \mathbb{E}_{a \sim \pi(a | \cdot)} [q(a,\cdot)\nabla_\theta \log \pi(a|\cdot)] + \gamma P^\pi \nabla_\theta V \\ &= (I - \gamma P^\pi)^{-1}[ \langle Q, \nabla_\theta \log \pi \rangle_{\pi_\theta} ] \end{align}\)
where we swap the expectation to vector notation because it is less likely to overflow small screens when mathjax renders it.
Now to make things look like an expectation again (which is important because that’s how we get an unbiased sampler), we note that the discounted sum of expected values can be viewed as an expectation of a discounted sum of probabilities, which, when appropriately normalized by \(\frac{1}{1-\gamma}\), is itself a valid probability distribution. This means we can write \(\begin{align} (I - \gamma P^\pi)^{-1}[ \langle Q, \nabla_\theta \log \pi \rangle_{\pi_\theta} ] &= \frac{1}{1-\gamma}\mathbb{E}_{s, a \sim \mu_\gamma^\pi} [Q (s,a) \nabla_\theta \log \pi(a|s)] \end{align}\)
So now we have that sampling from the distribution \(\mu^\gamma_\pi\) the `loss’ \(\nabla_\theta \log \pi(a|s) Q(s,a)\) gives a gradient ascent direction for the value \(V_\theta(s_0)\).
Getting to \(\mathcal{L}\) requires two simple steps: first, because machine learning researchers only know about gradient descent, we turn the maximization problem into a minimization problem by multiplying by -1. Second, we subtract a value baseline \(V(s)\) from \(Q(s,a)\) because this doesn’t change the gradient but does reduce the variance of its estimator. Subtracting a value \(V\) turns a Q-value estimate \(Q(s,a)\) into an advantage estimate \(A(s,a)\).
Although we do take a gradient through this function to estimate the policy gradient, \(\sum_{(s,a) \in \mathcal{B}}-A(s,a) \nabla_\theta \log \pi_\theta (a|s)\) isn’t a loss in the sense that it has a number of properties that you wouldn’t expect in a loss function. There are several ways that the `loss = thing inside of grad’ intuition breaks, three of which I list below.
Intuition 1: if we apply grad to a loss, the thing we get should be the gradient of some function. The things that you’re multiplying \(\nabla_\theta \log \pi_\theta\) by are actually also functions of \(\theta\), since they are a result of the policy that \(\theta\) parameterizes, but you’re implicitly stop-gradienting them. When you have a stop_gradient in your loss function, the thing you get out of jax.grad won’t in general be the gradient of that function (see the later discussion of temporal difference methods for more on this).
Intuition 2: we should be able to take multiple gradient steps on the loss function and not break optimization. The function \(\sum_{(s,a) \in \mathcal{B}}-A(s,a) \nabla_\theta \log \pi_\theta (a|s)\) being an unbiased estimate of the gradient \(\nabla_\theta v_\theta(s_0)\) is dependent on the sampling distribution and the advantage estimates coming from the same parameters \(\theta\) as are being used in the log probabilities. This will be true for your first gradient step, but most policy gradient algorithms use a batched approach so that you do many gradient steps on the same data distribution before you update the data-generating policy parameters. So not only are you not actually estimating the gradient of the function you’re applying jax.grad to, but after your first gradient step (unless you do something smart like importance sampling) you’re not even getting an unbiased estimate of the gradient you were trying to approximate. A loss function that you can only take one gradient step on is a pretty bad loss function.
Intuition 3: the number of datapoints we use to estimate the loss should linearly decrease its variance. The effective number of data points contributing to the variance of the gradient estimate is much lower than you’d expect in supervised learning because advantage estimates introduce weird dependencies between data points that you usually assume don’t exist in a fixed cost function. If my agent wins a game, then vanilla policy gradient will reinforce all of the actions that led to that win, meaning that my gradients for each state-action pair in that trajectory will be highly correlated. This means that the information-theoretic value of running my agent for \(k\) steps in \(n\) parallel episodes is closer to what I would expect from \(n\) independent samples (the number of independent trajectories), rather than \(kn\) independent samples (the number of state-action pairs that I have advantage estimates for).
All of these things conspire to make the RL training objective extremely ill-behaved for a native of supervised learning.
So we’ve established that the policy gradient is a gradient, but that the thing being fed into grad isn’t a loss as we would normally think of in supervised learning.
You might hope that temporal difference methods, the other main category of RL algorithm, will offer a better deal. Unfortunately, although the TD loss is indeed a loss, the `semi-gradient’ used in the update is unfortunately not a gradient.
As a quick reminder, in value-based methods, we aim to estimate the value of the optimal policy \[Q^{\pi^*}(s,a) = \mathbb{E}_{(s_t, a_t)\sim P^\pi}\bigg[ \sum \gamma^t R(s_t, a_t) | s_0 = s, a_0=a \bigg ].\]
The way that most value-based methods do this is to collect some data from a behaviour policy \(\pi_B\), estimate \(Q^{\pi_B}\), and then use this value estimate to improve \(\pi_B\) for the next round of data collection. Obviously, if you know the expected return under \(\pi_B\) of every state-action pair in the environment, you can improve on \(\pi_B\) by deterministically picking the action with the highest expected value. This approach, also known as policy iteration, basically amounts to a sequence of supervised learning problems for each behaviour policy \(\pi_B\). If you collect near-infinite data and regress on the sampled returns for near-infinite time, you will converge stably to the optimal policy. Your learner will also be extremely inefficient, so nobody actually does this in practice.
Instead, most practical value-learning methods interleave data collection, target construction, and optimization in a complicated dance where each partner is anticipating the response of the other. Rather than using Monte Carlo returns, most value-based methods use bootstrapping, which involves pretending that your predictions don’t depend on your current parameters and treating them as an God-given source of truth about the future. So rather than following the gradient of the well-behaved loss \[\ell_{MC}(\theta, \tau) =\sum_{s_t, a_t \in \tau} \bigg ( Q_\theta(s_t,a_t) -\sum_{t=0}^{|\tau| - t} \gamma^t R_\tau(s_t, a_t) \bigg)^2,\] we instead follow \[\ell_{TD}(\theta, \tau) =\sum_{s_t, a_t \in \tau} \bigg ( Q_\theta(s_t,a_t) -R_\tau(s_t, a_t) - \gamma Q_{\bar{\theta}}(s_{t+1}, a'_{t+1})\bigg)^2\]
which is now regressing with respect to a target \(R + \gamma Q_{\bar{\theta}}\) given by the parameters \(\bar{\theta}\), which if we stick our fingers in our ears and yell “stop-gradient” enough times, we can almost convince ourselves shouldn’t contribute to the gradient of this L2 distance\(^1\). (If you’re mystified by the sudden appearance of \(R + \gamma Q\), I recommend referring to a lovely textbook by Sutton and Barto, chapter 4.) I also use \(a'_{t+1}\) here because some algorithms like \(Q\)-learning use a different action selection rule for deciding which actions to back up.
Although this might sound sacrilegious to a disciple of gradient descent, most of the time following the not-a-gradient converges to the optimal value function faster (certainly at least in terms of environment samples required, and consequently generally optimizer steps as well) than the slower gradient-based approach. The reasons for this are complicated, but at a high level amount to a combination of variance reduction and faster communication between state-action pairs. Sutton and Barto (section 6.1) gives a long explanation of this with many pictures and a fun example of commuting in traffic which I can recommend to interested readers.
Of course, if your learning dynamics don’t look like you’re following a gradient then you can be vulnerable to very indirect convergence or even divergence – but at the same time, gradient descent can be slow depending on things like the conditioning of the optimization problem. The whole point of second-order methods is to say that you can find a descent direction that’s better than the gradient, so it shouldn’t be a surprise that if your problem has additional structure, a method that leverages that structure will be faster than blindly stumbling in the direction of local steepest descent.
While you might be forgiven for thinking that the thing you feed into your favourite autodiff library’s grad function should be called a ‘loss’, I think doing so in this case has led to a lot of confusion about what the heck policy gradient methods are actually doing.
A policy gradient is the gradient of the return with respect to the policy parameters.
All of the fancy clipping terms, generalized advantage estimates, group normalizations, and bizarrely weighted log probabilities are just different ways of estimating \(\nabla_\theta V(\theta)\). It’s easy to get lost in the noise of how to estimate the advantage, what dimensions to average over, or what prompts to filter. But ultimately, it’s important to remember that these are just different ways of estimating the same gradient. Otherwise, you’ll go around calling something that isn’t the thing you’re estimating the gradient of your loss function. </rant>.
1.To prove that the stop-gradient makes this whole thing not a gradient, you just pretend that the update is a gradient and then take a derivative w.r.t. the value at some other state-action pair. For simplicity, we define the TD update \(u_{TD}(s,a) = R(s,a) + \gamma (P^\pi Q )(s,a) - Q(s,a)\). If \(u_{TD}(s,a) = \nabla_{Q(s,a)} \ell_{\text{magic}}(Q)\), then it should satisfy \[\nabla_{Q(s', a')} u_{TD}(s,a) = \nabla_{Q(s', a')} \nabla_{Q(s,a)} \ell_{\text{magic}}(Q) = \nabla_{Q(s, a)} \nabla_{Q(s',a')} \ell_{\text{magic}}(Q) = \nabla_{Q(s, a)} u_{TD}(s',a') .\] It’s pretty easy to come up with examples where this equality doesn’t hold.