Linearly constrained relative entropy minimization

In reinforcement learning, one usually seeks to maximize a reward while staying close to a current policy. However, one could turn the problem around and instead minimize the deviation from the current policy under the constraint that the new policy achieves a higher reward. Such reformulation is remarkable for the beautiful maths that accompanies it, initially discovered by the physicists under the name free energy minimization. Some ideas presented here are known in different communities under different names: see, for example, Fenchel duality, entropy, and the log partition function, Donsker-Varadhan formula, Compression Lemma; exponential families in general are wonderfully described by Wainwright and Jordan.

Generic $f$-divergence minimization

It is helpful to keep in mind episode-based policy search as an example application of the abstract framework presented below.

Let $\mu$ be a given probability measure on $X$ (e.g., current policy), and $\pi$ be a measure on $X$ that we want to find (e.g., new policy). Assuming $\pi$ is absolutely continuous with respect to $\mu$, denote by $\xi$ the Radon-Nikodym derivative of $\pi$ with respect to $\mu$, \begin{equation} d\pi=\xi d\mu. \label{xi} \end{equation} The quality of a measure is evaluated using a reward function $R \colon X \to \mathbb{R}$, and it is defined as the expectation of $R$ under the given measure \begin{equation*} J(\pi) := \int R d\pi. \end{equation*} All integrals in this note go over $X$. The following optimization problem yields $\xi$ as its solution, allowing one to find $\pi$ using \eqref{xi}; \begin{equation} \begin{aligned} \text{minimize} &\quad \int f(\xi) d\mu \\ \text{subject to} &\quad \int R \xi d\mu = c, \\ &\quad \int \xi d\mu = 1. \end{aligned} \label{primal} \end{equation} Function $f$ is assumed to be convex and satisfy the property $f(1) = f^\prime(1) = 0$; we additionally assume that the domain of $f$ is $[0,\infty)$ in order to avoid adding an explicit non-negativity constraint on $\xi$. With such a choice of $f$, the objective that we are minimizing is an $f$-divergence from $\mu$ to $\pi$. The first equality constraint requires the expectation of $R$ under the new policy to be equal to $c$; choosing $c > \mathbb{E}_\mu R$ guarantees that $\pi$ achieves a higher reward than $\mu$. The last equality constraint ensures that $\pi$ is normalized.

Primal optimal point and dual objective

The Lagrangian of Problem \eqref{primal} \begin{equation*} L(\xi, \lambda, \nu) = \int f(\xi) d\mu - \lambda \left( \int R \xi d\mu - c \right) + \nu \left( \int \xi d\mu - 1 \right) \end{equation*} has an extremal \begin{equation} \xi^\star(\lambda, \nu) = f_*^\prime (\lambda R - \nu) \label{primal_optimal} \end{equation} where $f_*^\prime$ denotes the derivative of the Legendre transform of $f$; we put the star at the bottom to leave space at the top for the derivative symbol. Using Fenchel’s equality, we arrive at the Lagrangian dual \begin{equation} g(\lambda, \nu) = -\int f_* \left( \lambda R - \nu \right) d\mu + \lambda c - \nu. \label{dual} \end{equation} Without any additional assumptions, we cannot eliminate $\lambda$ or $\nu$ from the dual, and thus have to rely on numeric optimization for finding them.

Special case: KL divergence

With a special choice of $f$ \begin{equation*} f(x) := x \log x - (x - 1), \end{equation*} the $f$-divergence turns into the KL divergence (also called relative entropy). The exponential form of the Legendre transform $f_*(y) = e^y - 1$ renders $\xi^\star$ from \eqref{primal_optimal} to be an exponential family \begin{equation*} \xi^\star(\lambda, \nu) = e^{\lambda R - \nu}, \end{equation*} allowing us to satisfy the normalization constraint analytically by setting \begin{equation*} \nu^\star(\lambda) = \psi(\lambda) := \log \int e^{\lambda R} d\mu \end{equation*} where $\psi$ is known as the log-partition function. (The log-partition function is closely related to the cumulant generating function, with the difference that in the cumulant generating function, the expectation is computed with respect to the measure induced by $R$, while in the log-partition function—with respect to an arbitrary fixed measure $\mu$). The density $\xi^\star$ can then be expressed as a function of $\lambda$ only \begin{equation} \xi^\star(\lambda) = e^{\lambda R - \psi(\lambda)}, \label{xi_kl} \end{equation} which turns the dual \eqref{dual} into \begin{equation} g(\lambda) = \lambda c - \psi(\lambda). \label{dual_kl} \end{equation} There are a few interesting observation one can make about this dual.

Legendre transform of the log-partition function

Computing the maximum of \eqref{dual_kl} for a given $c$ is equivalent to evaluating the Legendre transform of $\psi$ \begin{equation} \psi_*(c) := \sup_{\lambda} \left\{ c \lambda - \psi(\lambda) \right\} \label{psi_star_variational} \end{equation} at the given $c$; in other words, \begin{equation} g(\lambda^\star(c)) = \psi_*(c) = c\lambda^\star(c) - \psi(\lambda^\star(c)). \label{psi_star} \end{equation} By strong duality, the Legendre transform of the log-partition function equals the KL divergence from $\mu$ to $\pi^\star$, \begin{equation} \psi_*(c) = KL(\pi^\star(\lambda^\star(c)) \| \mu). \label{psi_legendre} \end{equation} Although it may be hard in practice to find $\psi_*$ analytically, it is an important conceptual tool.

Bregman divergence of the Legendre transform

We can give one more little twist to things by expressing the exponent in \eqref{xi_kl} through the Bregman divergence generated by $\psi_*$. Recall that in general the Bregman divergence generated by a function $g$ is given by \begin{equation*} d_g(y, x) = g(y) - g(x) - \nabla g(x)^T (y - x). \end{equation*} Using this definition, together with the expression \eqref{psi_star} for $\psi_*$, we can rewrite the exponent in \eqref{xi_kl} at optimum as \begin{align*} \lambda^\star R - \psi(\lambda^\star) &= \left\{ c\lambda^\star - \psi(\lambda^\star) \right\} + \lambda^\star (R - c) \\ &= \psi_*(c) + \psi_*^\prime(c)(R-c) \\ &= \psi_*(R) - \left\{ \psi_*(R) - \psi_*(c) - \psi_*^\prime(c)(R-c) \right\} \\ &= \psi_*(R) - d_{\psi_*} (R, c). \end{align*} Therefore, the optimal density $\xi^\star$ as a function of $c$ can be computed as \begin{equation} \xi^\star(c) = e^{\psi_*(R) - d_{\psi_*}(R, c)}. \label{xi_c} \end{equation} Formula \eqref{xi_c} provides a nice interpretation of $\xi^\star$: it weighs rewarding experiences exponentially using $\psi_*$ as a measure of goodness; at the same time, it punishes deviation of $R$ from $c$ using the Bregman divergence generated by $\psi_*$.

Example: quadratic Gaussian problem

To get a more hands-on appreciation of the theory, let’s take a look at a concrete example. Let the base measure be Gaussian \begin{equation*} \mu := \mathcal{N}(\mu, \sigma^2) \end{equation*} (despite letter $\mu$ being already taken, we nevertheless denote the mean by $\mu$ in order to keep the standard notation for the parameters of a Gaussian distribution; no confusion should arise from such abuse of notation), and let the reward be quadratic \begin{equation*} R(x) := -x^2. \end{equation*} The log-partition function can then be computed in closed form \begin{equation} \psi(\lambda) = \log \int e^{-\lambda x^2} \mathcal{N} (x|\mu, \sigma^2) dx = -\frac{\lambda \mu^2}{2\lambda\sigma^2 + 1} - \frac{1}{2} \log (2\lambda\sigma^2 + 1). \label{log_part_func} \end{equation} From \eqref{psi_star}, we know that \begin{equation*} \psi^\prime (\lambda^\star) = c, \end{equation*} which gives us an equation on $\lambda^\star$. In this simple case, it is a quadratic equation \begin{equation} -\frac{\mu^2}{(2\lambda^\star\sigma^2 + 1)^2} - \frac{\sigma^2}{2\lambda^\star\sigma^2 + 1} = c \label{lam_star_eq} \end{equation} with only one root satisfying the improvement condition $c > -\mu^2 - \sigma^2$ which guarantees that the expectation of $R$ is greater under $\pi$ than under $\mu$, \begin{equation} \lambda^\star(c) = -\left( \frac{1}{2\sigma^2} + \frac{1}{4c} + \frac{ \sqrt{ 1 + 2\frac{(-c)}{\sigma^2} 2\frac{\mu^2}{\sigma^2} } }{4c} \right). \label{root} \end{equation} Note that $c$ must be non-positive because $R \leq 0$. Substituting $\psi$ from \eqref{log_part_func} and $c$ from \eqref{lam_star_eq} into Fenchel’s equality \begin{equation*} \psi_*(c) = c \lambda^\star - \psi(\lambda^\star), \end{equation*} we find \begin{equation} \psi_*(c) = \frac{2{\lambda^\star}^2\sigma^2\mu^2}{(2\lambda^\star\sigma^2+1)^2} - \frac{\lambda^\star\sigma^2}{2\lambda^\star\sigma^2+1} - \frac{1}{2} \log \frac{1}{2\lambda^\star\sigma^2+1}, \label{psi_star_gaussian} \end{equation} where $\lambda^\star = \lambda^\star(c)$ is a function of $c$ defined in \eqref{root}. Recall from \eqref{psi_legendre} that \begin{equation*} \psi_*(c) = KL(\pi^\star \| \mu). \end{equation*} We can compute the right-hand side directly in order to confirm this equality. First, using \eqref{xi}, we can find \begin{equation*} d\pi = \mathcal{N}\left(x \,\Big\rvert\, \frac{\mu}{2\lambda\sigma^2+1}, \frac{\sigma^2}{2\lambda\sigma^2+1} \right) dx. \end{equation*} Then, using the formula for the KL between Gaussians \begin{equation*} KL\left( \mathcal{N}(\mu_1, \sigma_1^2) \| \mathcal{N}(\mu_2, \sigma_2^2) \right) = \frac{(\mu_1-\mu_2)^2}{2\sigma_2^2} + \frac{1}{2} \left( \frac{\sigma_1^2}{\sigma_2^2} - 1 - \log \frac{\sigma_1^2}{\sigma_2^2} \right), \end{equation*} we can confirm that $KL(\pi^\star \| \mu)$ indeed equals $\psi_*(c)$ as defined in \eqref{psi_star_gaussian}. As a bonus, you may check that substituting \eqref{root} into \eqref{psi_star_gaussian} yields the following expression for $\psi_*$ as an explicit function of $c$ \begin{equation*} \psi_*(c) = \frac{(-c)+\mu^2}{2\sigma^2} - \frac{1}{2}\sqrt{1 + 2\frac{(-c)}{\sigma^2} 2\frac{\mu^2}{\sigma^2}} + \frac{1}{2}\log\left( \frac{ 1 + \sqrt{1 + 2\frac{(-c)}{\sigma^2} 2\frac{\mu^2}{\sigma^2}} }{2\frac{(-c)}{\sigma^2}} \right). \end{equation*} It is a bit sad that even for such a simple example, the convex conjugate of the log-partition function is so ugly. Fortunately, the variational definition \eqref{psi_star_variational}, being a convex optimization problem, makes it possible to find $\lambda^\star$ for any given $c$ quickly and reliably, thus eliminating the need of an explicit formula for $\psi_*(c)$ for all practical purposes.