Now that we understand the U-Net architecture, how to integrate timestep information, and the simplified loss function (predicting the noise), let's outline the algorithm used to train this noise prediction network ϵθ(xt,t).
The goal is to train the network ϵθ such that, given a noisy input xt and its corresponding timestep t, it accurately predicts the noise ϵ that was originally added to the clean input x0 to produce xt. We achieve this using stochastic gradient descent (or variants like Adam) by repeatedly sampling data points and timesteps, computing the loss, and updating the network's parameters θ.
Here is a breakdown of the training loop, typically performed over batches of data:
Sample a clean data point: Obtain an example x0 from your training dataset q(x0). In practice, you'll sample a mini-batch of data points. For simplicity, let's consider a single data point first.
Sample a timestep: Choose a timestep t uniformly at random from the range {1,2,...,T}, where T is the total number of diffusion steps defined in the forward process. This ensures the network learns to denoise across all noise levels.
Sample noise: Draw a noise sample ϵ from a standard Gaussian distribution: ϵ∼N(0,I). This is the "ground truth" noise that the network will try to predict.
Compute the noisy sample xt: Using the sampled x0, t, and ϵ, calculate the corresponding noisy version xt via the closed-form equation derived from the forward process:
xt=αˉtx0+1−αˉtϵ
Remember, αˉt is the cumulative product of (1−βi) up to timestep t, based on the predefined noise schedule βt.
Predict noise using the network: Pass the noisy sample xt and the timestep t (usually encoded, e.g., using sinusoidal embeddings) through the U-Net model ϵθ to get the predicted noise: ϵθ(xt,t).
Calculate the loss: Compute the difference between the actual noise ϵ (sampled in step 3) and the predicted noise ϵθ(xt,t) using the chosen loss function. As discussed previously, this is typically the Mean Squared Error (MSE):
L=∣∣ϵ−ϵθ(xt,t)∣∣2
When using mini-batches, this loss is averaged over all samples in the batch.
Compute gradients: Calculate the gradients of this loss L with respect to the network parameters θ: ∇θL.
Update network parameters: Update the parameters θ using an optimization algorithm (like Adam) and the computed gradients. For example: θ←θ−η∇θL, where η is the learning rate.
This entire sequence (steps 1-8) constitutes one training step. This process is repeated for many iterations or epochs over the dataset until the model's loss converges and it effectively learns to predict the noise added at various timesteps.
A diagram illustrating the core steps within a single training iteration for the noise prediction network.
By iterating through these steps, the U-Net gradually learns the complex mapping from a noisy image and a timestep to the noise component itself, forming the foundation for the reverse diffusion (denoising) process used during generation.