Following our conceptual overview of Generative Adversarial Networks (GANs), we now turn to implementation. Building a GAN involves setting up two competing neural networks, the Generator and the Discriminator, and training them simultaneously in an adversarial manner. We will construct a basic GAN using TensorFlow and Keras, focusing on generating images resembling the MNIST dataset. This hands-on example will solidify the core mechanics discussed previously.
The Generator's task is to create synthetic data that mimics the real data distribution. It takes a random noise vector (often drawn from a Gaussian or uniform distribution) as input and transforms it into an output with the same structure as the real data (e.g., a 28x28 grayscale image for MNIST).
A simple Generator can be constructed using tf.keras.Sequential
. We'll start with a Dense layer to project the input noise into a higher-dimensional space, followed by reshaping and potentially using transposed convolutions (Conv2DTranspose
) if building a Convolutional GAN (DCGAN). For simplicity here, let's illustrate with Dense layers, suitable for generating flattened MNIST images or adaptable to simple image structures.
import tensorflow as tf
def build_generator(latent_dim, output_shape):
model = tf.keras.Sequential(name='Generator')
model.add(tf.keras.layers.Input(shape=(latent_dim,)))
# Example using Dense layers - adjust architecture for specific tasks
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dense(256, activation='relu'))
model.add(tf.keras.layers.Dense(output_shape, activation='tanh')) # Use tanh for outputs scaled to [-1, 1]
return model
# Example usage for flattened MNIST (28*28 = 784)
latent_dim = 100
output_dim = 784
generator = build_generator(latent_dim, output_dim)
generator.summary() # Display the model structure
The latent_dim
parameter defines the size of the input noise vector. The final activation function (e.g., tanh
or sigmoid
) should match the expected range of the real data. For MNIST images normalized to [-1, 1]
, tanh
is appropriate.
The Discriminator acts as a binary classifier. Its input is either a real data sample or a fake sample produced by the Generator. Its goal is to output a probability indicating whether the input is real (probability close to 1) or fake (probability close to 0).
Similar to the Generator, a simple Discriminator can be a tf.keras.Sequential
model. It typically consists of Dense layers (or convolutional layers for image data) followed by a final Dense layer with a single output unit and a sigmoid
activation function to produce the probability score.
def build_discriminator(input_shape):
model = tf.keras.Sequential(name='Discriminator')
model.add(tf.keras.layers.Input(shape=(input_shape,)))
# Example using Dense layers
model.add(tf.keras.layers.Dense(256, activation='relu'))
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dropout(0.3)) # Regularization can help
model.add(tf.keras.layers.Dense(1, activation='sigmoid')) # Output probability
return model
# Example usage for flattened MNIST (784)
discriminator = build_discriminator(output_dim) # Input matches generator output / real data
discriminator.summary()
The adversarial training requires distinct loss functions for the Generator and the Discriminator. We typically use Binary Cross-Entropy loss (tf.keras.losses.BinaryCrossentropy
) because the Discriminator performs binary classification (real vs. fake).
Discriminator Loss (LD): This loss encourages the Discriminator to output 1 for real images and 0 for fake images. It's composed of two parts: the loss on real images and the loss on fake images.
LD=−m1i=1∑m[log(D(x(i)))+log(1−D(G(z(i))))]where D(x) is the Discriminator's output for real data x, G(z) is the Generator's output for noise z, m is the batch size.
Generator Loss (LG): This loss encourages the Generator to produce outputs that the Discriminator classifies as real (output 1).
LG=−m1i=1∑mlog(D(G(z(i))))We can implement these using tf.keras.losses.BinaryCrossentropy
. Note that for the generator loss, we compare the discriminator's output for fake images against labels of 1 (real).
# Use from_logits=True if the discriminator's final layer doesn't have sigmoid
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False)
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
def generator_loss(fake_output):
# Generator wants the discriminator to think fake images are real (label 1)
return cross_entropy(tf.ones_like(fake_output), fake_output)
Since the Generator and Discriminator have different objectives and are updated separately, we need distinct optimizers for each. The Adam optimizer is commonly used.
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
Learning rates might require tuning; sometimes different rates are used for the generator and discriminator.
GAN training requires a custom training loop because the updates for the Generator and Discriminator must be carefully orchestrated. Standard model.fit()
is not directly applicable. We'll use tf.GradientTape
to compute gradients for each network.
Here's the structure of a single training step, often wrapped in a tf.function
for performance optimization:
# Assume 'real_images' is a batch from the dataset (e.g., MNIST)
# Assume 'latent_dim' is defined
@tf.function
def train_step(real_images, generator, discriminator, gen_optimizer, disc_optimizer, batch_size, latent_dim):
# 1. Generate noise
noise = tf.random.normal([batch_size, latent_dim])
# Use GradientTape for automatic differentiation
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
# 2. Generate fake images
generated_images = generator(noise, training=True)
# 3. Get Discriminator predictions for real and fake images
real_output = discriminator(real_images, training=True)
fake_output = discriminator(generated_images, training=True)
# 4. Calculate losses
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
# 5. Calculate gradients
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
# 6. Apply gradients to update weights
gen_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
disc_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
return gen_loss, disc_loss
This train_step
function encapsulates one iteration of the adversarial training process: generating fakes, evaluating both networks, calculating losses, computing gradients, and updating weights.
A full training process involves iterating this train_step
over multiple epochs and batches of the real dataset.
# Conceptual Training Loop (needs dataset loading, epoch loops, etc.)
# epochs = ...
# batch_size = ...
# dataset = load_and_prepare_mnist_dataset(...) # Normalized to [-1, 1]
# for epoch in range(epochs):
# print(f"Epoch {epoch+1}/{epochs}")
# epoch_gen_loss_avg = tf.keras.metrics.Mean()
# epoch_disc_loss_avg = tf.keras.metrics.Mean()
# for image_batch in dataset: # Assuming dataset yields batches of real images
# gen_loss, disc_loss = train_step(
# image_batch,
# generator,
# discriminator,
# generator_optimizer,
# discriminator_optimizer,
# batch_size,
# latent_dim
# )
# epoch_gen_loss_avg.update_state(gen_loss)
# epoch_disc_loss_avg.update_state(disc_loss)
# print(f"Generator Loss: {epoch_gen_loss_avg.result():.4f}, Discriminator Loss: {epoch_disc_loss_avg.result():.4f}")
# # Add code here to save checkpoints and generate sample images periodically
# # Reset metrics at the end of each epoch
# epoch_gen_loss_avg.reset_states()
# epoch_disc_loss_avg.reset_states()
This structure provides the foundation for coding a simple GAN. Key aspects include defining the separate Generator and Discriminator networks, setting up their respective loss functions based on binary cross-entropy, using separate optimizers, and orchestrating their updates within a custom training loop using tf.GradientTape
. Monitoring the loss values and periodically visualizing the Generator's output are important steps for assessing training progress. Remember that GAN training can be unstable; careful tuning of hyperparameters (learning rates, network architectures) is often necessary.
© 2025 ApX Machine Learning