A practical examination of the implementation details for two prominent advanced VAE architectures, Conditional VAEs (CVAEs) and Vector Quantized VAEs (VQ-VAEs), is provided. The goal is to furnish the foundational code structure and understanding necessary to build, train, and evaluate these models, preparing for more complex generative tasks.
We'll focus on highlighting the specific modifications required to a standard VAE framework to realize these advanced variants. For this session, we'll assume you are working with a dataset like MNIST or Fashion-MNIST, as these allow for clear demonstration of conditional generation and the impact of discrete latents on sample quality. You should be comfortable with Python and a deep learning framework such as PyTorch or TensorFlow.
Before you begin, ensure you have:
We will provide high-level code structures and logic. You will adapt these to your specific framework and dataset.
Conditional VAEs extend the VAE framework by incorporating conditional information, denoted as , into both the generation and inference processes. This allows us to direct the VAE to generate data samples possessing specific attributes defined by . For instance, with MNIST, could be the digit label (0-9), enabling us to request the CVAE to generate an image of a particular digit.
The core idea is to make both the encoder and the decoder dependent on the condition .
Condition Representation: The condition (e.g., a class label) is typically converted into a numerical format, often one-hot encoded, before being fed into the networks. Let's say is this numerical representation.
Encoder :
Decoder :
A diagram illustrating the CVAE structure:
Data flow in a Conditional Variational Autoencoder. The condition is incorporated into both the encoder and decoder.
The objective function for a CVAE is a conditional version of the ELBO:
During training, we maximize this .
# Condition: label (e.g., integer from 0 to 9 for MNIST)
# Convert label to one-hot embedding: c_embed
# Encoder
class CVAEEncoder(nn.Module):
def __init__(self, input_dim, latent_dim, condition_dim, hidden_dim):
super().__init__()
# Define layers (e.g., nn.Linear, nn.Conv2d)
# Example: self.fc_x = nn.Linear(input_dim, hidden_dim)
# Example: self.fc_c = nn.Linear(condition_dim, hidden_dim)
# Example: self.fc_combined = nn.Linear(hidden_dim * 2, hidden_dim)
# self.fc_mu = nn.Linear(hidden_dim, latent_dim)
# self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
def forward(self, x, c_embed):
# h_x = F.relu(self.fc_x(x.view(x.size(0), -1))) # Flatten x if image
# h_c = F.relu(self.fc_c(c_embed))
# combined = torch.cat([h_x, h_c], dim=1)
# h_combined = F.relu(self.fc_combined(combined))
# mu = self.fc_mu(h_combined)
# logvar = self.fc_logvar(h_combined)
return mu, logvar
# Decoder
class CVAEDecoder(nn.Module):
def __init__(self, latent_dim, condition_dim, hidden_dim, output_dim):
super().__init__()
# Define layers
# Example: self.fc_z = nn.Linear(latent_dim, hidden_dim)
# Example: self.fc_c = nn.Linear(condition_dim, hidden_dim)
# Example: self.fc_combined = nn.Linear(hidden_dim * 2, hidden_dim)
# self.fc_out = nn.Linear(hidden_dim, output_dim)
def forward(self, z, c_embed):
# h_z = F.relu(self.fc_z(z))
# h_c = F.relu(self.fc_c(c_embed))
# combined = torch.cat([h_z, h_c], dim=1)
# h_combined = F.relu(self.fc_combined(combined))
# reconstruction = torch.sigmoid(self.fc_out(h_combined)) # Assuming sigmoid for pixel values
# return reconstruction.view(-1, num_channels, height, width) # Reshape to image
pass # Actual implementation depends on network design
# Training loop:
# For each batch (x_batch, c_batch_labels):
# c_batch_embed = one_hot_encode(c_batch_labels)
# mu, logvar = encoder(x_batch, c_batch_embed)
# z_sampled = reparameterize(mu, logvar)
# x_reconstructed = decoder(z_sampled, c_batch_embed)
#
# reconstruction_loss = F.binary_cross_entropy(x_reconstructed, x_batch, reduction='sum')
# kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# loss = reconstruction_loss + kl_divergence
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
The is the concatenation of the condition embedding c_embed at appropriate points in the encoder and decoder.
VQ-VAEs introduce a discrete latent space by quantizing the encoder's output to the closest vector in a learned codebook (or embedding space). This often leads to sharper generated samples compared to standard VAEs with continuous latents, as the decoder learns to map a finite set of representations to outputs.
A diagram illustrating the VQ-VAE structure:
Data flow in a Vector Quantized Variational Autoencoder. The encoder output is quantized using a learnable codebook. The Straight-Through Estimator (STE) is used for gradient propagation.
The VQ-VAE is trained by minimizing a combined loss:
Where:
sg) operation to prevent encoder outputs from growing arbitrarily large:
Here, is the codebook vector closest to . The gradient only updates .There's no explicit KL divergence term to a prior in the basic VQ-VAE. The discreteness of the latent space itself acts as a form of regularization. A prior can be learned over the discrete latent codes (e.g., using a PixelCNN) for generation, after the VQ-VAE is trained.
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.commitment_cost = commitment_cost
# Initialize codebook (embeddings)
# self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
# self.embedding.weight.data.uniform_(-1.0 / self.num_embeddings, 1.0 / self.num_embeddings)
def forward(self, inputs):
# inputs: (Batch, Channel, Height, Width) -> (B*H*W, C)
# flat_input = inputs.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)
# Calculate distances from inputs to all codebook vectors
# distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
# + torch.sum(self.embedding.weight**2, dim=1)
# - 2 * torch.matmul(flat_input, self.embedding.weight.t()))
# Find nearest encoding (indices)
# encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
# quantized = self.embedding(encoding_indices).view(inputs.shape) # Get quantized vectors
# Calculate losses
# codebook_loss = F.mse_loss(quantized.detach(), inputs) # sg[inputs] in original paper, for EMA update. Or sg[quantized] for codebook vectors
# commitment_loss = F.mse_loss(inputs, quantized.detach())
# loss = codebook_loss + self.commitment_cost * commitment_loss
# Straight-Through Estimator:
# quantized = inputs + (quantized - inputs).detach() # STE
# Reshape quantized back to (Batch, Channel, Height, Width)
# return quantized, loss, encoding_indices.squeeze()
pass # Actual implementation needs careful handling of dimensions and STE
class VQVAE(nn.Module):
def __init__(self, encoder, decoder, vq_layer):
super().__init__()
# self.encoder = encoder
# self.vq_layer = vq_layer
# self.decoder = decoder
def forward(self, x):
# z_e = self.encoder(x)
# quantized_latents, vq_loss, _ = self.vq_layer(z_e)
# x_reconstructed = self.decoder(quantized_latents)
# return x_reconstructed, vq_loss
pass
# Training loop:
# For each batch (x_batch, _): # No labels needed unless for evaluation
# x_reconstructed, vq_loss = vq_vae_model(x_batch)
#
# reconstruction_loss = F.mse_loss(x_reconstructed, x_batch)
# total_loss = reconstruction_loss + vq_loss # vq_loss already contains codebook and commitment losses
# optimizer.zero_grad()
# total_loss.backward()
# optimizer.step()
The VectorQuantizer module is the most intricate part. Ensure the stop-gradient (.detach() in PyTorch) is correctly applied for the codebook and commitment losses, and that the STE is implemented for the forward pass if you want gradients to flow back to the encoder through the quantization step.
encoding_indices obtained from the training data. Once is learned, sample indices from it, retrieve the corresponding codebook vectors , and pass them to the decoder.Both CVAE and VQ-VAE offer significant improvements over the vanilla VAE architecture, but they address different aspects and have distinct characteristics.
| Feature | Conditional VAE (CVAE) | Vector Quantized VAE (VQ-VAE) |
|---|---|---|
| Primary Goal | Controlled generation based on attributes | Improved sample fidelity, discrete latent representation |
| Latent Space | Continuous, conditioned by | Discrete, finite set of learned codebook vectors |
| Control Mechanism | Explicit input condition | Implicit through the structure of the learned codebook |
| Sample Quality | Content controlled by ; can still suffer from blurriness | Often sharper, less blurry samples; generation requires a prior over codes |
| Training Objective | Conditional ELBO () | |
| Gradient Flow | Standard reparameterization trick | Straight-Through Estimator for quantization step |
| Common Challenges | Ensuring condition is effectively used, posterior collapse | Codebook collapse (unused codes), choosing and |
| Generation Process | Sample , provide , decode $p_\theta(x | z,c)$ |
With the foundations of CVAE and VQ-VAE implementations in place, you can extend your practical work:
This hands-on experience is fundamental for developing a deeper intuition for how architectural choices influence the behavior and capabilities of Variational Autoencoders. By building and experimenting, you will be better equipped to select, design, and adapt VAE models for your specific representation learning and generative modeling tasks.
Was this section helpful?
© 2026 ApX Machine LearningAI Ethics & Transparency•