Amortized Optimization

Prologue

Over the summer, I made it a goal to use the phrase “amortized optimization” as often as my friends would dare let me. I’m not sure why the idea fascinated me so much, but whatever the reason, it tickled my brain enough that I developed a small bit of notoriety for abusively dropping that phrase into conversation every chance I could. And now that I have a little bit more time on my hands, I’d like to infect any poor soul who comes across this blog post with that same phrase. So without further ado…

What Is Amortized Optimization?

Suppose you wish to solve a large number of optimization problems all at once. Given a loss function L(b;a) that takes a context variable a and an optimization variable b, the goal is to solve

min.bL(b;a),

for all aA. Naively, this would require going through all the elements of A and solving A optimization problems separately. That would be quite challenging. But here’s an idea: let’s say we solved the problem for a whole bunch of a1:nA, then we would now have a dataset of (ai,bi)1:n samples, where bi denotes either the solution to the optimization problem with respect to context ai. In true hammer-meets-nail fashion, if given a new context variable an+1, then the church of machine learning compels to us ask, “well, why not predict bn+1?”

In general, this would be a sketchy proposal if L is a highly wiggly function with respect to (a,b). But if a pattern relating a and b is recoverable, then treating this as a prediction task may actually work. Therein lies the heart of amortized optimization, which says: rather than solving for b directly, we learn to predict the solution b as a function fθ of a. Motivated by this intuition, our amortized optimization objective becomes

min.θEap(a)L(fθ(a);a).

We replace all instances of b with fθ and simply sample a bunch of a’s from some distribution p(a) in order to do optimization. Of course, we would have to figure out a reasonable p(a) to use if an obvious choice does not exist. And we can either sample a fixed number of a’s to perform optimization with, or, if using stochastic gradient descent, we can just keep sampling new a’s on the fly for each gradient step. But details and devils aside, we have effectively converted a massive number of optimization problems into a single learning problem. And if given a new atest, we simply leverage the wisdom of all previous a’s we’ve ever trained on to predict b^test=fθ(atest). In other words, we’ve amortized the optimization of atest by leveraging previous a’s for which fθ has been trained to optimize. All in all, a pretty neat approach to solving multiple (related) optimization problems at once!


Amortized Optimization In the Wild

Conditioned on the fact that you somehow stumbled across this blog post and have read this far, it is likely that you’ve seen amortized optimization before, but probably under a different name. Actually, if at this point you’re wondering whether the term “amortized optimization” is a well-established phrase, my guess is that it’s not—seeing as I made this term up.If you check Google Scholar, the term "amortized optimization" has definitely been used before, but never, to the best of my knowledge, in the context of learning to optimize. In any case, as a demonstration of the utility of amortized optimization, here are some well-known use cases of amortized optimization that you’ve probably encountered.

Amortized Inference

In variational inference, we have a joint distribution pθ(x,z) and wish to approximate the posterior pθ(zx). The goal is to find the best distribution q(z)Q to solve

min.q(z)Ezq(z)lnq(z)pθ(x,z),

for each xX. It may be a little hard to see at first, but L in this case is simply

L(q(z);x)=Ezq(z)lnq(z)pθ(x,z),

which is the negative of the variational lower bound. To amortize the optimization over all xX, we instead learn fϕ that maps fϕ:XQ. Since fϕ is basically a conditional distribution, it is more meaningfully written as qϕ(zx), which yields

min.ϕExp(x)Ezqϕ(zx)lnqϕ(zx)pθ(x,z).

Here, p(x) is the true distribution over the observed data.If θ is fixed and the goal is strictly to learn how to approximate the posterior pθ(zx) for xpθ(x), then it makes sense to set p(x)=pθ(x). In the case of finite data, p(x) is simply approximated as uniform sampling over our dataset. Many may recognize this objective as the variational autoencoder objective (if you optimize (θ,ϕ) jointly), where the encoder qϕ performs variational inference on the generator pθ. I love amortized inference and think it’s one of the coolest things ever. And since inference is an instantiation of optimization, it’s why I decided to ultimately latch on to the term “amortized optimization”.

Fast Style Transfer

In style transfer, we have a content image c and a style image s. The goal is to transfer the style from s to x by optimizing

min.xLc(c,x)+Ls(s,x),

where Lc measures some notion of content distance between (c,x), and Ls measures some notion of style distance between (s,x). To amortize the optimization over all (s,c)S×C, simply learn fθ:S×CX with the amortized objective

min.θEc,sp(c,s)Lc(c,fθ(c,s))+Ls(s,fθ(c,s)).

In the case of style transfer, the tricky thing is figuring out a distribution over art styles and figuring out the best way of designing the model fθ.

Meta-Learning

In certain multi-task settings, the goal is to quickly learn to solve new tasks by leveraging experiences from previous tasks. This falls quite naturally within the amortization framework. Indeed, we can think of the “original” optimization problem as

min.θL(mθ;T),

where mθ is a model parameterized by θ and T is a task drawn from T, and L evaluates mθ on the task T. To amortize this optimization problem, we simply learn a meta-model fϕ:TΘ that maps tasks to parameters of m and is trained using the following amortized objective

min.ϕETp(T)L(mfϕ(T);T).

By doing so, you end up with a meta-learning fϕ that hopefully does well on new tasks!

Classification

Ok, this one is a bit of a stretch. But hear me out. This,

Lsurvival,

is the ultimate loss function, the survival of the fittest. Billions of years of biological optimization led to the development of biological systems that could now perform incredible, survival acts like detecting cats. As humans, we come from a long lineage of organisms that have optimized

min.yLsurvival(y;x)

for the classification y of every single object x our ancestors had the misfortune of encountering. And through countless more hours of mechanical turk (and who knows how much money), we have now curated millions of samples of D=(xi,yi)1:n=(xi,biological-classifier(xi))1:n. By noting that the amortized optimization takes the following form

min.θExDLsurvival(fθ(x);x)ExDLc(fθ(x),biological-classifier(x)),

where the humble classification loss function Lc now serves as a stand-in for Lsurvival, we have performed a feat no smaller than the amortization of billion of years of optimization.

In short, amortized optimization is everywhere.

I’m pretty sure I had a more relevant point to make at first, but I’ve totally lost track of it now.

Footnotes

  1. If you check Google Scholar, the term "amortized optimization" has definitely been used before, but never, to the best of my knowledge, in the context of learning to optimize.[↩]
  2. If θ is fixed and the goal is strictly to learn how to approximate the posterior pθ(zx) for xpθ(x), then it makes sense to set p(x)=pθ(x).[↩]