Implementing FSDP effectively requires more than simply instantiating the wrapper class. The performance of the sharding algorithm depends heavily on how the model layers are grouped, a process known as wrapping. If a model is treated as a single monolithic unit, FSDP must gather all parameters simultaneously during the forward pass, which negates the primary memory benefits of sharding. To optimize execution, the model must be segmented into smaller units, allowing the system to gather and release parameters dynamically.
This chapter examines the policies that dictate how a model is partitioned. You will configure auto-wrapping policies designed for Transformer architectures, ensuring that specific blocks, such as encoder or decoder layers, are handled individually. This granular approach allows the optimizer to maintain peak memory usage proportional to the size of a single shard plus the current working unit, rather than the full model size M. For architectures that do not fit standard patterns, we will look at constructing custom wrapping strategies using lambda functions to control exactly where the computational graph is cut.
We will also address initialization bottlenecks. Loading a massive model into CPU RAM before sharding often leads to out-of-memory errors on the host machine. You will learn to use the PyTorch meta device to initialize the model structure without allocating storage for the weights immediately. By combining delayed initialization with the reset_parameters() method, you can instantiate models where the total parameter count exceeds the available system memory. Finally, the text covers technical solutions for handling shared parameters to maintain synchronization across distributed ranks without duplicating storage.
2.1 Transformer Wrapping Policies
2.2 Custom Wrapping Strategies
2.3 Delayed Initialization and meta Device
2.4 Handling Shared Parameters
2.5 Code Practice: Advanced Wrapping Configuration
© 2026 ApX Machine LearningEngineered with