a quick derivation
As I explained in a previous post, the policy gradient update looks like the following: \[ \theta_{t+1} = \theta_t + \alpha \nabla_\theta J(\pi_\theta) \] where the gradient can be estimated from a trajectory \(\tau\) as \[ \widehat{\nabla_\theta J(\pi_\theta)} = \sum_{s_t, a_t \in \tau}{A}_t \nabla_\theta \log \pi_\theta(a_t|s_t)\] where \(A_t\) is your favourite estimate of the advantage.
This is a straightforward derivation, even though the log derivative trick might take a second to grok. But the most popular policy gradient method, PPO, notably doesn’t have a log in the gradient estimator. Instead, the thing you feed into jax.grad looks like
\[ \mathcal{L}(\tau, \theta, \pi_{\text{ref}}) = \sum_{s_t, a_t \in \tau}\min(\frac{\pi(a_t|s_t)}{\pi_{\text{ref}}(a_t|s_t)}A_t, \text{clip}(\frac{\pi(a_t|s_t)}{\pi_{\text{ref}}(a_t|s_t)}, 1-\epsilon, 1+\epsilon )A_t) \]
with no logarithm in sight.
As a quick reminder, the reason why we use a logarithm is because we want to estimate \(\nabla_\theta \sum_{a\in \mathcal{A}}\pi_\theta(a|s)Q^{\pi_\theta}(a,s)\). Since we are generally not capable of implementing a quantum superposition of agents which take every possible action, we use the fact that \[\pi(a|s) \nabla_\theta \log \pi(a|s) = \pi(a|s) \frac{1}{\pi(a|s)} \nabla_\theta \pi(a|s) = \nabla_\theta \pi(a|s)\]
to get an estimator that can use the one-action-at-a-time generative process that produces most RL agent trajectories. This means if \(a\) is sampled from \(\pi(\cdot | s)\), we get an unbiased estimator from the derivative of the log likelihood of the selected actions. But in PPO, we run our optimizer for multiple epochs over each batch of trajectories, which means that the estimator is actually using \(\pi_{\text{ref}}(a|s) \nabla_\theta \log \pi(a|s)\), which is not the derivative we’re looking for.
Instead, we do the following:
\(\begin{align} \nabla_\theta (\pi_\theta(\cdot|s)A(\cdot, s)) = \sum_{a \in \mathcal{A}}\nabla_\theta \pi_\theta(a|s) A(a,s) &= \sum_{a \in \mathcal{A}} \frac{\pi_r(a|s)}{\pi_r(a|s)} \nabla_\theta \pi_\theta(a|s) A(a,s) \\ &= \sum_{a \in \mathcal{A}} \pi_r(a|s) \frac{\nabla_\theta \pi_\theta(a|s)}{\pi_r(a|s)}A(a,s)\\ &= \mathbb{E}_{a \sim \pi_r(\cdot|s)} \nabla_\theta \frac{\pi_\theta(a|s)}{\pi_r(a|s)}A(a,s) \end{align}\)
which is the ratio we see in the PPO update. So because the distribution we’re computing the expectation over doesn’t share parameters with the one we’re differentiating, we don’t need the log.
The astute reader will note that this seems a bit suspicious, because in the first step of PPO \(\pi_{ref}\) is \(\pi_\theta\) and if we’re allowed to pretend that \(\pi_{ref}\) is independent of \(\theta\) then we could have just done that in the previous example and gotten away without the log as well. But since this would end up taking the form of \(\frac{\nabla_\theta \pi_\theta(a|s)}{\pi_\theta(a|s)} = \nabla \log \pi_\theta(a|s)\), we get the log as a special case of the more general importance sampling update.
I’ve included the step by step derivation of this fact below in case you don’t trust me. \(\begin{align} \nabla_\theta (\pi_\theta(\cdot|s)A(\cdot, s)) &= \sum_{a \in \mathcal{A}}\nabla_\theta \pi_\theta(a|s) A(a,s) \\ &= \sum_{a \in \mathcal{A}} \frac{\pi_\theta(a|s)}{\pi_\theta(a|s)} \nabla_\theta \pi_\theta(a|s) A(a,s) \\ &= \sum_{a \in \mathcal{A}} \pi_\theta(a|s) \frac{ \nabla_\theta \pi_\theta(a|s) A(a,s)}{\pi_\theta(a|s)} \\ &= \mathbb{E}\frac{ \nabla_\theta \pi_\theta(a|s) A(a,s)}{\pi_\theta(a|s)} \\ &= \mathbb{E} \nabla_\theta \log \pi_\theta(a|s) A(a,s) \end{align}\)
So if you’re going to take away one thing from this blog post, it’s that it can be more intuitive to replace \(\nabla \log \pi\) with \(\frac{\nabla \pi}{\pi}\) in your head whenever you read policy gradient papers so that the importance sampling version of the algorithm feels like less of a jump.