Amortized Optimization

Prologue

\(\newcommand{\E}{\mathbb{E}} \newcommand{\brac}[1]{\left[#1\right]} \newcommand{\paren}[1]{\left(#1\right)} \newcommand{\set}[1]{\left\{#1\right\}} \newcommand{\KL}{\text{KL}} \DeclareMathOperator*{\minimize}{min.} \DeclareMathOperator*{\maximize}{max.} \renewcommand{\L}{\mathcal{L}} \newcommand{\A}{\mathcal{A}} \newcommand{\Expect}{\mathbb{E}} \newcommand{\Q}{\mathcal{Q}} \newcommand{\X}{\mathcal{X}} \newcommand{\D}{\mathcal{D}} \newcommand{\giv}{\mid}\)Over the summer, I made it a goal to use the phrase “amortized optimization” as often as my friends would dare let me. I’m not sure why the idea fascinated me so much, but whatever the reason, it tickled my brain enough that I developed a small bit of notoriety for abusively dropping that phrase into conversation every chance I could. And now that I have a little bit more time on my hands, I’d like to infect any poor soul who comes across this blog post with that same phrase. So without further ado…

What Is Amortized Optimization?

Suppose you wish to solve a large number of optimization problems all at once. Given a loss function \(\L(b; a)\) that takes a context variable \(a\) and an optimization variable \(b\), the goal is to solve

\[\begin{align} \minimize_b \L(b; a), \end{align}\]

for all \(a \in \A\). Naively, this would require going through all the elements of \(\A\) and solving \(\|\A\|\) optimization problems separately. That would be quite challenging. But here’s an idea: let’s say we solved the problem for a whole bunch of \(a_{1:n} \in \A\), then we would now have a dataset of \((a_i, b_i^*)_{1:n}\) samples, where \(b_i^*\) denotes either the solution to the optimization problem with respect to context \(a_i\). In true hammer-meets-nail fashion, if given a new context variable \(a_{n+1}\), then the church of machine learning compels to us ask, “well, why not predict \(b^*_{n+1}\)?”

In general, this would be a sketchy proposal if \(\L\) is a highly wiggly function with respect to \((a, b)\). But if a pattern relating \(a\) and \(b^*\) is recoverable, then treating this as a prediction task may actually work. Therein lies the heart of amortized optimization, which says: rather than solving for \(b\) directly, we learn to predict the solution \(b^*\) as a function \(f_\theta\) of \(a\). Motivated by this intuition, our amortized optimization objective becomes

\[\begin{align} \minimize_\theta \Expect_{a \sim p(a)} \L(f_\theta(a); a). \end{align}\]

We replace all instances of \(b\) with \(f_\theta\) and simply sample a bunch of \(a\)’s from some distribution \(p(a)\) in order to do optimization. Of course, we would have to figure out a reasonable \(p(a)\) to use if an obvious choice does not exist. And we can either sample a fixed number of \(a\)’s to perform optimization with, or, if using stochastic gradient descent, we can just keep sampling new \(a\)’s on the fly for each gradient step. But details and devils aside, we have effectively converted a massive number of optimization problems into a single learning problem. And if given a new \(a_\text{test}\), we simply leverage the wisdom of all previous \(a\)’s we’ve ever trained on to predict \(\hat{b}_\text{test} = f_\theta(a_\text{test})\). In other words, we’ve amortized the optimization of \(a_\text{test}\) by leveraging previous \(a\)’s for which \(f_\theta\) has been trained to optimize. All in all, a pretty neat approach to solving multiple (related) optimization problems at once!


Amortized Optimization In the Wild

Conditioned on the fact that you somehow stumbled across this blog post and have read this far, it is likely that you’ve seen amortized optimization before, but probably under a different name. Actually, if at this point you’re wondering whether the term “amortized optimization” is a well-established phrase, my guess is that it’s not—seeing as I made this term up.If you check Google Scholar, the term "amortized optimization" has definitely been used before, but never, to the best of my knowledge, in the context of learning to optimize. In any case, as a demonstration of the utility of amortized optimization, here are some well-known use cases of amortized optimization that you’ve probably encountered.

Amortized Inference

In variational inference, we have a joint distribution \(p_\theta(x, z)\) and wish to approximate the posterior \(p_\theta(z \giv x)\). The goal is to find the best distribution \(q(z) \in \Q\) to solve

\[\begin{align} \minimize_{q(z)} \Expect_{z \sim q(z)} \ln \frac{q(z)}{p_\theta(x, z)}, \end{align}\]

for each \(x \in \X\). It may be a little hard to see at first, but \(\L\) in this case is simply

\[\begin{align} \L(q(z); x) = \Expect_{z \sim q(z)} \ln \frac{q(z)}{p_\theta(x, z)}, \end{align}\]

which is the negative of the variational lower bound. To amortize the optimization over all \(x \in \X\), we instead learn \(f_\phi\) that maps \(f_\phi: \X \to \Q\). Since \(f_\phi\) is basically a conditional distribution, it is more meaningfully written as \(q_\phi(z \giv x)\), which yields

\[\begin{align} \minimize_{\phi} \Expect_{x \sim p(x)}\Expect_{z \sim q_\phi(z \giv x)} \ln \frac{q_\phi(z \giv x)}{p_\theta(x, z)}. \end{align}\]

Here, \(p(x)\) is the true distribution over the observed data.If \(\theta\) is fixed and the goal is strictly to learn how to approximate the posterior \(p_\theta(z \giv x)\) for \(x \sim p_\theta(x)\), then it makes sense to set \(p(x) = p_\theta(x)\). In the case of finite data, \(p(x)\) is simply approximated as uniform sampling over our dataset. Many may recognize this objective as the variational autoencoder objective (if you optimize \((\theta, \phi)\) jointly), where the encoder \(q_\phi\) performs variational inference on the generator \(p_\theta\). I love amortized inference and think it’s one of the coolest things ever. And since inference is an instantiation of optimization, it’s why I decided to ultimately latch on to the term “amortized optimization”.

Fast Style Transfer

In style transfer, we have a content image \(c\) and a style image \(s\). The goal is to transfer the style from \(s\) to \(x\) by optimizing

\[\begin{align} \minimize_{x} \L_c(c, x) + \L_s(s, x), \end{align}\]

where \(\L_c\) measures some notion of content distance between \((c, x)\), and \(\L_s\) measures some notion of style distance between \((s, x)\). To amortize the optimization over all \((s, c) \in \mathcal{S} \times \mathcal{C}\), simply learn \(f_\theta: \mathcal{S} \times \mathcal{C} \to \X\) with the amortized objective

\[\begin{align} \minimize_\theta \Expect_{c, s \sim p(c, s)} \L_c(c, f_\theta(c, s)) + \L_s(s, f_\theta(c, s)). \end{align}\]

In the case of style transfer, the tricky thing is figuring out a distribution over art styles and figuring out the best way of designing the model \(f_\theta\).

Meta-Learning

In certain multi-task settings, the goal is to quickly learn to solve new tasks by leveraging experiences from previous tasks. This falls quite naturally within the amortization framework. Indeed, we can think of the “original” optimization problem as

\[\begin{align} \minimize_\theta \L(m_\theta; T), \end{align}\]

where \(m_\theta\) is a model parameterized by \(\theta\) and \(T\) is a task drawn from \(\mathcal{T}\), and \(\L\) evaluates \(m_\theta\) on the task \(T\). To amortize this optimization problem, we simply learn a meta-model \(f_\phi: \mathcal{T} \to \Theta\) that maps tasks to parameters of \(m\) and is trained using the following amortized objective

\[\begin{align} \minimize_\phi \Expect_{T \sim p(T)}\L(m_{f_\phi(T)}; T). \end{align}\]

By doing so, you end up with a meta-learning \(f_\phi\) that hopefully does well on new tasks!

Classification

Ok, this one is a bit of a stretch. But hear me out. This,

\[\begin{align} \L_\text{survival}, \end{align}\]

is the ultimate loss function, the survival of the fittest. Billions of years of biological optimization led to the development of biological systems that could now perform incredible, survival acts like detecting cats. As humans, we come from a long lineage of organisms that have optimized

\[\begin{align} \minimize_\text{y} \L_\text{survival}(y; x) \end{align}\]

for the classification \(y\) of every single object \(x\) our ancestors had the misfortune of encountering. And through countless more hours of mechanical turk (and who knows how much money), we have now curated millions of samples of \(\D = (x_i, y^*_i)_{1:n} = (x_i, \text{biological-classifier}(x_i))_{1:n}\). By noting that the amortized optimization takes the following form

\[\begin{align} \minimize_\theta \Expect_{x \sim \D}\L_\text{survival}(f_\theta(x); x) \equiv \Expect_{x \sim \D}\L_c(f_\theta(x), \text{biological-classifier}(x)), \end{align}\]

where the humble classification loss function \(\L_c\) now serves as a stand-in for \(\L_\text{survival}\), we have performed a feat no smaller than the amortization of billion of years of optimization.

In short, amortized optimization is everywhere.

I’m pretty sure I had a more relevant point to make at first, but I’ve totally lost track of it now.