Building an affine coupling layer from scratch translates the mathematical foundations into functional PyTorch code. This implementation demonstrates how input partitioning allows for simultaneous high-performance sampling and density estimation.
An affine coupling layer operates by splitting the input tensor into two halves. The first half is left completely unchanged and acts as the input to a neural network. This network computes the scale and translation parameters used to transform the second half of the input.
The forward transformation relies on the following mathematical operations:
Here, represents the input data, represents the output, and the functions and correspond to the scale and translation neural networks. The symbol denotes element-wise multiplication.
Because the and networks only process , the scaling factor is easily computable during both the forward and inverse passes. The inverse pass mirrors the forward pass closely:
The diagram below illustrates the flow of data through an affine coupling layer.
Data flow architecture of an affine coupling layer demonstrating the input split and parameterization.
We will build an AffineCouplingLayer class that inherits from torch.nn.Module. For the and functions, we can use a single multi-layer perceptron that outputs a tensor with twice the necessary dimensions. We will then split this output tensor into the scale and translation parameters.
To ensure numerical stability during training, it is common practice to apply a hyperbolic tangent activation function to the scale parameter before exponentiating it. This prevents the exponential function from producing extremely large values that cause exploding gradients.
import torch
import torch.nn as nn
class AffineCouplingLayer(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
# The network processes half of the input dimensions
self.half_dim = input_dim // 2
# A simple multi-layer perceptron to compute both s and t
self.st_net = nn.Sequential(
nn.Linear(self.half_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
# Output dimension is twice the half_dim to yield both s and t
nn.Linear(hidden_dim, self.half_dim * 2)
)
def forward(self, x):
# Split the input in half along the feature dimension
x1, x2 = x.chunk(2, dim=-1)
# Compute scale and translation parameters from x1
st_params = self.st_net(x1)
s, t = st_params.chunk(2, dim=-1)
# Constrain the scale parameter for numerical stability
s = torch.tanh(s)
# Apply the affine transformation to x2
y1 = x1
y2 = x2 * torch.exp(s) + t
# Recombine the output components
y = torch.cat([y1, y2], dim=-1)
# Compute the log-determinant of the Jacobian
# It is the sum of the constrained scale parameters
log_det_jacobian = s.sum(dim=-1)
return y, log_det_jacobian
def inverse(self, y):
# Split the output in half
y1, y2 = y.chunk(2, dim=-1)
# Recompute scale and translation parameters from y1
st_params = self.st_net(y1)
s, t = st_params.chunk(2, dim=-1)
# Apply the exact same constraint to s
s = torch.tanh(s)
# Invert the affine transformation
x1 = y1
x2 = (y2 - t) * torch.exp(-s)
# Recombine to form the original input
x = torch.cat([x1, x2], dim=-1)
return x
Notice the use of dim=-1 in the chunk and cat operations. This specifies that the splitting and concatenation should always happen along the last dimension of the tensor. This design allows the layer to accept single data points or batches of data interchangeably.
When building normalizing flows, a simple error in the inverse method or a mismatched activation function will ruin the model's ability to estimate probabilities accurately. You should consistently test your invertible layers to guarantee that passing data through the forward method and then the inverse method recovers the original input exactly.
We can instantiate our new layer and run a quick verification test.
# Instantiate the coupling layer for an 8-dimensional input
layer = AffineCouplingLayer(input_dim=8, hidden_dim=32)
# Create a batch of random dummy data
original_x = torch.randn(4, 8)
# Pass the data through the forward method
y, log_det = layer(original_x)
# Recover the data using the inverse method
reconstructed_x = layer.inverse(y)
# Measure the maximum absolute difference between the original and reconstruction
max_error = torch.max(torch.abs(original_x - reconstructed_x))
print(f"Maximum reconstruction error: {max_error.item():.6e}")
If the implementation is correct, the maximum reconstruction error should be a very small number close to zero. The error will not be exactly zero due to the inherent limitations of floating-point precision in computer hardware, but an error in the magnitude of 1e-6 or smaller confirms that the layer is functioning as expected.
Was this section helpful?
© 2026 ApX Machine LearningAI Ethics & Transparency•