Translating the mathematical formulation of planar flows into a functional PyTorch model involves writing a planar flow layer from scratch, stacking multiple layers to increase the expressiveness of the model, and writing a training loop to map a standard normal distribution to a complex 2D target distribution.
A single planar flow applies a simple invertible transformation to an input vector. We recall the mathematical definition of a planar flow transformation:
Here, and are learnable parameter vectors of the same dimension as our input , and is a learnable scalar bias. The function acts as our non-linear activation.
For this transformation to serve as a valid normalizing flow, it must be strictly invertible. Invertibility requires that the derivative of the transformation does not change sign, which translates to the mathematical condition . We ensure this condition is met during training by dynamically modifying the vector . We compute a safe vector using the following modification:
We build this operation directly into the forward pass of our PyTorch module. We rely on torch.nn.functional.softplus to compute the term safely without numerical overflow.
import torch
import torch.nn as nn
import torch.nn.functional as F
class PlanarFlow(nn.Module):
def __init__(self, dim=2):
super().__init__()
self.w = nn.Parameter(torch.randn(1, dim))
self.u = nn.Parameter(torch.randn(1, dim))
self.b = nn.Parameter(torch.randn(1))
def forward(self, z):
# Compute the dot product of w and u
wu = torch.sum(self.w * self.u, dim=-1, keepdim=True)
# Enforce the invertibility constraint to compute u_hat
m_wu = -1 + F.softplus(wu) - wu
w_norm_sq = torch.sum(self.w ** 2, dim=-1, keepdim=True)
u_hat = self.u + m_wu * self.w / w_norm_sq
# Compute the transformation f(z)
linear_term = F.linear(z, self.w, self.b)
activation = torch.tanh(linear_term)
f_z = z + u_hat * activation
# Compute the log determinant of the Jacobian
psi = (1 - activation ** 2) * self.w
det_jacobian = 1 + torch.sum(u_hat * psi, dim=-1)
# Add a small epsilon to prevent taking log of absolute zero
log_det_jacobian = torch.log(torch.abs(det_jacobian) + 1e-6)
return f_z, log_det_jacobian
The forward method returns both the transformed samples f_z and the log_det_jacobian. Returning the log determinant at every step is a standard practice in normalizing flows because we need to accumulate these values to calculate the final probability density.
A single planar flow behaves like a single ridge function. It stretches and compresses the probability space along a specific hyperplane. To model complex 2D shapes, we must pass our data through a sequence of these transformations.
We accomplish this by initializing a list of PlanarFlow modules using nn.ModuleList. In the forward pass, we iterate through these layers, updating our sample and accumulating the sum of the log determinants.
class NormalizingFlow(nn.Module):
def __init__(self, dim=2, num_layers=8):
super().__init__()
self.layers = nn.ModuleList([PlanarFlow(dim) for _ in range(num_layers)])
def forward(self, z):
log_det_sum = torch.zeros(z.shape[0], device=z.device)
for layer in self.layers:
z, log_det = layer(z)
log_det_sum += log_det
return z, log_det_sum
Forward pass through a stacked planar flow model showing the sequence of transformations and log determinant accumulation.
Because the inverse operation of a planar flow is not analytically tractable, computing exact maximum likelihood on a static dataset is computationally prohibitive. Instead, planar flows are typically trained to match an unnormalized target density function by minimizing the reverse Kullback-Leibler (KL) divergence.
We will define an unnormalized 2D energy function. The flow will learn to generate points that fall into the low-energy regions of this function.
def target_energy(z):
x, y = z[:, 0], z[:, 1]
# Ring-like structure
u1 = 0.5 * ((torch.norm(z, p=2, dim=1) - 2.0) / 0.4) ** 2
# Bimodal separation to create a split ring
u2 = -torch.log(
torch.exp(-0.5 * ((x - 2) / 0.6) ** 2) +
torch.exp(-0.5 * ((x + 2) / 0.6) ** 2) + 1e-6
)
return u1 + u2
The training process involves drawing samples from a simple base distribution and pushing them through our normalizing flow. We calculate the log-probability of the final generated samples using the change of variables formula. We then compare this generated log-probability against the negative target energy to compute our loss.
The change of variables formula allows us to compute the density of our generated samples based on the base distribution :
We minimize the expected difference between our generated density and the target density.
import torch.optim as optim
# Initialize the model, optimizer, and base distribution
model = NormalizingFlow(dim=2, num_layers=16)
optimizer = optim.Adam(model.parameters(), lr=0.005)
# We use a standard 2D normal distribution as our starting point
base_dist = torch.distributions.MultivariateNormal(
torch.zeros(2), torch.eye(2)
)
epochs = 3000
batch_size = 512
for epoch in range(epochs):
optimizer.zero_grad()
# 1. Sample from the base distribution
z0 = base_dist.sample((batch_size,))
# 2. Pass samples through the stacked flow
zK, log_det_sum = model(z0)
# 3. Compute the log probability of the base samples
log_prob_z0 = base_dist.log_prob(z0)
# 4. Apply the change of variables formula
log_prob_qK = log_prob_z0 - log_det_sum
# 5. Calculate target log probability (negative energy)
target_log_prob = -target_energy(zK)
# 6. Compute reverse KL divergence loss
loss = torch.mean(log_prob_qK - target_log_prob)
loss.backward()
optimizer.step()
if epoch % 500 == 0:
print(f"Epoch {epoch} | Loss: {loss.item():.4f}")
In this loop, the optimizer continuously adjusts the parameters , , and of every planar flow layer. With 16 layers, the model possesses enough flexibility to warp the initial standard Gaussian blob into the disconnected, bimodal target distribution defined by our energy function. By inspecting the samples at the end of training, you will observe the simple points successfully migrating into the specified low-energy regions.
Was this section helpful?
© 2026 ApX Machine LearningAI Ethics & Transparency•