Variational Autoencoders explained — with PyTorch Implementation

Variational autoencoders (VAEs) act as foundational building blocks in current state-of-the-art text-to-image generators such as DALL-E and Stable Diffusion. Previously, state-of-the-art image generators were generally GANs, however, we’ve recently seen a shift with the introduction of Diffusers where models containing VAEs have taken over in the field. With that in mind we would like to gain a deeper understanding of the underlying concepts of VAEs and implement a small version in PyTorch.

Sanna Persson
9 min readSep 13, 2022

The article is a paper summary and implementation of: Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling, link: https://arxiv.org/abs/1312.6114.

Key concepts of the Variational Autoencoder

The goal of the VAE is information reconstruction and generation. Given a dataset sampled from some unknown distribution we want to for example be able to conditionally generate new data with the same distribution. This principle can be applied in several domains but it has mainly been used for image generation. The toy example that we will use, that was also used in the original paper, is that of generating new MNIST images.

The model consists of two parts — the encoder and the decoder. We assume that our data has an underlying latent distribution, explained in detail below. The encoder takes the input data to a latent representation and outputs the distribution of this representation. The decoder then samples from this distribution and generates a new data point. Since the latent distribution of the input batch is used a copy of the input data point itself has high likelihod of being generated by the decoder.

The latent distribution is the key concept that makes the VAE different from the autoencoder. Instead of simply compressing and reconstructing the input, the VAE tries to model the underlying data distribution. So, what exactly is this latent distribution? In most implementations of the Variational Autoencoder, two strong assumptions/modelling choices are made. We choose it to be a standard Gaussian and for the covariance matrix to be diagonal. This places the quite strong assumption that the features of the distribution are independent of each other. The encoder will then only output a vector for both the means and standard deviation of the latent distribution.

The Mathematics behind VAEs

Feel free to skip this section if you only want a more intuitive understanding of the main concepts. Otherwise, let’s dive a bit deeper into the details of the paper. In reality the VAE is only an example in the original paper of the underlying ideas. Due to its usefulness, it has however become widely known. The real value of the paper instead lie in the ideas behind the Auto-Encoding Variational Bayes.

Problem regime

The problem which the paper tries to solve is that where we have a large dataset of identically distributed independent samples of a stochastic variable X. This variable is generated by a hidden process dependent on the latent variable z that comes from prior distribution with parameters θ. That is a specific sample of X is generated from the conditional distribution (likelihood)

Our goal, to be able to generate new samples from X, is to find the marginal likelihood p(x) but we are generally faced with problems with intractibility. For example, we are in many cases not able to compute the integral

analytically over all values of z or the posterior p(z| x) is unknown.

To be able to formulate workarounds for this problem we make two simple assumptions: the prior, p(z), and likelihood p(x| z) have PDFs that are differentiable (almost everywhere) and depend on parameters θ.

We would now like to formulate an objective that finds a maximum likelihood estimate of the parameters θ and approximate the posterior distribution to be able to represent the data efficiently. In other words we want to encode the data to a distribution representation and be able to generate samples from this distribution.

Derivation of the variational bound — training objective

Starting with the objective: to generate images. This means that given a latent variable z we want to reconstruct and/or generate an image x. To do this we want to estimate the conditional distribution or likelihood p(x| z) however, since the latent variable z is unknown to us we cannot train a model to fit this distribution directly. The idea is instead to let the decoder network approximate the likelihood and then use Bayes rule to find the marginal distribution, which the data follows, i.e. compute

We can assume a Gaussian prior for z but we are still left with the problem that the posterior is intractable. The proposed solution is to approximate this distribution with the encoder network, q, with parameters ϕ

We will now starting from the logarithm of the marginal distribution above try to derive an objective that we can optimize with stochastic gradient descent. First we use a trick and multiply both the numerator and denominator with our approximate posterior.

We then use logaritmic rules to split the terms to our convenience.

Since the left hand side does not depend on z we can use an additional trick and take the expectation over z and then only the right hand side will be affected.

We are now at a point where we can see that the first term is the expectation of the logarithm of the likelihod of the data. Maximizing this term will mean that we have a high probability of reconstructing the data x correctly. The other two terms we can from the definition of the KL-divergence identify as measuring how closely our approximated distribution matches the prior and the true posterior.

Note that the last term is intractable, since the posterior is unknown, but we can use that the KL-divergence will be non-negative and form a lower bound on the marginal likelhood as follows:

This is our final objective. The first term will be a reconstruction term which measures how well the decoder reconstructs the data and the second term will be a competing objective that pushes the approximate posterior closer to the prior. In practice we often choose the prior to be a standard normal and the second term will then have regularizing effect that simplifies the distribution the encoder outputs.

Reparameterization trick — how to sample z?

To make this all work there is one other detail we also need to consider. We train the decoder to generate a sample from the conditional distribution given a value of z. But how do we generate z in the first place? We need to find a way to sample the distribution that is differentiable to be able to optimize it with stochastic gradient descent. The trick the paper presents is to separate the stochastic part of z and then transform it with the given input and parameters of the encoder with a transformation function g

In practice we first generate

from a distribution independent from the encoder parameters. Then we transform it with a function to the desired distribution. In the VAE we use the simple reparametrization

Model architecture of the VAE

The Variational Autoencoder is only an example of how to use the ideas presented in the paper can be used. It has shown, with few modifications, however to be a very useful example.

In the VAE we choose the prior of the latent variable to be a unit Gaussian with diagonal covariance matrix. This tells the model that we want it to learn a latent variable representation with independent features which is actually a quite strict assumption.

We implement the encoder and the decoder as simple MLPs with only a few layers. The encoder outputs the mean and standard deviation of the approximate posterior. Since we assume it to be Gaussian with a diagonal covariance we use the reparametrization trick described above to sample the latent distribution. We feed this value of $z$ to the decoder which generates a reconstructed data point.

The loss consists of two competing objectives. The KL-divergence that pushes the latent variable distribution towards being a unit normal distribution and the reconstruction loss pushes that model towards accurately reconstructing the original input. In this case we can analytically compute the KL-divergence and going through the calculations will yield the following formula

where J is the dimension of z and if you stare at the formula for a bit you will realize that it is maximized for a standard normal distribution. The reconstruction loss

is implemented with a BCE loss in PyTorch which essentially push the outputted pixel values to be similar to the input.

Code in PyTorch

The implementation of the Variational Autoencoder is simplified to only contain the core parts. The example is on the MNIST dataset and for the encoder and decoder network we use a simple MLP.

Imports

Model

The encoder and decoder are mirrored networks consisting of two layers. In the encoder the we take the input data to a hidden dimension through a linear layer and then we pass the hidden state to two different linear layers outputting the mean and standard deviation of the latent distribution respectively.

We then sample from the latent distribution and input it to the decoder that in turn outputs a vector of the same shape as the input.

Training configuration

Loading the MNIST data

Similar to the examples in the paper we use the MNIST dataset to showcase the model concepts.

Training setup

As described above, the loss consists of two different terms, the reconstruction loss, here implemented with BCE-loss and the KL-divergence. We make the quite strict assumptions that the prior of $z$ is a unit normal and that the posterior is approximately Gaussian with diagonal covariance matrix which means we can simplify the expression for the KL-divergence as is described above.

1875it [00:28, 65.04it/s, loss=5.5e+3]
1875it [00:25, 72.15it/s, loss=4.75e+3]
1875it [00:26, 71.89it/s, loss=4.82e+3]
1875it [00:26, 71.39it/s, loss=4.17e+3]
1875it [00:26, 70.06it/s, loss=4.17e+3]
1875it [00:26, 70.68it/s, loss=4.22e+3]
1875it [00:26, 70.44it/s, loss=4.14e+3]
1875it [00:27, 68.29it/s, loss=4.59e+3]
1875it [00:28, 66.01it/s, loss=4.26e+3]
1875it [00:26, 69.46it/s, loss=4.06e+3]

How to generate new images

Below we can see that the variational autoencoder generates slightly varying images given the same input thanks to the sampling of a new value of the latent variable in each generation.

A few upscaled examples of generated images 28x28 pixels

With some intuition about how VAEs work and having seen an example of how to implement them I hope that you now are better equipped to understand and implement more modern architectures incorporating these ideas! At least that is how I feel after going through the paper

--

--

Sanna Persson

Currently exploring the realms of deep learning. Particularly interested in healthcare applications