This practical session builds directly on our understanding of advanced VAE architectures. Having discussed the theoretical aspects of models like Conditional VAEs (CVAEs) and Vector Quantized VAEs (VQ-VAEs), we will now walk through the implementation details of these two prominent architectures. Our goal is to provide you with the foundational code structure and insights needed to build, train, and evaluate these models, preparing you to tackle more complex generative tasks.We'll focus on highlighting the specific modifications required to a standard VAE framework to realize these advanced variants. For this session, we'll assume you are working with a dataset like MNIST or Fashion-MNIST, as these allow for clear demonstration of conditional generation and the impact of discrete latents on sample quality. You should be comfortable with Python and a deep learning framework such as PyTorch or TensorFlow.Prerequisites for this PracticalBefore you begin, ensure you have:A working Python environment (e.g., Python 3.8+).A deep learning library installed: PyTorch (version 1.10+ recommended) or TensorFlow (version 2.5+ recommended).Familiarity with building and training standard VAEs using your chosen framework.Standard data science libraries: NumPy, Matplotlib (for visualizations).We will provide high-level code structures and logic. You will adapt these to your specific framework and dataset.Implementing a Conditional VAE (CVAE)Conditional VAEs extend the VAE framework by incorporating conditional information, denoted as $c$, into both the generation and inference processes. This allows us to direct the VAE to generate data samples possessing specific attributes defined by $c$. For instance, with MNIST, $c$ could be the digit label (0-9), enabling us to request the CVAE to generate an image of a particular digit.CVAE ArchitectureThe core idea is to make both the encoder $q_\phi(z|x,c)$ and the decoder $p_\theta(x|z,c)$ dependent on the condition $c$.Condition Representation: The condition $c$ (e.g., a class label) is typically converted into a numerical format, often one-hot encoded, before being fed into the networks. Let's say $c_{embed}$ is this numerical representation.Encoder $q_\phi(z|x,c)$:Input: The original data $x$ and the condition $c_{embed}$.Modification: Concatenate $c_{embed}$ with $x$ (if $x$ is flattened) or with an intermediate feature representation of $x$ within the encoder.Output: Parameters $(\mu, \log \sigma^2)$ for the approximate posterior $q_\phi(z|x,c)$, which is usually a Gaussian distribution.Decoder $p_\theta(x|z,c)$:Input: A latent sample $z$ (drawn from $q_\phi(z|x,c)$ during training, or $p(z|c)$ during generation) and the condition $c_{embed}$.Modification: Concatenate $c_{embed}$ with $z$ before passing it through the decoder network.Output: Parameters of the distribution for the reconstructed data $\hat{x}$ (e.g., pixel values for an image).A diagram illustrating the CVAE structure:digraph G { rankdir=TB; node [shape=box, style="filled", fillcolor="#e9ecef"]; edge [color="#495057"]; subgraph cluster_encoder { label = "Encoder q_phi(z|x,c)"; bgcolor="#f8f9fa"; X [label="Input x", fillcolor="#a5d8ff"]; C_enc [label="Condition c", shape=parallelogram, fillcolor="#ffec99"]; Enc_Net [label="Encoder Network"]; Enc_Concat [label="Concatenate\n(x_features, c_embed)", shape=oval]; Mu_Sigma [label="μ, log σ²", shape=ellipse]; X -> Enc_Net -> Enc_Concat; C_enc -> Enc_Concat; Enc_Concat -> Mu_Sigma; } Z_sample [label="Sample z ~ q_phi(z|x,c)", shape=ellipse, fillcolor="#b2f2bb"]; Mu_Sigma -> Z_sample [label="Reparameterization"]; subgraph cluster_decoder { label = "Decoder p_theta(x|z,c)"; bgcolor="#f8f9fa"; C_dec [label="Condition c", shape=parallelogram, fillcolor="#ffec99"]; Dec_Net [label="Decoder Network"]; Dec_Concat [label="Concatenate\n(z, c_embed)", shape=oval]; X_hat [label="Reconstruction x_hat", fillcolor="#a5d8ff"]; Z_sample -> Dec_Concat; C_dec -> Dec_Concat; Dec_Concat -> Dec_Net -> X_hat; } }Data flow in a Conditional Variational Autoencoder. The condition $c$ is incorporated into both the encoder and decoder.CVAE Objective FunctionThe objective function for a CVAE is a conditional version of the ELBO: $$ L_{CVAE}(\phi, \theta; x, c) = \mathbb{E}{q\phi(z|x,c)}[\log p_\theta(x|z,c)] - D_{KL}(q_\phi(z|x,c) || p(z|c)) $$ During training, we maximize this $L_{CVAE}$.The first term is the conditional reconstruction likelihood.The second term is the KL divergence between the approximate posterior $q_\phi(z|x,c)$ and the prior $p(z|c)$. Often, the prior $p(z|c)$ is simplified to a standard normal distribution $p(z) = \mathcal{N}(0, I)$, especially if the condition $c$ primarily influences the decoder. If $p(z|c)$ is used, it could be a learned prior that also depends on $c$.Implementation Sketch (PyTorch-like pseudocode)# Condition: label (e.g., integer from 0 to 9 for MNIST) # Convert label to one-hot embedding: c_embed # Encoder class CVAEEncoder(nn.Module): def __init__(self, input_dim, latent_dim, condition_dim, hidden_dim): super().__init__() # Define layers (e.g., nn.Linear, nn.Conv2d) # Example: self.fc_x = nn.Linear(input_dim, hidden_dim) # Example: self.fc_c = nn.Linear(condition_dim, hidden_dim) # Example: self.fc_combined = nn.Linear(hidden_dim * 2, hidden_dim) # self.fc_mu = nn.Linear(hidden_dim, latent_dim) # self.fc_logvar = nn.Linear(hidden_dim, latent_dim) def forward(self, x, c_embed): # h_x = F.relu(self.fc_x(x.view(x.size(0), -1))) # Flatten x if image # h_c = F.relu(self.fc_c(c_embed)) # combined = torch.cat([h_x, h_c], dim=1) # h_combined = F.relu(self.fc_combined(combined)) # mu = self.fc_mu(h_combined) # logvar = self.fc_logvar(h_combined) return mu, logvar # Decoder class CVAEDecoder(nn.Module): def __init__(self, latent_dim, condition_dim, hidden_dim, output_dim): super().__init__() # Define layers # Example: self.fc_z = nn.Linear(latent_dim, hidden_dim) # Example: self.fc_c = nn.Linear(condition_dim, hidden_dim) # Example: self.fc_combined = nn.Linear(hidden_dim * 2, hidden_dim) # self.fc_out = nn.Linear(hidden_dim, output_dim) def forward(self, z, c_embed): # h_z = F.relu(self.fc_z(z)) # h_c = F.relu(self.fc_c(c_embed)) # combined = torch.cat([h_z, h_c], dim=1) # h_combined = F.relu(self.fc_combined(combined)) # reconstruction = torch.sigmoid(self.fc_out(h_combined)) # Assuming sigmoid for pixel values # return reconstruction.view(-1, num_channels, height, width) # Reshape to image pass # Actual implementation depends on network design # Training loop: # For each batch (x_batch, c_batch_labels): # c_batch_embed = one_hot_encode(c_batch_labels) # mu, logvar = encoder(x_batch, c_batch_embed) # z_sampled = reparameterize(mu, logvar) # x_reconstructed = decoder(z_sampled, c_batch_embed) # # reconstruction_loss = F.binary_cross_entropy(x_reconstructed, x_batch, reduction='sum') # kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # loss = reconstruction_loss + kl_divergence # optimizer.zero_grad() # loss.backward() # optimizer.step()The is the concatenation of the condition embedding c_embed at appropriate points in the encoder and decoder.Evaluation and GenerationConditional Reconstruction: Reconstruct an input image $x$ given its true label $c$.Conditional Generation: Sample $z \sim p(z)$ (e.g., $\mathcal{N}(0,I)$), pick a desired condition $c_{target}$, and generate $\hat{x} = \text{decoder}(z, c_{target_embed})$. You should see samples corresponding to $c_{target}$. For MNIST, you can generate images of specific digits on demand.Implementing a Vector Quantized VAE (VQ-VAE)VQ-VAEs introduce a discrete latent space by quantizing the encoder's output to the closest vector in a learned codebook (or embedding space). This often leads to sharper generated samples compared to standard VAEs with continuous latents, as the decoder learns to map a finite set of representations to outputs.VQ-VAE ArchitectureEncoder $E$: Maps an input $x$ to a continuous representation $z_e(x)$. This output is typically a tensor of vectors, e.g., a $H' \times W' \times D$ feature map if $x$ is an image.Vector Quantizer (VQ) Layer:Codebook: A learnable embedding space $E = {e_1, e_2, \dots, e_K}$, where each $e_i \in \mathbb{R}^D$ is an embedding vector. $K$ is the size of the codebook.Quantization: For each vector $z_{e,j}(x)$ from the encoder's output feature map, find the nearest codebook embedding $e_k$: $$ k_j = \arg\min_i ||z_{e,j}(x) - e_i||2 $$ The quantized representation for $z{e,j}(x)$ is $z_{q,j}(x) = e_{k_j}$.Straight-Through Estimator (STE): During backpropagation, the gradient from the decoder $\nabla_{z_q} L$ is copied directly to the encoder output $z_e(x)$, i.e., $\nabla_{z_e} L = \nabla_{z_q} L$. This allows gradients to flow through the non-differentiable $\arg\min$ operation.Decoder $D$: Maps the quantized latent vectors $z_q(x)$ back to the data space to produce $\hat{x}$.A diagram illustrating the VQ-VAE structure:digraph G { rankdir=TB; node [shape=box, style="filled", fillcolor="#e9ecef"]; edge [color="#495057"]; X [label="Input x", fillcolor="#a5d8ff"]; Encoder [label="Encoder E"]; Ze_map [label="Continuous z_e(x)\n(e.g., H' x W' x D feature map)", shape= Mrecord, fillcolor="#b2f2bb"]; subgraph cluster_vq { label = "Vector Quantizer"; bgcolor="#f8f9fa"; Codebook [label="Codebook\n{e_1, ..., e_K}", shape=cylinder, fillcolor="#ffd8a8"]; Quant_Op [label="Quantization\n(Nearest Neighbor Lookup)", shape=diamond, fillcolor="#fcc2d7"]; } Zq_map [label="Quantized z_q(x)\n(map of e_k vectors)", shape=Mrecord, fillcolor="#c0eb75"]; Decoder [label="Decoder D"]; X_hat [label="Reconstruction x_hat", fillcolor="#a5d8ff"]; X -> Encoder -> Ze_map; Ze_map -> Quant_Op; Codebook -> Quant_Op; Quant_Op -> Zq_map; Zq_map -> Decoder -> X_hat; edge [style=dashed, constraint=false, color="#fa5252"]; Decoder -> Ze_map [label=" Gradient (STE)"]; }Data flow in a Vector Quantized Variational Autoencoder. The encoder output $z_e(x)$ is quantized using a learnable codebook. The Straight-Through Estimator (STE) is used for gradient propagation.VQ-VAE Objective FunctionThe VQ-VAE is trained by minimizing a combined loss: $$ L_{VQVAE} = L_{reconstruction} + L_{codebook} + \beta \cdot L_{commitment} $$ Where:Reconstruction Loss $L_{reconstruction}$: Measures how well the decoder reconstructs the input $x$ from the quantized latents $z_q(x)$. For images, this is often Mean Squared Error (MSE): $$ L_{reconstruction} = ||x - D(z_q(x))||^2_2 $$Codebook Loss $L_{codebook}$: Aims to move the codebook vectors $e_i$ closer to the encoder outputs $z_e(x)$ they are mapped to. It uses a stop-gradient (sg) operation to prevent encoder outputs from growing arbitrarily large: $$ L_{codebook} = ||\text{sg}[z_e(x)] - e_k||^2_2 $$ Here, $e_k$ is the codebook vector closest to $z_e(x)$. The gradient only updates $e_k$.Commitment Loss $L_{commitment}$: Aims to ensure the encoder commits to an embedding and its output does not grow. It encourages $z_e(x)$ to be close to its chosen codebook vector $e_k$: $$ L_{commitment} = ||z_e(x) - \text{sg}[e_k]||^2_2 $$ The $\beta$ hyperparameter (typically between 0.1 and 2.0, often 0.25) controls the strength of this term. The gradient only updates $z_e(x)$.There's no explicit KL divergence term to a prior in the basic VQ-VAE. The discreteness of the latent space itself acts as a form of regularization. A prior can be learned over the discrete latent codes (e.g., using a PixelCNN) for generation, after the VQ-VAE is trained.Implementation Sketch (PyTorch-like pseudocode)class VectorQuantizer(nn.Module): def __init__(self, num_embeddings, embedding_dim, commitment_cost): super().__init__() self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings self.commitment_cost = commitment_cost # Initialize codebook (embeddings) # self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) # self.embedding.weight.data.uniform_(-1.0 / self.num_embeddings, 1.0 / self.num_embeddings) def forward(self, inputs): # inputs: (Batch, Channel, Height, Width) -> (B*H*W, C) # flat_input = inputs.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim) # Calculate distances from inputs to all codebook vectors # distances = (torch.sum(flat_input**2, dim=1, keepdim=True) # + torch.sum(self.embedding.weight**2, dim=1) # - 2 * torch.matmul(flat_input, self.embedding.weight.t())) # Find nearest encoding (indices) # encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) # quantized = self.embedding(encoding_indices).view(inputs.shape) # Get quantized vectors # Calculate losses # codebook_loss = F.mse_loss(quantized.detach(), inputs) # sg[inputs] in original paper, for EMA update. Or sg[quantized] for codebook vectors # commitment_loss = F.mse_loss(inputs, quantized.detach()) # loss = codebook_loss + self.commitment_cost * commitment_loss # Straight-Through Estimator: # quantized = inputs + (quantized - inputs).detach() # STE # Reshape quantized back to (Batch, Channel, Height, Width) # return quantized, loss, encoding_indices.squeeze() pass # Actual implementation needs careful handling of dimensions and STE class VQVAE(nn.Module): def __init__(self, encoder, decoder, vq_layer): super().__init__() # self.encoder = encoder # self.vq_layer = vq_layer # self.decoder = decoder def forward(self, x): # z_e = self.encoder(x) # quantized_latents, vq_loss, _ = self.vq_layer(z_e) # x_reconstructed = self.decoder(quantized_latents) # return x_reconstructed, vq_loss pass # Training loop: # For each batch (x_batch, _): # No labels needed unless for evaluation # x_reconstructed, vq_loss = vq_vae_model(x_batch) # # reconstruction_loss = F.mse_loss(x_reconstructed, x_batch) # total_loss = reconstruction_loss + vq_loss # vq_loss already contains codebook and commitment losses # optimizer.zero_grad() # total_loss.backward() # optimizer.step()The VectorQuantizer module is the most intricate part. Ensure the stop-gradient (.detach() in PyTorch) is correctly applied for the codebook and commitment losses, and that the STE is implemented for the forward pass if you want gradients to flow back to the encoder through the quantization step.Evaluation and GenerationReconstruction Quality: VQ-VAEs often produce sharper reconstructions than standard VAEs due to the discrete latent bottleneck.Generation: To generate new samples, you first need to learn a prior $p(k)$ over the discrete latent codes $k$. This is often done by training an autoregressive model (like PixelCNN or Transformer) on the sequence of encoding_indices obtained from the training data. Once $p(k)$ is learned, sample indices from it, retrieve the corresponding codebook vectors $e_k$, and pass them to the decoder.Comparing CVAE and VQ-VAEBoth CVAE and VQ-VAE offer significant improvements over the vanilla VAE architecture, but they address different aspects and have distinct characteristics.FeatureConditional VAE (CVAE)Vector Quantized VAE (VQ-VAE)Primary GoalControlled generation based on attributesImproved sample fidelity, discrete latent representationLatent SpaceContinuous, conditioned by $c$Discrete, finite set of learned codebook vectorsControl MechanismExplicit input condition $c$Implicit through the structure of the learned codebookSample QualityContent controlled by $c$; can still suffer from blurrinessOften sharper, less blurry samples; generation requires a prior over codesTraining ObjectiveConditional ELBO ($L_{reconstruction} + D_{KL}$)$L_{reconstruction} + L_{codebook} + \beta \cdot L_{commitment}$Gradient FlowStandard reparameterization trickStraight-Through Estimator for quantization stepCommon ChallengesEnsuring condition $c$ is effectively used, posterior collapseCodebook collapse (unused codes), choosing $K$ and $\beta$Generation ProcessSample $z \sim p(z)$, provide $c$, decode $p_\theta(xz,c)$Sample discrete codes $k \sim p(k)$, retrieve $e_k$, decode $p_\theta(xe_k)$When to Use Which?CVAE is preferable when:You need to generate data with specific, controllable attributes.Interpretability of latent space with respect to conditions is desired.Applications like style transfer or data augmentation based on attributes.VQ-VAE is a strong candidate when:The primary goal is high-fidelity, sharp sample generation.A discrete representation of data is beneficial (e.g., for downstream tasks or learning a prior).Working with complex data where continuous latents might lead to overly smooth or blurry outputs (e.g., high-resolution images, audio).Further ExplorationWith the foundations of CVAE and VQ-VAE implementations in place, you can extend your practical work:Implement both models on a dataset like MNIST or Fashion-MNIST.Qualitatively compare the reconstructions and generated samples from both models. For CVAE, test its ability to generate specific classes. For VQ-VAE, observe sample sharpness.Experiment with hyperparameters:For CVAE: Latent dimension, network depth, how the condition is incorporated.For VQ-VAE: Codebook size ($K$), embedding dimension ($D$), commitment cost ($\beta$). Observe the impact of $K$ on sample diversity and reconstruction quality.Visualize the latent space: For CVAE, try to visualize how different conditions map to different regions in the latent space (e.g., using t-SNE on $z$ samples colored by condition $c$). For VQ-VAE, analyze codebook usage.Combine Architectures: Consider how you might combine these ideas, for example, a Conditional VQ-VAE (C-VQ-VAE) where the codebook or its usage could be conditioned.Explore other architectures: Refer back to other models discussed in this chapter, such as Hierarchical VAEs or $\beta$-VAEs, and think about the specific architectural or loss function changes required for their implementation. The principles of modifying the encoder, decoder, latent space, or loss function are central to all these advanced VAEs.This hands-on experience is fundamental for developing a deeper intuition for how architectural choices influence the behavior and capabilities of Variational Autoencoders. By building and experimenting, you will be better equipped to select, design, and adapt VAE models for your specific representation learning and generative modeling tasks.