Reducing the precision of model parameters to BFloat16 or Float16 significantly lowers the static memory footprint of the model weights. However, during training, the transient memory required to store intermediate activations for backpropagation often exceeds the memory occupied by parameters. For large language models (LLMs) with long sequence lengths, this activation memory scales linearly with the number of layers and batch size, and quadratically with sequence length due to attention mechanisms. Activation checkpointing, also known as gradient checkpointing, addresses this bottleneck by discarding intermediate activations during the forward pass and recomputing them during the backward pass.The Recomputation StrategyStandard backpropagation requires the output of every operation in the computation graph to be stored in GPU memory until the gradients are calculated. For a Transformer model with $L$ layers, this results in a memory complexity of $O(L)$.Activation checkpointing alters this behavior by designating specific modules as "checkpoints." Only the inputs to these checkpointed modules are preserved in memory. All intermediate activations generated within the module are discarded. When the backward pass reaches a checkpointed module, the system performs a localized forward pass using the stored inputs to regenerate the necessary intermediate states. This effectively trades computational overhead for memory savings.The theoretical memory consumption with optimal checkpoint placement typically follows a square root relationship relative to the number of layers. If we divide a network of $N$ nodes into segments of length $\sqrt{N}$, we only store the $\sqrt{N}$ segment boundaries. During backpropagation, we recompute the $\sqrt{N}$ nodes within a segment.$$ Memory_{checkpoint} \approx O(\sqrt{N}) + O(BlockSize) $$The following diagram illustrates the difference in memory state between standard training and checkpointed training during the backward pass.digraph G { rankdir=TB; node [fontname="Helvetica", shape=box, style=filled, color="#dee2e6", fontcolor="#495057"]; edge [color="#adb5bd"]; bgcolor="transparent"; subgraph cluster_0 { label="Standard Backpropagation"; fontname="Helvetica"; fontcolor="#495057"; color="#ced4da"; style="dashed"; node [fillcolor="#a5d8ff"]; // Blue for kept s1 [label="Layer 1\n(Stored)"]; s2 [label="Layer 2\n(Stored)"]; s3 [label="Layer 3\n(Stored)"]; s4 [label="Layer 4\n(Stored)"]; s1 -> s2 -> s3 -> s4; } subgraph cluster_1 { label="Activation Checkpointing"; fontname="Helvetica"; fontcolor="#495057"; color="#ced4da"; style="dashed"; node [fillcolor="#ffc9c9"]; // Red for discarded c1 [label="Layer 1\n(Checkpoint)", fillcolor="#a5d8ff"]; // Blue kept c2 [label="Layer 2\n(Discarded)"]; c3 [label="Layer 3\n(Discarded)"]; c4 [label="Layer 4\n(Checkpoint)", fillcolor="#a5d8ff"]; // Blue kept c1 -> c2 -> c3 -> c4; } }Comparison of stored tensors in standard training versus activation checkpointing. In the checkpointing approach, intermediate layers are discarded and recomputed on demand.Integrating Checkpointing with FSDPImplementing activation checkpointing in an FSDP environment requires careful orchestration. Naive application of PyTorch's torch.utils.checkpoint on sub-modules can lead to conflicts with FSDP's sharding logic. FSDP shards parameters across GPUs, but checkpointing requires full parameter access during the recomputation phase.PyTorch provides the apply_activation_checkpointing utility specifically for this purpose. This function ensures that the wrapping order is correct: the FSDP wrapper must encompass the checkpoint wrapper, or vice versa, depending on the desired granularity. The recommended pattern for Transformers is to wrap each Transformer Block (Attention + FeedForward).When using apply_activation_checkpointing with FSDP, the utility automatically handles the checkpoint_wrapper. A critical configuration detail for experts is the choice between reentrant and non-reentrant wrappers.Non-Reentrant WrappersFor FSDP, it is strongly advised to use non-reentrant checkpointing (use_reentrant=False). The legacy reentrant variant relies on global state and backward hooks that can interfere with FSDP's state synchronization, potentially leading to deadlocks or incorrect gradients. The non-reentrant implementation treats the checkpointed region as a distinct autograd function, providing better stability and compatibility with distributed settings.Here is the implementation pattern for applying checkpointing to a Transformer-based model wrapped in FSDP:from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, CheckpointImpl, apply_activation_checkpointing, ) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP # Define the check function to identify transformer blocks # Assuming model structure has a class 'TransformerBlock' check_fn = lambda submodule: isinstance(submodule, TransformerBlock) # Apply checkpointing apply_activation_checkpointing( model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn, ) # Wrap in FSDP after applying checkpointing logic # Note: In recent PyTorch versions, applying to FSDP module # directly is also supported and often preferred. fsdp_model = FSDP(model, auto_wrap_policy=...)Throughput vs. Memory AnalysisActivation checkpointing introduces a computational cost. Because the forward pass for the checkpointed segments runs twice (once for the loss calculation, once during backpropagation), the computational overhead is approximately 20% to 30%.However, this overhead is often justifiable. The reduction in memory usage allows for a significant increase in micro-batch size. Larger batch sizes improve GPU occupancy and arithmetic intensity, potentially offsetting the recomputation cost. In scenarios where the model fits in memory without checkpointing but with a very small batch size, the GPU compute units may be underutilized. By enabling checkpointing and increasing the batch size, you often achieve higher overall tokens-per-second throughput despite the recomputation penalty.The following chart demonstrates the relationship between batch size and memory consumption with and without checkpointing.{ "layout": { "title": "Memory Usage vs Batch Size (7B Param Model)", "xaxis": { "title": "Micro-Batch Size", "showgrid": true, "gridcolor": "#dee2e6" }, "yaxis": { "title": "Peak Memory (GB)", "showgrid": true, "gridcolor": "#dee2e6" }, "plot_bgcolor": "white", "showlegend": true }, "data": [ { "x": [1, 2, 4, 8, 16, 32], "y": [24, 32, 48, 80, 144, 272], "type": "scatter", "mode": "lines+markers", "name": "Standard Training", "line": {"color": "#fa5252", "width": 3}, "marker": {"size": 8} }, { "x": [1, 2, 4, 8, 16, 32, 64], "y": [18, 20, 24, 32, 48, 80, 144], "type": "scatter", "mode": "lines+markers", "name": "Activation Checkpointing", "line": {"color": "#228be6", "width": 3}, "marker": {"size": 8} }, { "x": [1, 64], "y": [80, 80], "type": "line", "mode": "lines", "name": "GPU Memory Limit (80GB)", "line": {"color": "#868e96", "width": 2, "dash": "dashdot"} } ] }Memory scaling behavior. The red line indicates standard training hitting the 80GB VRAM limit at batch size 8. The blue line shows checkpointing enabling batch sizes up to 32 within the same hardware constraints.Selective Checkpointing and OffloadingFor extreme scale models where even standard checkpointing is insufficient, PyTorch FSDP allows for selective checkpointing. Instead of checkpointing every transformer layer, you might choose to checkpoint every $n$-th layer. This provides a granular dial to balance memory savings against computational overhead.Furthermore, the checkpoint_wrapper supports CPU offloading. While FSDP handles parameter offloading, activation offloading is distinct. By configuring offload_to_cpu=True in the checkpoint wrapper, the preserved inputs for the checkpointed modules are moved to system RAM and prefetched back to the GPU immediately before the recomputation step. This is particularly effective when PCIe bandwidth is not the bottleneck, allowing the training of models that are significantly larger than the aggregate GPU memory capacity of the cluster.$$ Total_Throughput \propto \frac{BatchSize}{StepTime} $$When optimizing, the goal is to maximize the equation above. If checkpointing increases StepTime by 30% but allows BatchSize to increase by 100% without Out-Of-Memory (OOM) errors, the net result is a substantial gain in training efficiency.