Parameter sharing, frequently referred to as weight tying, is a standard architectural efficiency in modern Transformer models. The most common instance involves tying the input embedding matrix to the output language modeling head. Mathematically, if the input embedding is $W_E \in \mathbb{R}^{V \times d}$, the output layer often utilizes the same weights, effectively computing logits as $y = x W_E^T$. This reduces the parameter count significantly, often saving hundreds of millions of parameters in large vocabulary models.However, Fully Sharded Data Parallel (FSDP) introduces complexity when handling these shared references. FSDP operates by flattening parameters into a single contiguous FlatParameter within each wrapped unit. If two modules share a parameter but are assigned to different FSDP units (sharding groups), the system faces a conflict: it cannot assign the same underlying storage to two distinct flattened vectors managed by different communication streams.The Sharding Conflict with Shared WeightsWhen FSDP initializes, it traverses the module hierarchy to identify parameters. It respects the Python object identity (checked via id()) to detect shared weights. If the system encounters a parameter that is already managed by an existing FSDP unit, it must decide how to proceed.The primary constraint in FSDP is that shared parameters must belong to the same FSDP unit. They cannot be sharded across different boundaries. If Module A and Module B share a weight $W$, you cannot wrap Module A in one FSDP instance and Module B in another. Doing so would force $W$ to exist in two separate FlatParameter storages, breaking the synchronization logic and effectively untying the weights during optimization.The following diagram illustrates the structural difference between valid and invalid wrapping hierarchies for shared parameters.digraph G { rankdir=TB; node [shape=box, style="filled", fontname="Arial", fontsize=12]; edge [color="#adb5bd"]; subgraph cluster_invalid { label = "Invalid Wrapping Strategy"; style = dashed; color = "#fa5252"; fontcolor = "#fa5252"; inv_root [label="Model Root", fillcolor="#e9ecef"]; inv_fsdp1 [label="FSDP(Embedding)", fillcolor="#ffc9c9"]; inv_fsdp2 [label="FSDP(LM Head)", fillcolor="#ffc9c9"]; inv_weight [label="Shared Weight W", shape=ellipse, fillcolor="#eebefa"]; inv_root -> inv_fsdp1; inv_root -> inv_fsdp2; inv_fsdp1 -> inv_weight [label="Owns", color="#fa5252"]; inv_fsdp2 -> inv_weight [label="Conflict", style=dashed, color="#fa5252"]; } subgraph cluster_valid { label = "Valid Wrapping Strategy"; style = solid; color = "#40c057"; fontcolor = "#40c057"; val_root [label="FSDP(Model Root)", fillcolor="#b2f2bb"]; val_emb [label="Embedding", fillcolor="#ffffff"]; val_head [label="LM Head", fillcolor="#ffffff"]; val_weight [label="Shared Weight W", shape=ellipse, fillcolor="#eebefa"]; val_root -> val_emb; val_root -> val_head; val_emb -> val_weight; val_head -> val_weight; val_root -> val_weight [label="Manages Single Storage", color="#40c057"]; } }Visualizing ownership conflicts. In the invalid strategy, two FSDP instances attempt to shard the same memory address. In the valid strategy, the shared weight remains within a single FSDP scope.Implementing Safe Wrapping PoliciesTo handle shared parameters correctly, you must design your auto-wrapping policy to exclude the modules containing shared weights from individual wrapping. Instead, you allow the top-level FSDP wrapper to manage them.Consider a standard GPT-style architecture where the transformer blocks are computationally heavy, but the embedding and head share weights. The goal is to wrap the transformer layers individually to save memory while keeping the embedding and head in the root FSDP unit.Here is how to construct a wrapping policy that respects parameter sharing:import torch import torch.nn as nn from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, MixedPrecision, ) from torch.distributed.fsdp.wrap import ( transformer_auto_wrap_policy, lambda_auto_wrap_policy, ) class TransformerBlock(nn.Module): def __init__(self, dim): super().__init__() self.attn = nn.Linear(dim, dim) self.mlp = nn.Linear(dim, dim) class GPTModel(nn.Module): def __init__(self, vocab_size, dim, layers): super().__init__() # Shared weight logic happens here self.token_emb = nn.Embedding(vocab_size, dim) self.layers = nn.ModuleList([TransformerBlock(dim) for _ in range(layers)]) self.lm_head = nn.Linear(dim, vocab_size, bias=False) # Explicitly tie weights self.lm_head.weight = self.token_emb.weight def get_wrapping_policy(): """ Returns a policy that wraps TransformerBlock instances but leaves Embedding and LM Head for the root FSDP unit. """ return functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={TransformerBlock}, ) # Setup code inside the training loop model = GPTModel(vocab_size=32000, dim=1024, layers=12) # The policy wraps blocks individually. # The token_emb and lm_head remain unwrapped until the root FSDP call. fsdp_model = FSDP( model, auto_wrap_policy=get_wrapping_policy(), mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), device_id=torch.cuda.current_device() )In this configuration, TransformerBlock instances are wrapped and sharded immediately. The token_emb and lm_head are ignored by the auto-wrap policy. Consequently, they fall into the scope of the outer fsdp_model. Since they are both managed by the same root FSDP unit, the shared weight self.token_emb.weight is flattened once, and both modules reference the correct index in the flattened storage.Verification of Parameter IdentityWhen working with distributed setups, silence is not success. It is important to verify that the weights remain tied after FSDP initialization. If the wrapping logic accidentally breaks the link, the model might train with dissociated embeddings and output heads, leading to poor convergence.You can verify the identity of the storage pointers after wrapping. Note that FSDP modifies param.data to point to a view into the FlatParameter.def verify_weight_tying(fsdp_model): # Access the underlying modules (requires unravelling if heavily nested) # In this simple example, we assume attribute access # Note: We must access the original module structure. # FSDP wraps the module, so we might need to look at fsdp_model.module or similar # depending on the exact hierarchy. embedding_weight = fsdp_model.token_emb.weight head_weight = fsdp_model.lm_head.weight # 1. Check Python Object Identity # This might change if FSDP replaces parameters with views, # but the underlying storage data_ptr should match. ptr_emb = embedding_weight.data_ptr() ptr_head = head_weight.data_ptr() if ptr_emb == ptr_head: print(f"Success: Weights are tied. Storage Pointer: {ptr_emb}") else: print(f"WARNING: Weights are NOT tied. {ptr_emb} != {ptr_head}") # Call this on Rank 0 before training verify_weight_tying(fsdp_model)Shared Parameters and Initialization on Meta DeviceCombining shared parameters with meta device initialization requires careful ordering. As discussed in the "Delayed Initialization" section, we often create models on the meta device to avoid OOM errors. When you materialize these weights (move them to CPU or CUDA), you must ensure the tying relationship is re-established before wrapping with FSDP.If you initialize on meta device:Create model on meta.Materialize parameters (reset parameters).Re-apply weight tying.Wrap with FSDP.Often, the reset_parameters() or materialization process allocates new memory for every module, breaking the reference created in __init__. You must explicitly re-tie the weights:# After materializing weights from meta device model.to_empty(device="cuda") model.apply(init_weights_fn) # CRITICAL: Re-tie weights explicitly before FSDP wrapping model.lm_head.weight = model.token_emb.weight # Now safe to wrap fsdp_model = FSDP(model, ...)Memory Implications of Weight TyingWhile weight tying saves parameters, it can create a communication bottleneck at the boundaries of the network. In FSDP, the root unit containing the embeddings and head is often quite large. For a vocabulary size $V=128,000$ and dimension $d=4096$, the embedding table alone is roughly 1GB in BF16.Because this large parameter set is in the root FSDP unit, it must be gathered (all-gathered) at the very beginning of the forward pass (for embeddings) and again at the end (for the head). This can create memory pressure if the root unit is not sharded efficiently.The chart below depicts the memory allocation timeline during a forward pass when the embedding and head are in the root unit versus distributed throughout the layers.{ "layout": { "title": "Memory Spikes with Shared Embeddings (Root FSDP Unit)", "xaxis": { "title": "Execution Steps (Forward Pass)", "showgrid": false }, "yaxis": { "title": "Allocated Memory (GB)", "showgrid": true }, "showlegend": true, "margin": {"l": 60, "r": 30, "t": 50, "b": 50} }, "data": [ { "x": ["Start", "Emb Gather", "Layer 1", "Layer 6", "Layer 12", "Head Gather", "End"], "y": [2, 12, 4, 4, 4, 12, 2], "type": "scatter", "mode": "lines+markers", "name": "Active Param Memory", "line": {"color": "#4dabf7", "width": 3} }, { "x": ["Start", "Emb Gather", "Layer 1", "Layer 6", "Layer 12", "Head Gather", "End"], "y": [2, 4, 4, 4, 4, 4, 2], "type": "scatter", "mode": "lines", "name": "Ideal (No Root Spike)", "line": {"color": "#adb5bd", "dash": "dot"} } ] }Active memory usage during forward pass. The "Root Spike" occurs because the large shared embedding layer must be gathered at the start and end, exceeding the memory footprint of internal transformer layers.If the shared embedding is excessively large, it may force you to use cpu_offload=True for the root unit or implement sequence parallelism to shard the activation memory, compensating for the parameter spike. Proper management of these shared resources is essential for training models where the vocabulary size contributes significantly to the total parameter count.