Zero-Shot RL with the Forward-Backward Representation, Inside Out

In 2021, shortly after I started graduate school, Ahmed Touati and Yann Ollivier published Learning One Representation to Optimize All Rewards, and this paper blew my mind. As the title suggests, the paper presents an algorithm for learning a representation (the forward-backward representation) that, in a sense, encodes all possible optimal policies for a given MDP simultaneously. You can see for yourself that this essentially does work by playing around with their fabulous demo, where you can literally prompt their model with a reward function and watch it (nearly instantly) derive a Walker policy that optimizes the return.

The idea and the math behind the FB always seemed fairly magical to me. Recently, I derived the FB "inside-out" (relative to the exposition in the paper), which led to some nice intuitions and insights, which I will document here.

The Inside Out Derivation

One way to achieve zero-shot transfer in RL – that is, immediately inferring an optimal policy in an arbitrary (but Markovian) task – is by modeling the successor measure (SM, defined below for completeness) for each optimal policy. Before FB, it was not clear how to do this, but in an ideal world where this could be done, we can imagine modeling this object as \(M^{\pi^\star_r}(\cdot\mid x, a)\) – that is, a collection of successor measures for the policies \(\pi^\star_r\) that optimize the rewards \(r\). Now, we know something else that's interesting about optimal policies:

\begin{align*} \pi^\star_r(x) &\in \arg\max_aQ^{\pi^\star_r}(x, a)\\ &= \arg\max_aM^{\pi^\star_r}r(x, a). \end{align*}

Note, however, that \(\pi^\star_r\) can be determined precisely by the reward function \(r\), so let's say \(\pi^\star_r = \mathsf{opt}(r)\). Equivalently, we now have

\begin{align*} \pi^\star_r(x) \in \arg\max_a\bar{M}^rr(x, a) \end{align*}

where \(\bar{M}^r = M^{\mathsf{opt}(r)}\). So, one way to learn \(\bar{M}\) is to randomly sample a reward function \(r\) from some prior, and then train \(\bar{M}^r\) with a TD update (see e.g. algorithms for learning fixed-policy SMs) by using the greedy action with respect to \(\bar{M}^rr\) in the bootstrap target. This does literally amount to learning a reward-function-conditioned SM with a greedy bootstrap target. Also, this is roughly what the FB algorithm is doing in principle.

The Missing Piece: Task Embeddings

So why not stop there? To my understanding, the problem is that conditioning a function approximator on a reward function is nontrivial. To begin with, you cannot even represent a reward function exactly (it lives in an infinite-dimensional space generally). Moreover, the product \(\bar{M}^rr\) would be intractable to compute, as it would involve integrating over the state space. The brilliance of the FB algorithm is the realization that by embedding the reward functions into a finite dimensional space in a special way, both problems are simultaneously resolved. The idea is the following. Assume \(\bar{M}^r\) has a low-rank factorization (at least approximately) as \(\bar{M}^r(\cdot\mid x, a) = \bar{F}^r(x, a)^\top \bar{B}(\cdot)\), where \(\bar{F}^r, \bar{B}^r\) have range \(\mathbf{R}^d\). Note that \(\bar{B}\) does not depend on \(r\) – the idea here is that we want \(\bar{B}\) to act as a "reward function encoder". It should also be noted that this factorization is not actually that restrictive: for large enough \(d\), such a factorization will exist (in the extreme case for arbitrarily large \(d\), we just have \(B = \mathsf{Id}\) which is reward-independent). Under this model, optimal policies are characterized by

\begin{align*} \pi^\star_r(x) &\in \arg\max_{a}\bar{F}^r(x, a)^\top\bar{B}r\\ &= \arg\max_a\bar{F}^r(x, a)^\top z \end{align*}

where \(z := \int_{\mathcal{X}}r(x')B(\mathrm{d}x')\in\mathbf{R}^d\). But now, the greedy action depends only on \(z\)! Consequently, we can index optimal policies and the \(\bar{F}\) factor by the "task embeddings" \(z\in\mathbf{R}^d\). This looks like the following:

\begin{align*} \begin{cases} \pi_z(x) \in \arg\max_a F^z(x, a)^\top z\\ z = \int_{\mathcal{X}} r(x')B(\mathrm{d}x')\\ F^z(x, a)^\top B(\cdot) = M^{\pi_z}(\cdot\mid x, a) \end{cases} \end{align*}

Notably, we now can easily represent \(F, B\) with function approximators (their inputs are now all finite-dimensional objects). The paper shows some conditions on when such a factorization exists (and provides error bounds in the case where the factorization is only approximate). But the important lesson here, at least to me, is that FB innovates by compressing the space of reward functions in a clever way.

Now, given a learned FB representation and an arbitrary reward function \(r\):

  1. Infer the task embedding \(z = \int_{\mathcal{X}} r(x')B(\mathrm{d}x')\). While this is still an intractable integral, we can approximate it by sampling. Or when \(r\) is the indicator for a given state \(x^\star\) (like in goal-conditioned RL), \(z = B(\mathrm{d}x')\).
  2. Exectute the policy \(\pi_z(x)\in\arg\max_aF^z(x, a)^\top z\).
  3. Profit.

The Successor Measure

The successor measure (SM) is a term that I believe was coined by Leonard Blier, Corentin Tallec, and Yann Olivier in an excellent manuscript from 2021, which generalizes the successor representation to measurable state spaces. For an MDP with transition kernel \(P\) and a policy \(\pi\), the successor measures is a measure on the state space conditioned on a source state and action, defined as

\begin{align*} M^\pi(\cdot\mid x, a) = \sum_{t\geq 0}\gamma^t\Pr_\pi(X_t\in\cdot\mid X_0 = x,\ A_0 = a) \end{align*}

where \(\Pr_\pi\) is the probability measure over trajectories where \(A_t\sim\pi(\cdot X_t)\) and \(X_{t+1}\sim \int P(\cdot\mid x, a)\pi(\mathrm{d}a\mid x)\). Notably, it is readily verified that \(M^\pi(\mathcal{X}\mid x, a) = (1-\gamma)^{-1}\), so \(M^\pi\) is a finite measure for any source state-action pair. As such, it acts as a linear operator on the space of bounded measurable functions in the usual way,

\begin{align*} (M^\pi f)(x, a) = \int_{\mathcal{X}}f(x')M^\pi(\mathrm{d}x'\mid x, a). \end{align*}

But most notably, this linear operation has a most meaningful interpretation in RL. Famously, it maps reward functions to action-value functions,

\begin{align*} M^\pi r = Q^\pi \end{align*}

for any bounded and measurable reward function \(r\), so the successor measure can be used as a device to infer the action-value function for just about any reward function.

Author: Harley Wiltzer

Created: 2024-07-17 Wed 21:35