High utilization numbers reported by nvidia-smi often deceive engineers into believing their training pipeline is efficient. A generic utility metric only indicates that a CUDA kernel is resident on the device; it does not distinguish between high-throughput matrix multiplications and memory-bound operations, nor does it account for the device waiting on data from the CPU or network. To effectively optimize FSDP training, you must bypass high-level metrics and analyze the execution timeline directly through trace files.Anatomy of a GPU TraceA standard PyTorch profiler trace, viewed in Perfetto or Chrome Trace Viewer, organizes execution data into horizontal tracks. For distributed training on NVIDIA GPUs, three specific categories of tracks provide the necessary signals for performance debugging.The CPU Python Thread track displays the high-level application logic, including the nn.Module calls and the dataloader loop. Beneath this lies the CPU CUDA Runtime track, which logs the API calls made to the CUDA driver (e.g., cudaLaunchKernel, cudaMemcpyAsync). Finally, the GPU Stream tracks visualize the actual execution of kernels on the hardware.In an optimized FSDP setup, the relationship between these tracks is strictly hierarchical but asynchronous. The CPU dispatches kernels to a launch queue, and the GPU consumes them. Performance degradation typically manifests as a breakdown in this producer-consumer relationship.digraph G { rankdir=TB; node [shape=box, style=filled, fontname="Helvetica", fontsize=10]; edge [fontname="Helvetica", fontsize=9]; subgraph cluster_cpu { label = "Host (CPU)"; style = filled; color = "#e9ecef"; python [label="Python Interpreter\n(FSDP Forward/Backward)", fillcolor="#a5d8ff", color="#1c7ed6"]; runtime [label="CUDA Runtime API\n(Kernel Launch)", fillcolor="#bac8ff", color="#4263eb"]; } subgraph cluster_gpu { label = "Device (GPU)"; style = filled; color = "#e9ecef"; stream_c [label="Compute Stream\n(GEMM / Elem-wise)", fillcolor="#b2f2bb", color="#37b24d"]; stream_n [label="NCCL Stream\n(AllGather / ReduceScatter)", fillcolor="#ffc9c9", color="#f03e3e"]; } python -> runtime [label=" Dispatch"]; runtime -> stream_c [label=" Async Launch"]; runtime -> stream_n [label=" Async Launch"]; stream_c -> stream_n [label=" Synchronization\n(Event Wait)", style=dashed]; }Interaction between host dispatch and device execution streams in a distributed environment.Identifying CPU-Bound ExecutionThe most common pathology in large-scale training is "GPU starvation," where the GPU completes its work queue faster than the CPU can replenish it. In the profiler trace, this appears as gaps or whitespace between blocks on the GPU Stream track.If you observe small, distinct gaps between every kernel execution, your system suffers from launch overhead. PyTorch eager execution incurs a fixed cost per kernel launch (roughly 3-10 microseconds). When FSDP wraps a model with many small layers (such as a Transformer with low hidden dimensions relative to the batch size), the ratio of launch overhead to compute time increases.The metric to monitor here is the relative density of the GPU timeline. In a healthy trace, kernel blocks are tightly packed, appearing as a solid bar of color. Significant whitespace implies the CPU is struggling to serialize the graph. This is often resolved by enabling torch.compile to fuse pointwise operations, thereby reducing the total number of kernel launches required per step.Communication and Computation OverlapFSDP relies heavily on overlapping communication (NCCL operations) with computation. During the forward pass, while the GPU computes layer $N$, the system should simultaneously AllGather the parameters for layer $N+1$.To verify this behavior, locate the NCCL kernels in the GPU tracks. These are typically labeled ncclKernel, ncclAllGatherRingLL, or similar variants depending on the backend version. You should see these communication kernels executing on a separate stream parallel to the compute kernels (e.g., sgemm, elementwise_kernel).If the NCCL kernels and compute kernels execute sequentially, meaning one stream is idle while the other is active, your overlap configuration is failing. This serial execution results from improper CUDA stream prioritization or limits in the NCCL process group configuration.The efficiency of this overlap is quantifiable. Let $T_{compute}$ be the time taken for the forward/backward pass of a layer, and $T_{comm}$ be the time required to gather or scatter its parameters. The exposed communication overhead $T_{exposed}$ is:$$ T_{exposed} = \max(0, T_{comm} - T_{compute}) $$Ideal scaling occurs when $T_{exposed}$ approaches zero. If the profiler shows $T_{comm}$ significantly exceeding $T_{compute}$, the network bandwidth is the bottleneck, and simply adding more GPUs may not improve training speed linearly.Analyzing the Backward Pass StructureThe backward pass in FSDP is complex because it involves two distinct communication primitives: AllGather (to retrieve full parameters for gradient computation) and ReduceScatter (to synchronize gradients and shard them immediately).A trace of the backward pass should demonstrate a "zigzag" pattern across streams. As the backward computation for layer $L$ finishes:A ReduceScatter operation for layer $L$ gradients begins on the communication stream.An AllGather operation for layer $L-1$ parameters begins (prefetching).The compute stream executes the backward pass for layer $L-1$.{ "layout": { "title": "Ideal FSDP Backward Pass Overlap", "height": 300, "margin": {"l": 100, "r": 20, "t": 40, "b": 20}, "xaxis": { "title": "Time (microseconds)", "showgrid": true, "zeroline": false }, "yaxis": { "showgrid": false, "zeroline": false, "ticktext": ["Compute Stream", "NCCL Stream"], "tickvals": [1, 0] }, "showlegend": true }, "data": [ { "type": "bar", "y": [1, 1, 1], "x": [300, 300, 300], "base": [0, 350, 700], "orientation": "h", "name": "Compute (Backward)", "marker": {"color": "#339af0"}, "text": ["Layer N", "Layer N-1", "Layer N-2"], "textposition": "auto" }, { "type": "bar", "y": [0, 0, 0], "x": [200, 200, 200], "base": [50, 400, 750], "orientation": "h", "name": "Comm (ReduceScatter)", "marker": {"color": "#fa5252"}, "text": ["Grad N", "Grad N-1", "Grad N-2"], "textposition": "auto" }, { "type": "bar", "y": [0, 0], "x": [250, 250], "base": [0, 350], "orientation": "h", "name": "Comm (AllGather Prefetch)", "marker": {"color": "#fcc419"}, "text": ["Param N-1", "Param N-2"], "textposition": "auto" } ] }Gantt chart depicting optimal stream concurrency. Note how communication operations for adjacent layers execute simultaneously with the computation of the current layer.In a suboptimal trace, you will observe "Wait" kernels (often denoted as cudaStreamSynchronize or cudaEventSynchronize) on the CPU track extending for significant durations. This indicates that the Python code is blocked waiting for the GPU to finish a communication event before it can dispatch the next compute kernel. This defeats the purpose of asynchronous execution.Memory Allocator FragmentationAside from compute and communication, the profiler reveals memory management overhead. Frequent calls to cudaMalloc and cudaFree are expensive. PyTorch uses a caching allocator to mitigate this, but in FSDP, the constant allocation and deallocation of full parameter shards can lead to fragmentation.In the profiler, look for the memory timeline track. A sawtooth pattern is expected and healthy in FSDP as layers are materialized and discarded. However, if you see spikes in cudaMalloc calls during the middle of the training step (not just the first few iterations), it suggests the caching allocator is thrashing. The allocator is forced to search for free blocks or split existing ones, which stalls the CPU dispatch thread. This is often rectified by setting the PYTORCH_CUDA_ALLOC_CONF environment variable to adjust the max_split_size_mb, forcing the allocator to keep larger contiguous blocks available for the recurrent parameter expansions.