Since Policy Gradient Methods, I've been curious about how LLMs are taught/trained. 

It turns out that folks use reinforcement learning to train LLMs (Large Language Models) too, and it amounts, very similarly, to the Policy Gradient Methods I recently discussed.

For example, an LLM predicts the next token given the prior tokens and therefore, there must be a way to evaluate the generated token against what the next token should be, i.e there must be a loss function or a reward function. Let's explore. 

Preliminaries

LLMs can be formally described thusly:

\[ p_{\theta}(x_{t+1}|x_{1:t}) \]

where \( p_{\theta}\) is the LLM which uses \(\theta\) parameters, and which generates/predicts the next token (\(x_{t+1}\)) given the previous sequence of tokens (\(x_{1:t} \)). So, in a nutshell, it generates the next token given the prior tokens. 

If you want to generate tokens continuously (autoregressive LLM), it merely keeps doing so by iteratively sampling the next token (\(x_{t+1}\)) from the LLM, i.e:

\[ x_{t+1} \sim p_{\theta}(\cdot|x_{1:t}) \]

If instead you want to do conditional generation, that is, use a prompt that conditions how the LLM will generate its next token, the following model describes that:

\[ p_{\theta}(x_{1:n}|c) = \prod_{t=q}^n p_{\theta}(x_t|c, x_{1:t-1})\]

This means (I think) the generated tokens \(x_{1:n}\) are conditioned on the context c, which is a sequence of prior input tokens and can be considered as the prompt to the LLM. Internally, this is the product probabilities (what LLMs really output) of all prior tokens given the prompt/context/c. It must be said that the above interpretation and my knowledge of exactly how this is done are not very well-defined, so tread lightly here! Anyway, the interesting bit is next. 

Treating LLM token generation as an MDP

You can define an LLM as a Markov Decision Process (MDP), that is, it can be formulated as a set of states, actions, transition probabilities, a reward function and a discount factor (all the necessary components of an MDP).

This can be represented by the form \(MDP=(S,A,P, R,\gamma)\). As the LLM generates a new token, a new state is reached and the token generated is considered the action which transitions the LLM to the next state:

\[ [c, x_{t-1}],[c, x_t], [c, x_{t+1}]... \]

This represents the transition of states based on the LLM generating actions as new tokens that produce resulting states. The next action is \(x_{1:t+1}\) and the prior tokens are the context \(c\) at that point.

The reward is based on the generated token \(x\) and the context it was generated under, namely, \( R(c,x)\) (the reward function).

Now we have all the components necessary to frame this as a MDP problem, with the goal to maximise the accumulated discounted return (as is always the goal with an MDP).

If the LLM generates actions, then it can be modelled as a policy that would select actions in reinforcement learning (RL). Therefore like RL you can aim to optimise the objective function (and therefore the accumulated discounted total reward):

\[ L_{\theta}(c)=  \mathbb{E}_{x \sim p_{\theta}}[R(c,x)] \]

where \(c\) is he context/prompt or prior tokens for conditioning and \(x\) is the generated token.

We can then optimise the objective function by obtaining the gradient, which  the REINFORCE algorithm specifies can be done this way:

\[ \nabla_{\theta} L_{\theta}(c) = \mathbb{E}_{x \sim p_{\theta}(\cdot|c)}[\hat{A}(c,x)\nabla_{\theta}\log p_{\theta}(x|c)] \]

Here, the \(\hat{A}\) is the advantage estimate, which estimates the Q-Value, which also lowers the variance of the gradient estimate (\(\nabla_{\theta}\log p_{\theta}(x|c)\)) to make the gradient updates less dramatic.

The key is that during reinforcement learning using policy gradient methods (such as the above), each subsequent action is sampled from the policy, in this case, the LLM \(p_{\theta}\) and the gradients with respect to the actions (parameters) are calculated. The above uses the REINFORCE algorithm to determine the reward/objecrtive gradient by using the policy gradient, meaning this is an on-policy reinforcement learning approach.

Note that on-policy means you're using/sampling actions from the very policy you're trying to improve, while offline policy sampling is taking actions from another policy (often called the behaviour policy), which is used to evaluate and improve the target policy. Q-Learning is off-policy.

The useful thing about off-policy learning is that you can explore using an exploratory behaviour policy or watch others' actions to dictate your action, and then use the outcome to update your target policy. With on-policy, you're learning ONLY from your own actions. SARSA is on-policy.

So you now have an LLM that acts as a policy that continually samples/generates actions/tokens, which when assessed by the reward function(that also depends on the token), can be used to generate the function gradients of the objective/reward function, and can then be fed back to the LLM(policy) via gradient acent to improve the reward function in the future via the LLM(policy).