Optimizing distributed training is fundamentally about increasing the density of computation per unit of time while minimizing the latency of data movement. Once functional correctness is established, the focus shifts to maximizing the Model Flops Utilization (MFU). This metric quantifies how effectively your training loop utilizes the theoretical peak performance of the underlying hardware. In high-performance computing environments, achieving 50% to 60% MFU on Large Language Models (LLMs) is considered an excellent result, while unoptimized FSDP setups often languish below 30%.Measuring Model Flops UtilizationRaw throughput measured in "samples per second" or "tokens per second" helps compare runs, but it does not account for hardware capability or model architecture changes. MFU provides a normalized efficiency score. To calculate this, we first estimate the floating-point operations required for a single training step.For a Transformer-based model, the number of floating-point operations (FLOPs) per token is roughly proportional to the parameter count. The standard approximation for a forward and backward pass without activation checkpointing is:$$ C_{\text{step}} \approx 6 \cdot P \cdot D_{\text{batch}} $$Where $P$ represents the number of trainable parameters and $D_{\text{batch}}$ is the total number of tokens in the global batch (sequence length $\times$ batch size). The factor of 6 arises from the forward pass ($2P$) and the backward pass ($4P$).However, when training large models with FSDP, activation checkpointing (also known as gradient checkpointing) is almost always enabled to save memory. This technique requires re-computing the forward pass during the backward phase. Consequently, the compute cost increases:$$ C_{\text{checkpointed}} \approx 8 \cdot P \cdot D_{\text{batch}} $$To determine MFU, we divide the achieved FLOPs per second by the GPU's peak theoretical throughput (e.g., 312 TFLOPS for an NVIDIA A100 using BF16 tensor cores).$$ \text{MFU} = \frac{C_{\text{checkpointed}} / \text{Step Time (s)}}{\text{Peak Device FLOPs} \times \text{Number of GPUs}} $$A low MFU indicates that the GPU execution units are stalling, likely due to memory bandwidth limitations (HBM bound), communication overhead (network bound), or kernel launch latencies (latency bound).Optimization workflow for systematically improving training throughput.digraph G { rankdir=TB; node [fontname="Sans-Serif", shape=box, style=filled, color="#dee2e6", fillcolor="#f8f9fa"]; edge [color="#adb5bd"]; start [label="Measure Baseline Throughput", fillcolor="#e7f5ff", color="#74c0fc"]; calc_mfu [label="Calculate MFU", fillcolor="#e7f5ff", color="#74c0fc"]; check_bound [label="Identify Bottleneck", shape=diamond, fillcolor="#fff3bf", color="#fcc419"]; mem_bound [label="Memory Bound\n(Low Arithmetic Intensity)", fillcolor="#ffe3e3", color="#ff8787"]; comm_bound [label="Communication Bound\n(High NCCL Wait)", fillcolor="#ffe3e3", color="#ff8787"]; lat_bound [label="Latency Bound\n(Small Kernels)", fillcolor="#ffe3e3", color="#ff8787"]; opt_mem [label="Apply FlashAttention\nIncrease Micro-Batch Size", fillcolor="#d3f9d8", color="#69db7c"]; opt_comm [label="Tune Bucket Size\nBackward Prefetching", fillcolor="#d3f9d8", color="#69db7c"]; opt_lat [label="CUDA Graph Capture\nFuse Optimizers", fillcolor="#d3f9d8", color="#69db7c"]; start -> calc_mfu; calc_mfu -> check_bound; check_bound -> mem_bound [label=" High HBM Util"]; check_bound -> comm_bound [label=" High Idle"]; check_bound -> lat_bound [label=" Low GPU Util"]; mem_bound -> opt_mem; comm_bound -> opt_comm; lat_bound -> opt_lat; opt_mem -> start; opt_comm -> start; opt_lat -> start; }Maximizing Arithmetic IntensityThe most common reason for low MFU in FSDP training is a low micro-batch size. GPUs perform best when they operate on large contiguous blocks of data. If the local batch size per GPU is too small, the system becomes memory-bandwidth bound; the compute cores spend more time waiting for data to arrive from High Bandwidth Memory (HBM) than performing matrix multiplications.To resolve this, increase the micro-batch size per GPU until you near the Out-Of-Memory (OOM) limit. This increases arithmetic intensity, the ratio of FLOPs performed to bytes accessed. If the global batch size is fixed by convergence hyperparameters, use Gradient Accumulation to maintain the global batch size while maximizing the micro-batch size on the hardware.For example, if your target global batch size is 1024 and you have 64 GPUs:Scenario A: Micro-batch 16 per GPU ($16 \times 64 = 1024$). Gradient Accumulation steps = 1.Scenario B: Micro-batch 4 per GPU ($4 \times 64 = 256$). Gradient Accumulation steps = 4.Scenario A is strictly superior for throughput because it launches fewer, larger kernels, reducing overhead and better saturating the tensor cores. Scenario B should only be used if Scenario A results in OOM.IO-Aware Kernels and FlashAttentionIn Transformer architectures, the self-attention mechanism is quadratically expensive with respect to sequence length. Standard implementations read and write the $N \times N$ attention matrix to HBM, which creates a massive memory bottleneck.Integrating FlashAttention (v2 or newer) is mandatory for high-performance training. It fuses the attention operation into a single kernel, keeping the attention matrix in the GPU's fast SRAM (L1/shared memory) and avoiding round-trips to HBM. This not only speeds up calculation but also reduces memory footprint, allowing for larger batch sizes. In PyTorch FSDP, ensuring your model wraps F.scaled_dot_product_attention (which dispatches to FlashAttention when available) is a high-priority optimization.Optimizing Communication FrequencyFSDP introduces communication overhead by gathering sharded parameters (AllGather) before computation and synchronizing gradients (ReduceScatter) after computation. The frequency and size of these messages impact throughput.The limit_all_gathers setting in FSDP configuration controls whether the GPU releases the gathered shards immediately after the forward pass of a layer or holds them for the backward pass.True (Default): Saves memory by freeing weights immediately. Requires re-gathering weights during the backward pass (more communication).False: Holds gathered weights until the backward pass is complete. reduces communication volume but increases peak memory usage.If profiling reveals significant gaps due to NCCL AllGather operations and VRAM is not fully saturated, disabling limit_all_gathers for specific layers or the entire model can yield substantial speedups.Additionally, tuning the bucket_cap_mb parameter controls the size of the data chunks sent over the network. Small buckets increase the number of NCCL calls (latency overhead), while buckets that are too large may prevent effective overlap between computation and communication. A value between 25MB and 100MB is generally optimal for modern clusters.Impact of cumulative optimizations on training throughput (TFLOPS) for a 7B parameter model on A100 GPUs.{ "layout": { "title": "Throughput Gains from Optimization Layers", "xaxis": { "title": "Optimization Stage" }, "yaxis": { "title": "Throughput (TFLOPS per GPU)" }, "barmode": "group", "plot_bgcolor": "#f8f9fa", "paper_bgcolor": "#ffffff", "font": { "family": "Sans-Serif", "color": "#495057" } }, "data": [ { "type": "bar", "x": ["Baseline (FP32)", "Mixed Precision (BF16)", "+ Activation Checkpointing", "+ FlashAttention v2", "+ Tuned Batch/Comm"], "y": [45, 110, 135, 168, 195], "marker": { "color": ["#adb5bd", "#74c0fc", "#4dabf7", "#339af0", "#228be6"] }, "text": ["14% MFU", "35% MFU", "43% MFU", "54% MFU", "62% MFU"], "textposition": "auto" } ] }Data Loading and CPU BottlenecksWhile focus typically rests on the GPU, the CPU can silently throttle training speed. If the DataLoader cannot feed the GPU fast enough, the GPU execution timeline will show gaps between steps where no kernels are running. This is often visible in the PyTorch Profiler as "DataLoaderNext" taking significant time.To mitigate this:Pin Memory: Always set pin_memory=True in the DataLoader to use page-locked memory, accelerating the transfer from host to device.Num Workers: Set num_workers typically to the number of CPU cores divided by the number of GPUs per node.Pre-fetching: Ensure the data loader pre-fetches batches so the next batch is ready immediately when the current step finishes.By systematically addressing arithmetic intensity, kernel efficiency, communication overhead, and data logistics, you transform a functional distributed setup into a high-performance training engine capable of iterating through terabytes of data in feasible timeframes.