Processing high-dimensional data, such as high-resolution images, through dozens of normalizing flow layers requires significant computational resources. Every dimension of the input vector passes through every layer of the model in standard architectures. Calculating the affine coupling transformations and their Jacobians at every step becomes a severe bottleneck when training on inputs with hundreds of thousands of individual components.
To address this computational burden, we use multi-scale architectures. This design pattern was popularized by the RealNVP model. Instead of transforming the entire data volume at every layer, a multi-scale architecture factors out a portion of the variables at regular intervals. These factored-out variables are sent directly to the final output, while only the remaining variables continue through the rest of the flow network.
Let us examine how this works mathematically and structurally. Assume we have an intermediate tensor at stage . We apply a series of flow operations to , and then we split the result into two halves along the channel dimension:
The tensor represents the variables that are factored out and finalized. We do not process them any further. The tensor contains the active variables that are passed to the next stage of the network.
Multi-scale architecture factoring out variables at intermediate stages to reduce the computational load for subsequent flow blocks.
This structural choice has three important benefits for density estimation.
First, it directly reduces memory and computational requirements. Because half of the variables are dropped at each split, the spatial or channel dimensions of the tensors shrink. The neural networks used to compute the scale and translation in the subsequent coupling layers now operate on significantly smaller inputs.
Second, it creates a natural hierarchy of learned features. The early layers of the flow operate on the full spatial resolution of the data. The variables factored out here tend to represent fine-grained, local details like texture and edges. As the data passes through squeezing operations and subsequent splits, the spatial resolution decreases but the receptive field effectively increases. The variables factored out at the very end of the network capture global structure and semantic information.
Third, factoring out variables distributes the loss function computation throughout the network. In exact maximum likelihood estimation, we optimize the sum of the log-probabilities of the base distribution and the log-determinants of the Jacobians. By evaluating the base distribution on at multiple intermediate stages, we provide shorter gradient paths for the early layers. This greatly improves training stability in deep models.
The total log-likelihood in a multi-scale model is the sum of the likelihoods of all factored-out components. If a model has scales, the objective function aggregates the log-probabilities of all the outputs plus the sum of all Jacobian log-determinants computed along the active paths.
To implement this in PyTorch, you need a mechanism to slice the tensor and store the finalized variables. Here is a practical example showing how a split operation works in a forward pass.
import torch
import torch.nn as nn
class MultiScaleSplit(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
# Split the input tensor into two halves along the channel dimension (dim=1)
channels = x.shape[1]
z, h = torch.split(x, channels // 2, dim=1)
# Since splitting is a linear operation with identity transformations,
# the log-determinant of the Jacobian for this specific step is zero.
log_det = torch.zeros(x.shape[0], device=x.device)
return z, h, log_det
def inverse(self, z, h):
# Concatenate the factored-out variables back with the active variables
x = torch.cat([z, h], dim=1)
return x
During the inverse pass, which is used for generating new data, the process is reversed. You sample all the independent partitions from your base distribution, which is usually a standard normal distribution. You pass the final partition through the last flow block in reverse. Then, you concatenate the result with , pass that combined tensor through the preceding flow block in reverse, and repeat the process until you reconstruct the full-dimensional output .
This architecture ensures that normalizing flows remain tractable even when dealing with high-dimensional datasets. You will use these multi-scale patterns extensively when scaling up coupling layers for image generation tasks.
Was this section helpful?
© 2026 ApX Machine LearningAI Ethics & Transparency•