Let's put the theory of conditional generation into practice. Building upon the architectures and training strategies discussed earlier, this section provides practical guidance on implementing a system that generates images based on specific conditions, such as class labels. This capability is fundamental for tasks requiring controllable synthesis, like generating specific types of objects or augmenting datasets in a targeted manner.
We will focus on implementing a class-conditional Generative Adversarial Network (cGAN), a common and effective approach. The core idea is simple yet powerful: provide the conditional information (the class label) as an additional input to both the generator and the discriminator. This forces the generator to produce images relevant to the label and trains the discriminator to verify if an image is real and matches its given label.
Assume you have a standard GAN implementation (using PyTorch, for instance) trained on an unlabeled dataset. To make it conditional, we need to modify the data pipeline, the generator, and the discriminator. We'll use a dataset like CIFAR-10, which contains images and corresponding class labels.
First, ensure your data loader provides both images and their integer class labels. Since neural networks work best with continuous vectors, we need to convert these integer labels into embeddings. An embedding layer is suitable for this.
# Example parameters
num_classes = 10 # For CIFAR-10
embedding_dim = 16 # Size of the label embedding vector
# Embedding layer
label_embedding = nn.Embedding(num_classes, embedding_dim)
# Inside the training loop:
# real_images, labels = next(data_loader_iter)
# Convert labels to embeddings
label_input = label_embedding(labels) # Shape: (batch_size, embedding_dim)
The embedding_dim
is a hyperparameter you can tune.
The generator needs to receive both the random noise vector z and the condition (the label embedding). A common strategy is to concatenate them.
# Example generator input
latent_dim = 100
# ... (define label_embedding as above)
# Inside the training loop:
noise = torch.randn(batch_size, latent_dim, device=device)
# Generate labels for synthesis (e.g., random labels or specific ones)
gen_labels = torch.randint(0, num_classes, (batch_size,), device=device)
gen_label_input = label_embedding(gen_labels)
# Concatenate noise and label embedding
generator_input = torch.cat((noise, gen_label_input), dim=1)
# Pass through generator
fake_images = generator(generator_input)
The generator's first layer must now accept an input size of latent_dim + embedding_dim
. Alternatively, the label embedding can be projected and added or concatenated at intermediate layers, similar to techniques seen in StyleGAN or attention mechanisms, allowing for more nuanced control.
The discriminator must also receive the conditional information. It needs to determine if an image is real given its supposed class. A typical approach is to feed the label embedding alongside the image data.
One common technique is to embed the label, reshape it spatially, and concatenate it to the image tensor as an extra channel. Another approach involves processing the image through initial convolutional layers and then concatenating the flattened image features with the label embedding before feeding them into later fully connected layers.
Let's illustrate the latter approach:
# Assume 'discriminator_features' extracts features from the image
# Assume 'label_embedding' produces label embeddings
# Inside the discriminator's forward pass:
image_features = self.feature_extractor(image) # e.g., output of conv layers
image_features_flat = image_features.view(image.size(0), -1)
# label_input is the embedded label passed to the discriminator
discriminator_input = torch.cat((image_features_flat, label_input), dim=1)
validity = self.final_layers(discriminator_input) # Final classification layers
The discriminator's layers following the concatenation point need to accommodate the combined size of the image features and the label embedding.
The training loop requires careful handling of labels for both real and fake samples:
Here's a simplified diagram illustrating the information flow in a cGAN:
Information flow in a class-conditional GAN. The condition (label
y
) is embedded and provided as input to both the Generator (along with noisez
) and the Discriminator (along with the real or generated image).
After successful training, you should be able to provide a specific class label (e.g., '3' for digit generation, or 'cat' for CIFAR-10) along with a noise vector to the generator and obtain an image representative of that class. Visual inspection is the first step: generate batches of images conditioned on each class and check if they look appropriate.
For quantitative evaluation (covered in Chapter 5), you can compute metrics like FID or IS per class to assess the quality and diversity within each category. Comparing the distribution of generated samples for a specific class against the distribution of real samples from that same class is also informative.
While we focused on cGANs, conditional generation is also highly effective with diffusion models. Classifier guidance involves using a separate pre-trained classifier to steer the sampling process towards a desired class at each denoising step. A more modern and often preferred technique is classifier-free guidance (detailed in Chapter 4). This involves training the diffusion model intermittently on conditional and unconditional inputs. During sampling, the noise prediction is extrapolated based on both the conditional and unconditional scores, allowing strong guidance without needing an external classifier. Implementing classifier-free guidance involves modifying the U-Net architecture slightly to accept conditional embeddings (like class labels or text embeddings) and adjusting the training process to randomly drop the condition for some percentage of training steps.
This practical exercise demonstrates how to extend generative models for controllable synthesis. By incorporating conditional information, you gain significant control over the output, enabling targeted data generation for a variety of advanced applications discussed in this chapter, from specific object synthesis to guided data augmentation.
© 2025 ApX Machine Learning