Standard auto-wrapping policies typically satisfy the requirements of homogeneous architectures like BERT or GPT. However, complex production environments often demand bespoke strategies for distributed training. Hybrid multimodal models, Mixture-of-Experts (MoE) configurations, and architectures with non-standard residual connections introduce memory patterns that common default policies fail to optimize. In these scenarios, relying on policies like transformer_auto_wrap_policy can result in suboptimal sharding, leading to either memory spikes (when shards are too large) or communication bottlenecks (when shards are too small).
This section addresses the construction of custom wrapping policies using Python lambda functions and the functools.partial utility. By defining explicit rules for graph segmentation, you gain control over the trade-off between memory granularity and network overhead.
In PyTorch FSDP, the auto_wrap_policy argument accepts a callable that determines whether a specific submodule should be wrapped in its own FSDP unit. During initialization, FSDP traverses the model's module tree. For every module encountered, it invokes this callable with three specific arguments:
module: The nn.Module instance currently being evaluated.recurse: A boolean indicating if the traversal can continue to the module's children.unwrapped_params: The number of parameters in the current module that have not yet been assigned to a child FSDP unit.The function must return a boolean. If True, the module is wrapped. If False, it remains part of the parent FSDP unit (or the root).
The logic operates recursively. If a parent module is wrapped, its parameters are sharded. However, if that parent contains children that are also wrapped, the children become independent FSDP units. This nesting is what enables the "gather-compute-scatter" overlap. The system gathers the child, computes, frees the child, then gathers the parent (if needed), computes the remainder, and so forth.
A custom policy often combines parameter counting with module type inspection. Consider a scenario where you are training a model with a large embedding layer that must be sharded, but numerous small projection heads that should remain aggregated to avoid launching thousands of tiny NCCL kernels.
The following code demonstrates a policy that wraps modules based on two conditions: they must be of a specific type (e.g., a Transformer Block) OR they must exceed a specific parameter threshold.
import functools
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
_or_policy,
lambda_auto_wrap_policy,
)
import torch.nn as nn
def hybrid_shard_policy(
module: nn.Module,
recurse: bool,
unwrapped_params: int,
min_num_params: int = 1e6,
target_module_types: tuple = ()
) -> bool:
"""
Custom policy that wraps if:
1. The module is an instance of target_module_types
OR
2. The module contains more parameters than min_num_params
"""
# Always recurse if the module allows it to ensure we check children
if recurse:
return True
# Condition 1: Type-based wrapping
if isinstance(module, target_module_types):
return True
# Condition 2: Size-based wrapping
# wrapping only if the remaining parameters justify a new shard
if unwrapped_params >= min_num_params:
return True
return False
# Application within the FSDP constructor
# Assume MyTransformerBlock is defined elsewhere
# my_model = ...
custom_policy = functools.partial(
hybrid_shard_policy,
min_num_params=5 * 10**6, # 5 Million params minimum
target_module_types=(nn.TransformerEncoderLayer, )
)
# fsdp_model = FSDP(my_model, auto_wrap_policy=custom_policy, ...)
In this implementation, the logic prevents "over-sharding." A purely recursive wrapper might isolate a small LayerNorm or Dropout module if not carefully restricted. By enforcing a min_num_params floor, we ensure that very small layers remain part of their parent block, reducing the number of distinct all-gather operations required during the forward pass.
The following diagram illustrates how different policies affect the sharding of a hybrid architecture containing an image encoder and a text decoder.
Comparison of graph partitioning outcomes. Type-based wrapping (left) targets architectural units, while parameter-count wrapping (right) targets memory consumption directly.
When training architectures like Mixture of Experts (MoE), standard wrapping fails because the "Expert" layers are often sparse. If you wrap the entire MoE block as a single unit, you force an all-gather of all experts even if only one is active, defeating the purpose of sparse computation.
For MoE, the custom policy must target the individual experts. This requires inspecting the module name or structure rather than just the class type, especially if the experts are instances of generic nn.Linear layers.
def moe_sparse_policy(module, recurse, unwrapped_params):
if recurse:
return True
# Identify if this module is a specific Expert container
# Check logic depends on specific model implementation details
if hasattr(module, 'is_sparse_expert') and module.is_sparse_expert:
return True
return False
By wrapping individual experts, FSDP only gathers the parameters for the active expert during the forward pass (assuming the experts are on different execution paths or CUDA streams), significantly reducing peak memory usage.
Designing a custom policy involves balancing memory efficiency against communication latency.
The objective is to find the "sweet spot" where the shard size is large enough to saturate network bandwidth but small enough to fit within the GPU memory budget.
The relationship between shard size and throughput. Extremely small shards increase communication overhead (red dotted line), while extremely large shards degrade performance due to memory fragmentation and lack of computation overlap.
Occasionally, a custom policy is needed to prevent wrapping. Parameters that are shared across the network, such as embedding layers tied to the output projection (common in GPT architectures), often cause synchronization issues if they are sharded in different FSDP units.
A lambda policy can explicitly return False for known shared modules, forcing them to be managed by the root FSDP unit or a specific parent. This ensures that the shared parameter is broadcast only once per forward pass, rather than being gathered and scattered multiple times by different wrappers.
def exclude_embeddings_policy(module, recurse, unwrapped_params):
# Identify embedding layers to exclude
if isinstance(module, (nn.Embedding, nn.EmbeddingBag)):
return False # Force exclusion
# Proceed with standard size-based wrapping for other layers
return unwrapped_params >= 1e7
This strategy is essential when using sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, where the optimizer states are sharded but the parameters might persist. By explicitly controlling the cut points, you ensure the system behavior aligns with the specific mathematical requirements of the model architecture.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with