While standard Variational Autoencoders (VAEs), as discussed in Chapter 4, operate on continuous latent variables, typically sampled from a Gaussian distribution, there are scenarios where a discrete latent representation is more natural or beneficial. Consider domains like natural language, where fundamental units (words, characters) are discrete, or tasks where we desire a more structured, potentially symbolic, latent space. Furthermore, continuous VAEs can sometimes suffer from "posterior collapse," where the decoder learns to ignore the latent variable, especially when paired with powerful autoregressive decoders.
Vector Quantized Variational Autoencoders (VQ-VAEs) address this by introducing a discrete latent space through vector quantization. Instead of mapping the input to parameters of a continuous distribution (like mean and variance in a standard VAE), the VQ-VAE encoder maps the input to a continuous vector, which is then snapped to the closest vector in a learned, finite codebook (also called an embedding space).
A VQ-VAE consists of three main components:
Encoder: Similar to other autoencoders, this network f takes an input x and produces a continuous output vector (or tensor) ze(x)∈RD. In contrast to a standard VAE, ze(x) is not interpreted as parameters of a distribution but as a point in a D-dimensional space.
Codebook (Embedding Space): This is a learnable collection E={ei}i=1K, where each ei∈RD is an embedding vector. The size of the codebook, K, determines the number of possible discrete latent states. D is the dimensionality of each embedding vector.
Decoder: This network g takes a vector from the codebook and aims to reconstruct the original input x.
The core operation is the quantization step that links the encoder output to the codebook. For a given encoder output ze(x), we find the index k of the nearest codebook vector ek using Euclidean distance:
k=argimin∥ze(x)−ei∥22The input to the decoder is then the chosen codebook vector itself, zq(x)=ek. This ek (or sometimes just the index k) represents the discrete latent representation of the input x.
A significant challenge arises immediately: the argmin operation used for quantization is non-differentiable. Selecting the closest vector involves a hard decision, and its gradient is zero almost everywhere, preventing gradient flow from the decoder back to the encoder during training via standard backpropagation.
VQ-VAEs cleverly bypass this using a variant of the Straight-Through Estimator (STE). The core idea is:
This allows the decoder's reconstruction error signal to flow back to update the encoder weights, even though the quantization step itself has no usable gradient.
Training the VQ-VAE involves optimizing the encoder, decoder, and the codebook vectors simultaneously. The loss function comprises three terms:
Reconstruction Loss (Lrec): This is the standard autoencoder loss, measuring the difference between the original input x and the decoder's output x^=g(zq(x)). Depending on the data type, this could be Mean Squared Error (MSE) for continuous data or cross-entropy for discrete data (like images with pixel values treated categorically).
Codebook Loss (Lcodebook): This term updates the codebook vectors ei. The goal is to move the chosen codebook vector ek closer to the encoder output ze(x) that selected it. To prevent the encoder output from growing arbitrarily large to minimize this term (as the encoder weights aren't updated by this specific loss term), the encoder output is treated as a constant using the stop-gradient operator (sg
).
The stop-gradient sg[v]
passes v
unchanged during the forward pass but has zero gradient during the backward pass, effectively cutting off gradient flow through v
.
Commitment Loss (Lcommit): This term regularizes the encoder output, encouraging it to stay close to the chosen codebook vector ek. This prevents ze(x) from fluctuating too much and ensures the encoder "commits" to a specific codebook vector. It also uses the stop-gradient operator, but this time on the codebook vector, so the gradient only affects the encoder. This loss is typically weighted by a hyperparameter β.
Lcommit=β∥ze(x)−sg[ek]∥22A common value for β is between 0.1 and 2.0.
The total loss is the sum of these components:
L=Lrec+Lcodebook+LcommitNote that the codebook loss moves the embeddings ek towards the encoder outputs ze(x), while the commitment loss moves the encoder outputs ze(x) towards the embeddings ek. The STE handles the gradient flow for the reconstruction loss back to the encoder.
High-level architecture of a VQ-VAE. The forward pass computes encoder output ze(x), finds the nearest codebook vector ek via quantization, yielding zq(x), which is then decoded. Dashed arrows indicate gradient flow during backpropagation, with the Straight-Through Estimator (STE) copying the reconstruction gradient from zq to ze. Separate gradients update the codebook and enforce commitment.
VQ-VAEs offer several benefits:
In summary, VQ-VAEs provide a compelling alternative to standard VAEs by employing a discrete latent space learned via vector quantization. They overcome the non-differentiability of quantization using the straight-through estimator and rely on a specific loss structure involving reconstruction, codebook, and commitment terms. Their ability to learn meaningful discrete representations has made them a foundation for powerful generative models, particularly in image and audio synthesis.
© 2025 ApX Machine Learning