Alright, let's translate the theory of Conditional GANs (cGANs) into practice. You've learned that cGANs allow us to direct the generator's output by providing additional information, typically a class label or some other attribute, denoted as y. This practical section guides you through the essential steps to build and train a cGAN, focusing on how to integrate this conditional input y into the generator and discriminator networks.
We'll use the familiar MNIST dataset as our example. It consists of grayscale images of handwritten digits (0-9), making the digit label the natural choice for our condition y. Our goal is to train a generator that can produce an image of a specific digit when prompted with the corresponding label.
First, load your dataset (e.g., MNIST) using your preferred deep learning framework's utilities. Unlike standard GAN training where we only need the images x, for a cGAN, we also need the corresponding labels y. Ensure your data loader provides pairs of (x,y).
The labels y are typically integers (0 to 9 for MNIST). Since neural networks work best with numerical vectors, we need to convert these integer labels into a suitable format. A common and effective approach is to use embedding layers. We can represent each label as a learnable vector. Alternatively, for discrete labels like in MNIST, one-hot encoding is a straightforward option, although embedding often provides more flexibility and potentially better performance, especially with a large number of classes.
Let's assume we use embedding. If we have Nc classes, we can create an embedding layer that maps each integer label i∈{0,1,...,Nc−1} to a dense vector of a chosen dimension, say de.
The generator, G, must now accept two inputs: the random noise vector z and the conditional information y. The core idea is to combine these inputs effectively so the generator learns to use y to shape its output.
Here's a structure (PyTorch-like):
# Generator Structure
class ConditionalGenerator(nn.Module):
def __init__(self, noise_dim, num_classes, embedding_dim, output_channels):
super().__init__()
self.label_embedding = nn.Embedding(num_classes, embedding_dim)
# Define the main generator network body
# Input dimension for the first layer should accommodate noise_dim + embedding_dim
self.main = nn.Sequential(
# Example: Transposed Conv layers, Batch Norm, ReLU
# nn.ConvTranspose2d(noise_dim + embedding_dim, ...)
# ... other layers ...
# nn.ConvTranspose2d(..., output_channels, ..., bias=False),
# nn.Tanh() # Output activation often Tanh for images scaled to [-1, 1]
)
def forward(self, noise, labels):
# Embed labels
label_embedding_vector = self.label_embedding(labels) # Shape: (batch_size, embedding_dim)
# Reshape embedding if needed and concatenate with noise
# Assume noise is shape (batch_size, noise_dim, 1, 1) for ConvTranspose2d
# We need to reshape label_embedding_vector to match spatially
label_embedding_reshaped = label_embedding_vector.view(label_embedding_vector.size(0), label_embedding_vector.size(1), 1, 1)
# Concatenate along the channel dimension
combined_input = torch.cat([noise, label_embedding_reshaped], dim=1) # Shape: (batch_size, noise_dim + embedding_dim, 1, 1)
# Generate image
generated_image = self.main(combined_input)
return generated_image
Similarly, the discriminator, D, must now evaluate not just the image x, but the pair (x,y). It needs to determine if the image x is a real image corresponding to label y, or a fake image generated for label y.
Here's a structure using late concatenation (PyTorch-like):
# Discriminator Structure
class ConditionalDiscriminator(nn.Module):
def __init__(self, num_classes, embedding_dim, input_channels):
super().__init__()
self.label_embedding = nn.Embedding(num_classes, embedding_dim)
# Define the image processing part (e.g., Conv layers)
self.image_processor = nn.Sequential(
# Example: Conv layers, Batch Norm, LeakyReLU
# nn.Conv2d(input_channels, ...)
# ... other conv layers ...
)
# Define the final classifier part
# Input dimension needs to accommodate flattened image features + embedding_dim
self.classifier = nn.Sequential(
# Example: Flatten, Linear layers, LeakyReLU
# nn.Flatten(),
# nn.Linear(feature_dim + embedding_dim, ...)
# nn.LeakyReLU(0.2, inplace=True),
# nn.Linear(..., 1) # Output layer (no sigmoid if using BCEWithLogitsLoss or Wasserstein loss)
)
# Calculate feature_dim based on image_processor output shape
def forward(self, image, labels):
# Process image
image_features = self.image_processor(image) # Shape depends on layers
image_features_flat = image_features.view(image_features.size(0), -1) # Flatten features
# Embed labels
label_embedding_vector = self.label_embedding(labels) # Shape: (batch_size, embedding_dim)
# Concatenate flattened features and label embedding
combined_input = torch.cat([image_features_flat, label_embedding_vector], dim=1)
# Classify
validity = self.classifier(combined_input)
return validity
The following diagram illustrates the data flow in a cGAN, highlighting where the conditional label y is incorporated.
Data flow in a Conditional GAN. The condition y (yellow) is embedded and combined with the noise z (blue) in the generator, and with image features (green) in the discriminator. The discriminator outputs a decision (red) based on both the image and its supposed condition.
The objective function remains a minimax game, but now D and G also depend on y. The value function V(D,G) is:
GminDmaxV(D,G)=E(x,y)∼pdata(x,y)[logD(x,y)]+Ez∼pz(z),y∼py(y)[log(1−D(G(z,y),y))]Here, pdata(x,y) is the joint distribution of real data and labels, and py(y) is the distribution of labels (which we often sample uniformly or according to the training set distribution).
In practice, when using standard binary cross-entropy loss (often implemented with BCEWithLogitsLoss
for stability), the discriminator tries to output high values for real pairs (x,y) and low values for fake pairs (G(z,y),y). The generator tries to fool the discriminator by making D(G(z,y),y) output high values. Remember to use the same label y when generating G(z,y) and passing it to the discriminator.
The cGAN training loop follows the standard GAN pattern, with the crucial addition of handling the labels y:
Update Discriminator:
.detach()
on xfake when training D.Update Generator:
.detach()
here.Repeat these steps for the desired number of epochs. Remember standard GAN training practices like using appropriate optimizers (e.g., Adam), learning rates, and potentially stabilization techniques discussed in Chapter 3 if needed.
Once training is complete, you can generate images conditioned on specific labels. Simply:
For instance, to generate only images of the digit '7', you would repeatedly call G(z,label=7) with different noise vectors z.
Evaluating a cGAN involves assessing not only the quality and diversity of generated images (using metrics like FID or IS, discussed in Chapter 5) but also the conditional consistency. Did the generator produce an image that actually matches the requested label y? This can be checked qualitatively by visual inspection or quantitatively by feeding the generated images G(z,y) into a pre-trained classifier (independent of the cGAN's discriminator) and measuring its accuracy in predicting y.
This practical exercise provides the blueprint for implementing cGANs. By carefully integrating conditional information into both the generator and discriminator, you gain significant control over the generation process, enabling targeted synthesis based on specific attributes. Experiment with different embedding dimensions and concatenation strategies to see how they impact performance on your chosen dataset.
© 2025 ApX Machine Learning