Standard auto-wrapping policies typically satisfy the requirements of homogeneous architectures like BERT or GPT. However, complex production environments often demand bespoke strategies for distributed training. Hybrid multimodal models, Mixture-of-Experts (MoE) configurations, and architectures with non-standard residual connections introduce memory patterns that common default policies fail to optimize. In these scenarios, relying on policies like transformer_auto_wrap_policy can result in suboptimal sharding, leading to either memory spikes (when shards are too large) or communication bottlenecks (when shards are too small).This section addresses the construction of custom wrapping policies using Python lambda functions and the functools.partial utility. By defining explicit rules for graph segmentation, you gain control over the trade-off between memory granularity and network overhead.The Mechanics of the Wrapping PolicyIn PyTorch FSDP, the auto_wrap_policy argument accepts a callable that determines whether a specific submodule should be wrapped in its own FSDP unit. During initialization, FSDP traverses the model's module tree. For every module encountered, it invokes this callable with three specific arguments:module: The nn.Module instance currently being evaluated.recurse: A boolean indicating if the traversal can continue to the module's children.unwrapped_params: The number of parameters in the current module that have not yet been assigned to a child FSDP unit.The function must return a boolean. If True, the module is wrapped. If False, it remains part of the parent FSDP unit (or the root).The logic operates recursively. If a parent module is wrapped, its parameters are sharded. However, if that parent contains children that are also wrapped, the children become independent FSDP units. This nesting is what enables the "gather-compute-scatter" overlap. The system gathers the child, computes, frees the child, then gathers the parent (if needed), computes the remainder, and so forth.Implementing a Logic-Based PolicyA custom policy often combines parameter counting with module type inspection. Consider a scenario where you are training a model with a large embedding layer that must be sharded, but numerous small projection heads that should remain aggregated to avoid launching thousands of tiny NCCL kernels.The following code demonstrates a policy that wraps modules based on two conditions: they must be of a specific type (e.g., a Transformer Block) OR they must exceed a specific parameter threshold.import functools from torch.distributed.fsdp.wrap import ( transformer_auto_wrap_policy, _or_policy, lambda_auto_wrap_policy, ) import torch.nn as nn def hybrid_shard_policy( module: nn.Module, recurse: bool, unwrapped_params: int, min_num_params: int = 1e6, target_module_types: tuple = () ) -> bool: """ Custom policy that wraps if: 1. The module is an instance of target_module_types OR 2. The module contains more parameters than min_num_params """ # Always recurse if the module allows it to ensure we check children if recurse: return True # Condition 1: Type-based wrapping if isinstance(module, target_module_types): return True # Condition 2: Size-based wrapping # wrapping only if the remaining parameters justify a new shard if unwrapped_params >= min_num_params: return True return False # Application within the FSDP constructor # Assume MyTransformerBlock is defined elsewhere # my_model = ... custom_policy = functools.partial( hybrid_shard_policy, min_num_params=5 * 10**6, # 5 Million params minimum target_module_types=(nn.TransformerEncoderLayer, ) ) # fsdp_model = FSDP(my_model, auto_wrap_policy=custom_policy, ...)In this implementation, the logic prevents "over-sharding." A purely recursive wrapper might isolate a small LayerNorm or Dropout module if not carefully restricted. By enforcing a min_num_params floor, we ensure that very small layers remain part of their parent block, reducing the number of distinct all-gather operations required during the forward pass.Visualization of Graph PartitioningThe following diagram illustrates how different policies affect the sharding of a hybrid architecture containing an image encoder and a text decoder.digraph FSDP_Wrapping { rankdir=TB; node [shape=box, style=filled, fontname="Helvetica"]; subgraph cluster_0 { label = "Policy: Type-Based (Wrap 'Block')"; style=dashed; color="#adb5bd"; root1 [label="Root Model", fillcolor="#e9ecef"]; enc1 [label="Image Encoder", fillcolor="#e9ecef"]; dec1 [label="Text Decoder", fillcolor="#e9ecef"]; blk1 [label="Block 1\n(Wrapped)", fillcolor="#b2f2bb"]; blk2 [label="Block 2\n(Wrapped)", fillcolor="#b2f2bb"]; ln1 [label="LayerNorm\n(Unwrapped)", fillcolor="#ffc9c9"]; root1 -> enc1; root1 -> dec1; enc1 -> blk1; dec1 -> blk2; dec1 -> ln1; } subgraph cluster_1 { label = "Policy: Parameter Count (>1M)"; style=dashed; color="#adb5bd"; root2 [label="Root Model", fillcolor="#e9ecef"]; lg_layer [label="Linear (2M params)\n(Wrapped)", fillcolor="#b2f2bb"]; sm_layer [label="Linear (0.5M params)\n(Unwrapped)", fillcolor="#ffc9c9"]; root2 -> lg_layer; root2 -> sm_layer; } }Comparison of graph partitioning outcomes. Type-based wrapping (left) targets architectural units, while parameter-count wrapping (right) targets memory consumption directly.Handling Heterogeneous ArchitecturesWhen training architectures like Mixture of Experts (MoE), standard wrapping fails because the "Expert" layers are often sparse. If you wrap the entire MoE block as a single unit, you force an all-gather of all experts even if only one is active, defeating the purpose of sparse computation.For MoE, the custom policy must target the individual experts. This requires inspecting the module name or structure rather than just the class type, especially if the experts are instances of generic nn.Linear layers.def moe_sparse_policy(module, recurse, unwrapped_params): if recurse: return True # Identify if this module is a specific Expert container # Check logic depends on specific model implementation details if hasattr(module, 'is_sparse_expert') and module.is_sparse_expert: return True return FalseBy wrapping individual experts, FSDP only gathers the parameters for the active expert during the forward pass (assuming the experts are on different execution paths or CUDA streams), significantly reducing peak memory usage.Optimizing Shard GranularityDesigning a custom policy involves balancing memory efficiency against communication latency.Too Granular (Over-wrapping): Wrapping every small layer results in hundreds of FSDP units. This triggers excessive CUDA kernel launches and NCCL synchronization barriers. The GPU spends more time waiting for network operations than computing.Too Coarse (Under-wrapping): Wrapping only the top-level container means large chunks of parameters must be gathered at once. This increases the peak memory requirement, potentially causing Out-Of-Memory (OOM) errors.The objective is to find the "sweet spot" where the shard size is large enough to saturate network bandwidth but small enough to fit within the GPU memory budget.{ "layout": { "title": "Impact of Wrapping Granularity on Training Throughput", "xaxis": { "title": "Average Parameters per FSDP Unit (Millions)", "type": "log" }, "yaxis": { "title": "Throughput (Tokens/Sec)", "showgrid": true, "gridcolor": "#e9ecef" }, "showlegend": true, "plot_bgcolor": "white" }, "data": [ { "x": [0.1, 0.5, 1, 5, 10, 50, 100, 500], "y": [1200, 2800, 4100, 4350, 4400, 4200, 3600, 1500], "mode": "lines+markers", "name": "Throughput", "line": {"color": "#228be6", "width": 3}, "marker": {"size": 8} }, { "x": [0.1, 0.5, 1, 5, 10, 50, 100, 500], "y": [85, 60, 45, 40, 38, 35, 32, 30], "mode": "lines", "name": "Communication Overhead (%)", "yaxis": "y2", "line": {"color": "#fa5252", "dash": "dot"} } ] }The relationship between shard size and throughput. Extremely small shards increase communication overhead (red dotted line), while extremely large shards degrade performance due to memory fragmentation and lack of computation overlap.Exclusion StrategiesOccasionally, a custom policy is needed to prevent wrapping. Parameters that are shared across the network, such as embedding layers tied to the output projection (common in GPT architectures), often cause synchronization issues if they are sharded in different FSDP units.A lambda policy can explicitly return False for known shared modules, forcing them to be managed by the root FSDP unit or a specific parent. This ensures that the shared parameter is broadcast only once per forward pass, rather than being gathered and scattered multiple times by different wrappers.def exclude_embeddings_policy(module, recurse, unwrapped_params): # Identify embedding layers to exclude if isinstance(module, (nn.Embedding, nn.EmbeddingBag)): return False # Force exclusion # Proceed with standard size-based wrapping for other layers return unwrapped_params >= 1e7This strategy is essential when using sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, where the optimizer states are sharded but the parameters might persist. By explicitly controlling the cut points, you ensure the system behavior aligns with the specific mathematical requirements of the model architecture.