This practical session builds directly on our understanding of advanced VAE architectures. Having discussed the theoretical aspects of models like Conditional VAEs (CVAEs) and Vector Quantized VAEs (VQ-VAEs), we will now walk through the implementation details of these two prominent architectures. Our goal is to provide you with the foundational code structure and insights needed to build, train, and evaluate these models, preparing you to tackle more complex generative tasks.
We'll focus on highlighting the specific modifications required to a standard VAE framework to realize these advanced variants. For this session, we'll assume you are working with a dataset like MNIST or Fashion-MNIST, as these allow for clear demonstration of conditional generation and the impact of discrete latents on sample quality. You should be comfortable with Python and a deep learning framework such as PyTorch or TensorFlow.
Before you begin, ensure you have:
We will provide high-level code structures and logic. You will adapt these to your specific framework and dataset.
Conditional VAEs extend the VAE framework by incorporating conditional information, denoted as c, into both the generation and inference processes. This allows us to direct the VAE to generate data samples possessing specific attributes defined by c. For instance, with MNIST, c could be the digit label (0-9), enabling us to request the CVAE to generate an image of a particular digit.
The core idea is to make both the encoder qϕ(z∣x,c) and the decoder pθ(x∣z,c) dependent on the condition c.
Condition Representation: The condition c (e.g., a class label) is typically converted into a numerical format, often one-hot encoded, before being fed into the networks. Let's say cembed is this numerical representation.
Encoder qϕ(z∣x,c):
Decoder pθ(x∣z,c):
A diagram illustrating the CVAE structure:
Data flow in a Conditional Variational Autoencoder. The condition c is incorporated into both the encoder and decoder.
The objective function for a CVAE is a conditional version of the ELBO:
LCVAE(ϕ,θ;x,c)=Eqϕ(z∣x,c)[logpθ(x∣z,c)]−DKL(qϕ(z∣x,c)∣∣p(z∣c))During training, we maximize this LCVAE.
# Condition: label (e.g., integer from 0 to 9 for MNIST)
# Convert label to one-hot embedding: c_embed
# Encoder
class CVAEEncoder(nn.Module):
def __init__(self, input_dim, latent_dim, condition_dim, hidden_dim):
super().__init__()
# Define layers (e.g., nn.Linear, nn.Conv2d)
# Example: self.fc_x = nn.Linear(input_dim, hidden_dim)
# Example: self.fc_c = nn.Linear(condition_dim, hidden_dim)
# Example: self.fc_combined = nn.Linear(hidden_dim * 2, hidden_dim)
# self.fc_mu = nn.Linear(hidden_dim, latent_dim)
# self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
def forward(self, x, c_embed):
# h_x = F.relu(self.fc_x(x.view(x.size(0), -1))) # Flatten x if image
# h_c = F.relu(self.fc_c(c_embed))
# combined = torch.cat([h_x, h_c], dim=1)
# h_combined = F.relu(self.fc_combined(combined))
# mu = self.fc_mu(h_combined)
# logvar = self.fc_logvar(h_combined)
return mu, logvar
# Decoder
class CVAEDecoder(nn.Module):
def __init__(self, latent_dim, condition_dim, hidden_dim, output_dim):
super().__init__()
# Define layers
# Example: self.fc_z = nn.Linear(latent_dim, hidden_dim)
# Example: self.fc_c = nn.Linear(condition_dim, hidden_dim)
# Example: self.fc_combined = nn.Linear(hidden_dim * 2, hidden_dim)
# self.fc_out = nn.Linear(hidden_dim, output_dim)
def forward(self, z, c_embed):
# h_z = F.relu(self.fc_z(z))
# h_c = F.relu(self.fc_c(c_embed))
# combined = torch.cat([h_z, h_c], dim=1)
# h_combined = F.relu(self.fc_combined(combined))
# reconstruction = torch.sigmoid(self.fc_out(h_combined)) # Assuming sigmoid for pixel values
# return reconstruction.view(-1, num_channels, height, width) # Reshape to image
pass # Actual implementation depends on network design
# Training loop:
# For each batch (x_batch, c_batch_labels):
# c_batch_embed = one_hot_encode(c_batch_labels)
# mu, logvar = encoder(x_batch, c_batch_embed)
# z_sampled = reparameterize(mu, logvar)
# x_reconstructed = decoder(z_sampled, c_batch_embed)
#
# reconstruction_loss = F.binary_cross_entropy(x_reconstructed, x_batch, reduction='sum')
# kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# loss = reconstruction_loss + kl_divergence
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
The key is the concatenation of the condition embedding c_embed
at appropriate points in the encoder and decoder.
VQ-VAEs introduce a discrete latent space by quantizing the encoder's output to the closest vector in a learned codebook (or embedding space). This often leads to sharper generated samples compared to standard VAEs with continuous latents, as the decoder learns to map a finite set of representations to outputs.
A diagram illustrating the VQ-VAE structure:
Data flow in a Vector Quantized Variational Autoencoder. The encoder output ze(x) is quantized using a learnable codebook. The Straight-Through Estimator (STE) is used for gradient propagation.
The VQ-VAE is trained by minimizing a combined loss:
LVQVAE=Lreconstruction+Lcodebook+β⋅LcommitmentWhere:
sg
) operation to prevent encoder outputs from growing arbitrarily large:
Lcodebook=∣∣sg[ze(x)]−ek∣∣22
Here, ek is the codebook vector closest to ze(x). The gradient only updates ek.There's no explicit KL divergence term to a prior in the basic VQ-VAE. The discreteness of the latent space itself acts as a form of regularization. A prior can be learned over the discrete latent codes (e.g., using a PixelCNN) for generation, after the VQ-VAE is trained.
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.commitment_cost = commitment_cost
# Initialize codebook (embeddings)
# self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
# self.embedding.weight.data.uniform_(-1.0 / self.num_embeddings, 1.0 / self.num_embeddings)
def forward(self, inputs):
# inputs: (Batch, Channel, Height, Width) -> (B*H*W, C)
# flat_input = inputs.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)
# Calculate distances from inputs to all codebook vectors
# distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
# + torch.sum(self.embedding.weight**2, dim=1)
# - 2 * torch.matmul(flat_input, self.embedding.weight.t()))
# Find nearest encoding (indices)
# encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
# quantized = self.embedding(encoding_indices).view(inputs.shape) # Get quantized vectors
# Calculate losses
# codebook_loss = F.mse_loss(quantized.detach(), inputs) # sg[inputs] in original paper, for EMA update. Or sg[quantized] for codebook vectors
# commitment_loss = F.mse_loss(inputs, quantized.detach())
# loss = codebook_loss + self.commitment_cost * commitment_loss
# Straight-Through Estimator:
# quantized = inputs + (quantized - inputs).detach() # STE
# Reshape quantized back to (Batch, Channel, Height, Width)
# return quantized, loss, encoding_indices.squeeze()
pass # Actual implementation needs careful handling of dimensions and STE
class VQVAE(nn.Module):
def __init__(self, encoder, decoder, vq_layer):
super().__init__()
# self.encoder = encoder
# self.vq_layer = vq_layer
# self.decoder = decoder
def forward(self, x):
# z_e = self.encoder(x)
# quantized_latents, vq_loss, _ = self.vq_layer(z_e)
# x_reconstructed = self.decoder(quantized_latents)
# return x_reconstructed, vq_loss
pass
# Training loop:
# For each batch (x_batch, _): # No labels needed unless for evaluation
# x_reconstructed, vq_loss = vq_vae_model(x_batch)
#
# reconstruction_loss = F.mse_loss(x_reconstructed, x_batch)
# total_loss = reconstruction_loss + vq_loss # vq_loss already contains codebook and commitment losses
# optimizer.zero_grad()
# total_loss.backward()
# optimizer.step()
The VectorQuantizer
module is the most intricate part. Ensure the stop-gradient (.detach()
in PyTorch) is correctly applied for the codebook and commitment losses, and that the STE is implemented for the forward pass if you want gradients to flow back to the encoder through the quantization step.
encoding_indices
obtained from the training data. Once p(k) is learned, sample indices from it, retrieve the corresponding codebook vectors ek, and pass them to the decoder.Both CVAE and VQ-VAE offer significant improvements over the vanilla VAE architecture, but they address different aspects and have distinct characteristics.
Feature | Conditional VAE (CVAE) | Vector Quantized VAE (VQ-VAE) |
---|---|---|
Primary Goal | Controlled generation based on attributes | Improved sample fidelity, discrete latent representation |
Latent Space | Continuous, conditioned by c | Discrete, finite set of learned codebook vectors |
Control Mechanism | Explicit input condition c | Implicit through the structure of the learned codebook |
Sample Quality | Content controlled by c; can still suffer from blurriness | Often sharper, less blurry samples; generation requires a prior over codes |
Training Objective | Conditional ELBO (Lreconstruction+DKL) | Lreconstruction+Lcodebook+β⋅Lcommitment |
Gradient Flow | Standard reparameterization trick | Straight-Through Estimator for quantization step |
Common Challenges | Ensuring condition c is effectively used, posterior collapse | Codebook collapse (unused codes), choosing K and β |
Generation Process | Sample z∼p(z), provide c, decode $p_\theta(x | z,c)$ |
With the foundations of CVAE and VQ-VAE implementations in place, you can extend your practical work:
This hands-on experience is fundamental for developing a deeper intuition for how architectural choices influence the behavior and capabilities of Variational Autoencoders. By building and experimenting, you will be better equipped to select, design, and adapt VAE models for your specific representation learning and generative modeling tasks.
Was this section helpful?
© 2025 ApX Machine Learning