Background
Most of the math below is taken from Mohamed et al. (2020).
Consider a function $f: \mathbb{R}^n \to \mathbb{R}^m$, a parameter $\theta \in \mathbb{R}^d$ and a parametric probability distribution $p(\theta)$ on the input space. Given a random variable $X \sim p(\theta)$, we want to differentiate the expectation of $Y = f(X)$ with respect to $\theta$:
\[E(\theta) = \mathbb{E}[f(X)] = \int f(x) ~ p(x | \theta) ~\mathrm{d} x = \int y ~ q(y | \theta) ~\mathrm{d} y\]
Usually this is approximated with Monte-Carlo sampling: let $x_1, \dots, x_S \sim p(\theta)$ be i.i.d., we have the estimator
\[E(\theta) \simeq \frac{1}{S} \sum_{s=1}^S f(x_s)\]
Autodiff
Since $E$ is a vector-to-vector function, the key quantity we want to compute is its Jacobian matrix $\partial E(\theta) \in \mathbb{R}^{m \times n}$:
\[\partial E(\theta) = \int f(x) ~ \nabla_\theta p(x | \theta)^\top ~\mathrm{d} x = \int y ~ \nabla_\theta q(y | \theta)^\top ~ \mathrm{d} y\]
However, to implement automatic differentiation, we only need the vector-Jacobian product (VJP) $\partial E(\theta)^\top \bar{y}$ with an output cotangent $\bar{y} \in \mathbb{R}^m$. See the book by Blondel and Roulet (Mar 2024) to know more.
Our goal is to rephrase this VJP as an expectation, so that we may approximate it with Monte-Carlo sampling as well.
REINFORCE
Implemented by Reinforce
.
Score function
The REINFORCE estimator is derived with the help of the identity $\nabla \log u = \nabla u / u$:
\[\begin{aligned} \partial E(\theta) & = \int f(x) ~ \nabla_\theta p(x | \theta)^\top ~ \mathrm{d}x \\ & = \int f(x) ~ \nabla_\theta \log p(x | \theta)^\top p(x | \theta) ~ \mathrm{d}x \\ & = \mathbb{E} \left[f(X) \nabla_\theta \log p(X | \theta)^\top\right] \\ \end{aligned}\]
And the VJP:
\[\partial E(\theta)^\top \bar{y} = \mathbb{E} \left[f(X)^\top \bar{y} ~\nabla_\theta \log p(X | \theta)\right]\]
Our Monte-Carlo approximation will therefore be:
\[\partial E(\theta)^\top \bar{y} \simeq \frac{1}{S} \sum_{s=1}^S f(x_s)^\top \bar{y} ~ \nabla_\theta \log p(x_s | \theta)\]
Variance reduction
The REINFORCE estimator has high variance, but its variance is reduced by subtracting a so-called baseline $b = \frac{1}{S} \sum_{s=1}^S f(x_s)$ (Kool et al., 2022).
For $S > 1$ Monte-Carlo samples, we have
\[\begin{aligned} \partial E(\theta)^\top \bar{y} & \simeq \frac{1}{S} \sum_{s=1}^S \left(f(x_s) - \frac{1}{S - 1}\sum_{j\neq s} f(x_j) \right)^\top \bar{y} ~ \nabla_\theta\log p(x_s | \theta)\\ & = \frac{1}{S - 1}\sum_{s=1}^S (f(x_s) - b)^\top \bar{y} ~ \nabla_\theta\log p(x_s | \theta) \end{aligned}\]
Reparametrization
Implemented by Reparametrization
.
Trick
The reparametrization trick assumes that we can rewrite the random variable $X \sim p(\theta)$ as $X = g_\theta(Z)$, where $Z \sim r$ is another random variable whose distribution $r$ does not depend on $\theta$.
The expectation is rewritten with $h = f \circ g$:
\[E(\theta) = \mathbb{E}\left[ f(g_\theta(Z)) \right] = \mathbb{E}\left[ h_\theta(Z) \right]\]
And we can directly differentiate through the expectation:
\[\partial E(\theta) = \mathbb{E} \left[ \partial_\theta h_\theta(Z) \right]\]
This yields the VJP:
\[\partial E(\theta)^\top \bar{y} = \mathbb{E} \left[ \partial_\theta h_\theta(Z)^\top \bar{y} \right]\]
We can use a Monte-Carlo approximation with i.i.d. samples $z_1, \dots, z_S \sim r$:
\[\partial E(\theta)^\top \bar{y} \simeq \frac{1}{S} \sum_{s=1}^S \partial_\theta h_\theta(z_s)^\top \bar{y}\]
Catalogue
The following reparametrizations are implemented:
- Univariate Normal: $X \sim \mathcal{N}(\mu, \sigma^2)$ is equivalent to $X = \mu + \sigma Z$ with $Z \sim \mathcal{N}(0, 1)$.
- Multivariate Normal: $X \sim \mathcal{N}(\mu, \Sigma)$ is equivalent to $X = \mu + L Z$ with $Z \sim \mathcal{N}(0, I)$ and $L L^\top = \Sigma$. The matrix $L$ can be obtained by Cholesky decomposition of $\Sigma$.
Probability gradients
In the case where $f$ is a function that takes values in a finite set $\mathcal{Y} = \{y_1, \cdots, y_K\}$, we may also want to compute the jacobian of the probability weights vector:
\[q : \theta \longmapsto \begin{pmatrix} q(y_1|\theta) = \mathbb{P}(f(X) = y_1|\theta) \\ \dots \\ q(y_K|\theta) = \mathbb{P}(f(X) = y_K|\theta) \end{pmatrix}\]
whose Jacobian is given by
\[\partial_\theta q(\theta) = \begin{pmatrix} \nabla_\theta q(y_1|\theta)^\top \\ \dots \\ \nabla_\theta q(y_K|\theta)^\top \end{pmatrix}\]
REINFORCE probability gradients
The REINFORCE technique can be applied in a similar way:
\[q(y_k | \theta) = \mathbb{E}[\mathbf{1}\{f(X) = y_k\}] = \int \mathbf{1} \{f(x) = y_k\} ~ p(x | \theta) ~ \mathrm{d}x\]
Differentiating through the integral,
\[\begin{aligned} \nabla_\theta q(y_k | \theta) & = \int \mathbf{1} \{f(x) = y_k\} ~ \nabla_\theta p(x | \theta) ~ \mathrm{d}x \\ & = \mathbb{E} [\mathbf{1} \{f(X) = y_k\} ~ \nabla_\theta \log p(X | \theta)] \end{aligned}\]
The Monte-Carlo approximation for this is
\[\nabla_\theta q(y_k | \theta) \simeq \frac{1}{S} \sum_{s=1}^S \mathbf{1} \{f(x_s) = y_k\} ~ \nabla_\theta \log p(x_s | \theta)\]
The VJP is then
\[\begin{aligned} \partial_\theta q(\theta)^\top \bar{q} &= \sum_{k=1}^K \bar{q}_k \nabla_\theta q(y_k | \theta)\\ &\simeq \frac{1}{S} \sum_{s=1}^S \left[\sum_{k=1}^K \bar{q}_k \mathbf{1} \{f(x_s) = y_k\}\right] ~ \nabla_\theta \log p(x_s | \theta) \end{aligned}\]
In our implementation, the empirical_distribution
method outputs an empirical FixedAtomsProbabilityDistribution
with uniform weights $\frac{1}{S}$, where some $x_s$ can be repeated.
\[q : \theta \longmapsto \begin{pmatrix} q(f(x_1)|\theta) \\ \dots \\ q(f(x_S) | \theta) \end{pmatrix}\]
We therefore define the corresponding VJP as
\[\partial_\theta q(\theta)^\top \bar{q} = \frac{1}{S} \sum_{s=1}^S \bar{q}_s \nabla_\theta \log p(x_s | \theta)\]
If $\bar q$ comes from mean
, we have $\bar q_s = f(x_s)^\top \bar y$ and we obtain the REINFORCE VJP.
This VJP can be interpreted as an empirical expectation, to which we can also apply variance reduction: $\partial_\theta q(\theta)^\top \bar q \approx \frac{1}{S-1}\sum_s(\bar q_s - b') \nabla_\theta \log p(x_s|\theta)$ with $b' = \frac{1}{S}\sum_s \bar q_s$.
Again, if $\bar q$ comes from mean
, we have $\bar q_s = f(x_s)^\top \bar y$ and $b' = b^\top \bar y$. We then obtain the REINFORCE backward rule with variance reduction: $\partial_\theta q(\theta)^\top \bar q \approx \frac{1}{S-1}\sum_s(f(x_s) - b)^\top \bar y \nabla_\theta \log p(x_s|\theta)$
Reparametrization probability gradients
To leverage reparametrization, we perform a change of variables:
\[q(y | \theta) = \mathbb{E}[\mathbf{1}\{h_\theta(Z) = y\}] = \int \mathbf{1} \{h_\theta(z) = y\} ~ r(z) ~ \mathrm{d}z\]
Assuming that $h_\theta$ is invertible, we take $z = h_\theta^{-1}(u)$ and
\[\mathrm{d}z = |\partial h_{\theta}^{-1}(u)| ~ \mathrm{d}u\]
so that
\[q(y | \theta) = \int \mathbf{1} \{u = y\} ~ r(h_\theta^{-1}(u)) ~ |\partial h_{\theta}^{-1}(u)| ~ \mathrm{d}u\]
We can now differentiate, but it gets tedious.
Bibliography
- Blondel, M. and Roulet, V. (Mar 2024). The Elements of Differentiable Programming, arXiv:2403.14606 [cs]. Accessed on Mar 22, 2024.
- Kool, W.; van Hoof, H. and Welling, M. (2022). Buy 4 REINFORCE Samples, Get a Baseline for Free! ICLR. Accessed on Apr 17, 2023.
- Mohamed, S.; Rosca, M.; Figurnov, M. and Mnih, A. (2020). Monte Carlo Gradient Estimation in Machine Learning. Journal of Machine Learning Research 21, 1–62. Accessed on Oct 21, 2022.