Using a Bernoulli VAE on Real-Valued Observations

Introduction

\(\newcommand{\mbf}{\mathbf} \newcommand{\paren}[1]{\left(#1\right)} \newcommand{\brac}[1]{\left[#1\right]} \newcommand{\set}[1]{\left\{#1\right\}} \newcommand{\sset}[1]{\{#1\}} \newcommand{\abs}[1]{\left | #1\right |} \newcommand{\giv}{\mid} \newcommand{\nint}[1]{\displaystyle\lfloor #1 \rceil} \DeclareMathOperator*{\argmin}{arg\,min} \DeclareMathOperator*{\argmax}{arg\,max} \DeclareMathOperator*{\minimize}{min.} \DeclareMathOperator*{\maximize}{max.} \DeclareMathOperator*{\supp}{supp} \newcommand{\scolon}{\mathbin{;}} \newcommand{\tf}[1]{\text{#1}} \newcommand{\identity}{\mbf{I}} \newcommand{\zero}{\mbf{0}} \newcommand{\eps}{\epsilon} \newcommand{\veps}{\varepsilon} \newcommand{\data}{\mathrm{data}} \newcommand{\Normal}{\mathcal{N}} \newcommand{\Expect}{\mathbb{E}} \newcommand{\Ber}{\text{Ber}} \newcommand{\diag}{\text{diag}} \newcommand{\I}{\mathbf{I}} \newcommand{\0}{\mathbf{0}} \renewcommand{\P}{\mathcal{P}} \newcommand{\Cat}{\mathrm{Cat}} \newcommand{\R}{\mathbb{R}} \newcommand{\Q}{\mathcal{Q}} \newcommand{\X}{\mathcal{X}} \newcommand{\F}{\mathcal{F}} \newcommand{\G}{\mathcal{G}} \newcommand{\E}{\mathcal{E}} \newcommand{\Z}{\mathcal{Z}} \newcommand{\Y}{\mathcal{Y}} \renewcommand{\d}{\text{d}} \renewcommand{\H}{\mathcal{H}} \newcommand{\B}{\mathcal{B}} \renewcommand{\L}{\mathcal{L}} \newcommand{\U}{\mathcal{U}} \newcommand{\D}{\mathcal{D}} \newcommand{\J}{\mathcal{J}} \newcommand{\C}{\mathcal{C}} \newcommand{\1}[1]{\mathds{1}\!\set{#1}} \newcommand{\KL}[2]{D(#1 \mathbin{\|} #2)}\)The Bernoulli observation VAE is supposed is used when one’s observed samples \(x \in \sset{0, 1}^n\) are vectors of binary elements. However, I have, on occasion, seen people (and even papers) that apply Bernoulli observation VAEs to real-valued samples \(x \in [0, 1]^n\). This will be a quick and dirty post going over whether this unholy marriage of Bernoulli VAE with real-valued samples is appropriate.

Background and Notation for Bernoulli VAE

Given an empirical distribution \(\hat{p}(x)\) whose samples are binary \(x \in \sset{0, 1}^n\), the VAE objective is

\[\begin{align} \Expect_{\hat{p}(x)} \ln p_\theta(x) &\ge \Expect_{\hat{p}(x)} \bigg[\ln p_\theta(x) - \KL{q_\phi(z \giv x)}{p_\theta(z \giv x)}\bigg]\\ &= \Expect_{\hat{p}(x)} \bigg[\Expect_{q_\phi(z \giv x)} \ln p_\theta(x \giv z)\bigg] - \KL{q_\phi(z \giv x)}{p(z)}. \end{align}\]

If \(p_\theta(x \giv z)\) is furthermore a fully-factorized Bernoulli observation model, then the distribution can be expressed as

\[\begin{align} p_\theta(x \giv z) = \prod_i p_\theta(x_i \giv z) = \prod_i \Ber(x_i \giv \pi_i(z \scolon \theta)), \end{align}\]

where \(\pi: \Z \to [0, 1]^n\) is a neural network parameterized by \(\theta\). As preparation for the next section, we shall—with a slight abuse of notation—also define

\[\begin{align} p(x \giv \pi) = \prod_i p(x_i \giv \pi_i) = \prod_i \Ber(x_i \giv \pi_i), \end{align}\]

where \(\pi \in [0, 1]^n\).

Applying Bernoulli VAE to Real-Valued Samples

Suppose we have a distribution over \(r(\pi)\), and \(\hat{p}(x)\) is in fact the marginalization of \(r(\pi)p(x \giv \pi)\). This is the case for MNIST, where the real-valued samples are interpreted as observations of \(\pi\). This allows us to construct the objective as

\[\begin{align} \Expect_{\hat{p}(x)} \ln p_\theta(x) \ge \Expect_{r(\pi)}\Expect_{p(x \giv \pi)} \bigg[\ln p_\theta(x) - \KL{q_\phi(z \giv x)}{p_\theta(z \giv x)}\bigg]. \end{align}\]

It turns out there is another equally valid lower bound

\[\begin{align}\label{eq:bad} \Expect_{\hat{p}(x)} \ln p_\theta(x) \ge \Expect_{r(\pi)}\Expect_{p(x \giv \pi)} \bigg[\ln p_\theta(x) - \KL{q_\phi(z \giv \pi)}{p_\theta(z \giv x)}\bigg]. \end{align}\]

However, since \(q_\phi(z \giv \pi)\) does not have access to \(x\), it is unlikely to give a better approximation of \(p_\theta(z \giv x)\) than the previous equation. Consequently, it is likely to be a looser bound (which can be verified empirically). A bit of tedious algebra shows that the objective is equivalent to

\[\begin{align} \Expect_{r(\pi)}&\Expect_{p(x \giv \pi)} \bigg[\ln p_\theta(x) - \KL{q_\phi(z \giv \pi)}{p_\theta(z \giv x)}\bigg] \\ &= \Expect_{r(\pi)}\bigg[\Expect_{p(x \giv \pi)}\bigg[\Expect_{q_\phi(z \giv \pi)} \ln p_\theta(x \giv z) \bigg] - \KL{q_\phi(z \giv \pi)}{p(z)} \bigg] \\ &= \Expect_{r(\pi)}\bigg[\Expect_{q_\phi(z \giv \pi)}\bigg[\Expect_{p(x \giv \pi)} \sum_i \ln p_\theta(x_i \giv z) \bigg] - \KL{q_\phi(z \giv \pi)}{p(z)} \bigg] \\ &= \Expect_{r(\pi)}\bigg[\Expect_{q_\phi(z \giv \pi)}\sum_i \bigg[\Expect_{p(x_i \giv \pi_i)} \ln p_\theta(x_i \giv z) \bigg] - \KL{q_\phi(z \giv \pi)}{p(z)} \bigg] \\ \end{align}\]

where the inner-most term is exactly the sum of element-wise cross-entropy terms, where each cross-entropy term is

\[\begin{align} \Expect_{p(x_i \giv \pi_i)} \ln p_\theta(x_i \giv z) = \pi_i \ln \pi_i(z\scolon \theta) + (1 - \pi_i) \ln (1 - \pi_i(z\scolon \theta)). \end{align}\]

Note that this is exactly the application of Bernoulli observation VAEs to real-valued samples. So long as the real-valued samples can be interpreted as the Bernoulli distribution parameters, then this lower bound is valid. However, as noted above, this lower bound tends to be looser.