## Zero-Shot RL with the Forward-Backward Representation, Inside Out

In 2021, shortly after I started graduate school, Ahmed Touati and
Yann Ollivier published Learning One Representation to Optimize All
Rewards, and this paper blew my mind. As the title suggests, the
paper presents an algorithm for learning a representation (the
*forward-backward representation*) that, in a
sense, encodes all possible optimal policies for a given MDP
simultaneously. You can see for yourself that this essentially does
work by playing around with their fabulous demo, where you can
literally prompt their model with a reward function and watch it
(nearly instantly) derive a Walker policy that optimizes the return.

The idea and the math behind the FB always seemed fairly magical to me. Recently, I derived the FB "inside-out" (relative to the exposition in the paper), which led to some nice intuitions and insights, which I will document here.

### The Inside Out Derivation

One way to achieve zero-shot transfer in RL – that is, immediately inferring an optimal policy in an arbitrary (but Markovian) task – is by modeling the successor measure (SM, defined below for completeness) for each optimal policy. Before FB, it was not clear how to do this, but in an ideal world where this could be done, we can imagine modeling this object as \(M^{\pi^\star_r}(\cdot\mid x, a)\) – that is, a collection of successor measures for the policies \(\pi^\star_r\) that optimize the rewards \(r\). Now, we know something else that's interesting about optimal policies:

\begin{align*} \pi^\star_r(x) &\in \arg\max_aQ^{\pi^\star_r}(x, a)\\ &= \arg\max_aM^{\pi^\star_r}r(x, a). \end{align*}Note, however, that \(\pi^\star_r\) can be determined precisely by the reward function \(r\), so let's say \(\pi^\star_r = \mathsf{opt}(r)\). Equivalently, we now have

\begin{align*} \pi^\star_r(x) \in \arg\max_a\bar{M}^rr(x, a) \end{align*}where \(\bar{M}^r = M^{\mathsf{opt}(r)}\). So, one way to learn \(\bar{M}\) is to randomly sample a reward function \(r\) from some prior, and then train \(\bar{M}^r\) with a TD update (see e.g. algorithms for learning fixed-policy SMs) by using the greedy action with respect to \(\bar{M}^rr\) in the bootstrap target. This does literally amount to learning a reward-function-conditioned SM with a greedy bootstrap target. Also, this is roughly what the FB algorithm is doing in principle.

#### The Missing Piece: Task Embeddings

So why not stop there? To my understanding, the problem is that
conditioning a function approximator on a *reward function* is
nontrivial. To begin with, you cannot even represent a reward function
exactly (it lives in an infinite-dimensional space
generally). Moreover, the product \(\bar{M}^rr\) would be
intractable to compute, as it would involve integrating over the state
space. The brilliance of the FB algorithm is the realization that by
**embedding the reward functions** into a finite dimensional space in a
special way, both
problems are simultaneously resolved. The idea is the
following. Assume \(\bar{M}^r\) has a low-rank factorization (at
least approximately) as \(\bar{M}^r(\cdot\mid x, a) = \bar{F}^r(x, a)^\top
\bar{B}(\cdot)\), where \(\bar{F}^r, \bar{B}^r\) have
range \(\mathbf{R}^d\). Note that \(\bar{B}\) does not depend on \(r\)
– the idea here is that we want \(\bar{B}\) to act as a "reward
function encoder". It should also be noted that this factorization is
not actually that restrictive: for large enough \(d\), such a factorization
will exist (in the extreme case for arbitrarily large \(d\), we just
have \(B = \mathsf{Id}\) which is reward-independent). Under this model, optimal policies are
characterized by

where \(z := \int_{\mathcal{X}}r(x')B(\mathrm{d}x')\in\mathbf{R}^d\). But now, the greedy action depends only on \(z\)! Consequently, we can index optimal policies and the \(\bar{F}\) factor by the "task embeddings" \(z\in\mathbf{R}^d\). This looks like the following:

\begin{align*} \begin{cases} \pi_z(x) \in \arg\max_a F^z(x, a)^\top z\\ z = \int_{\mathcal{X}} r(x')B(\mathrm{d}x')\\ F^z(x, a)^\top B(\cdot) = M^{\pi_z}(\cdot\mid x, a) \end{cases} \end{align*}Notably, we now can easily represent \(F, B\) with function approximators (their inputs are now all finite-dimensional objects). The paper shows some conditions on when such a factorization exists (and provides error bounds in the case where the factorization is only approximate). But the important lesson here, at least to me, is that FB innovates by compressing the space of reward functions in a clever way.

Now, given a learned FB representation and an arbitrary reward function \(r\):

- Infer the task embedding \(z = \int_{\mathcal{X}} r(x')B(\mathrm{d}x')\). While this is still an intractable integral, we can approximate it by sampling. Or when \(r\) is the indicator for a given state \(x^\star\) (like in goal-conditioned RL), \(z = B(\mathrm{d}x')\).
- Exectute the policy \(\pi_z(x)\in\arg\max_aF^z(x, a)^\top z\).
- Profit.

### The Successor Measure

The *successor measure* (SM) is a term that I believe was coined by
Leonard Blier, Corentin Tallec, and Yann Olivier in an excellent
manuscript from 2021, which generalizes the successor representation
to measurable state spaces. For an MDP with transition kernel \(P\) and
a policy \(\pi\), the
successor measures is a measure on the state space conditioned on a
source state and action, defined as

where \(\Pr_\pi\) is the probability measure over trajectories where \(A_t\sim\pi(\cdot X_t)\) and \(X_{t+1}\sim \int P(\cdot\mid x, a)\pi(\mathrm{d}a\mid x)\). Notably, it is readily verified that \(M^\pi(\mathcal{X}\mid x, a) = (1-\gamma)^{-1}\), so \(M^\pi\) is a finite measure for any source state-action pair. As such, it acts as a linear operator on the space of bounded measurable functions in the usual way,

\begin{align*} (M^\pi f)(x, a) = \int_{\mathcal{X}}f(x')M^\pi(\mathrm{d}x'\mid x, a). \end{align*}But most notably, this linear operation has a most meaningful interpretation in RL. Famously, it maps reward functions to action-value functions,

\begin{align*} M^\pi r = Q^\pi \end{align*}for any bounded and measurable reward function \(r\), so the successor measure can be used as a device to infer the action-value function for just about any reward function.