Let's translate the theoretical understanding of structured pruning into practice. Unlike unstructured pruning, which zeros out individual weights, structured pruning removes entire groups of parameters, such as attention heads or neurons within feed-forward network (FFN) layers. This approach creates regular sparsity patterns that can potentially lead to more significant inference speedups on hardware designed to exploit such structures, although it often requires more careful implementation and fine-tuning to maintain model performance.
In this practical exercise, we will focus on implementing attention head pruning for a transformer-based model. This involves identifying and removing the least important attention heads across the model's layers.
Implement structured pruning by removing a fixed percentage of attention heads from a pre-trained transformer model and evaluate the impact on model size and a relevant performance metric.
We'll use the Hugging Face transformers
library along with PyTorch. Ensure you have these installed. We'll work with a smaller, pre-trained transformer model, like bert-base-uncased
or distilbert-base-uncased
, to make computations manageable.
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import numpy as np
# Load a pre-trained model and tokenizer
model_name = "distilbert-base-uncased" # Use a manageable model
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Example: Prepare some dummy data for importance calculation or evaluation
dummy_texts = ["This is an example sentence.", "Another example for testing."]
inputs = tokenizer(dummy_texts, return_tensors="pt", padding=True, truncation=True)
print(f"Model: {model_name}")
print(f"Number of parameters: {model.num_parameters()}")
# Accessing transformer block structure (example for DistilBERT)
# Note: Structure varies between models (BERT, GPT, etc.)
transformer_blocks = model.distilbert.transformer.layer
num_layers = len(transformer_blocks)
num_heads = model.config.num_attention_heads
head_dim = model.config.dim // num_heads
print(f"Layers: {num_layers}, Heads per layer: {num_heads}, Head dim: {head_dim}")
We need a metric to rank the importance of each attention head. A common approach is to use the magnitude (L1 or L2 norm) of the weights associated with each head. Specifically, we can look at the output projection weight (WO) within the self-attention mechanism for each head. Heads with smaller norm weights are considered less important.
Let's outline the process for calculating the L2 norm for each head's output projection weights:
head_importances = []
for layer_idx in range(num_layers):
attention_layer = transformer_blocks[layer_idx].attention
# Output projection weight matrix: shape (hidden_dim, hidden_dim)
W_O = attention_layer.out_lin.weight.data # Shape: (768, 768) for distilbert-base
# W_O combines outputs from all heads. Each head contributes a slice.
# W_O can be seen as concat([W_O_h1, W_O_h2, ..., W_O_hN]) where each W_O_hi has shape (head_dim, hidden_dim)
# but transposed in the actual matrix.
# So, we need to calculate norm column-wise per head for W_O.
# Effective shape after reshaping for head view: (hidden_dim, num_heads, head_dim)
# We want the norm of the projection *from* each head's output space.
# Let's compute the L2 norm for the weights projecting *from* each head's output.
# The output linear layer weight matrix W_O has shape [dim, dim].
# It can be viewed as concatenating matrices, each of shape [dim, head_dim], one for each head.
# W_O = [W_O_1 | W_O_2 | ... | W_O_num_heads], where W_O_i is [dim, head_dim]
layer_head_norms = []
for head_idx in range(num_heads):
# Extract the weights corresponding to the output projection of head_idx
# Shape: (hidden_dim, head_dim)
head_weights = W_O[:, head_idx * head_dim : (head_idx + 1) * head_dim]
norm = torch.linalg.norm(head_weights).item()
layer_head_norms.append(norm)
head_importances.append(layer_head_norms) # Store norms for each head in this layer
# Flatten the list of lists into a single list of (layer_idx, head_idx, importance) tuples
all_head_importances = []
for layer_idx, norms in enumerate(head_importances):
for head_idx, norm in enumerate(norms):
all_head_importances.append(((layer_idx, head_idx), norm))
# Sort heads globally by importance (ascending)
all_head_importances.sort(key=lambda x: x[1])
print(f"Calculated importance for {len(all_head_importances)} heads.")
# Display a few least important heads
print("Least important heads (layer, head):")
for i in range(min(10, len(all_head_importances))):
print(f" Layer {all_head_importances[i][0][0]}, Head {all_head_importances[i][0][1]}: Norm = {all_head_importances[i][1]:.4f}")
Note: Importance calculation can be more sophisticated, involving activation analysis or gradient information during a forward/backward pass, but weight norm is a common and simpler starting point.
Let's set a target sparsity level. For instance, we might aim to prune 20% of the total attention heads.
target_sparsity = 0.20 # Prune 20% of heads
total_heads = num_layers * num_heads
num_heads_to_prune = int(total_heads * target_sparsity)
# Get the heads with the lowest importance scores
heads_to_prune = {head_info[0] for head_info in all_head_importances[:num_heads_to_prune]}
print(f"Total heads: {total_heads}")
print(f"Target sparsity: {target_sparsity*100:.1f}%")
print(f"Number of heads to prune: {num_heads_to_prune}")
# print(f"Heads identified for pruning: {sorted(list(heads_to_prune))}") # Uncomment to see the list
Applying structured pruning involves creating masks to zero out the parameters associated with the selected heads. This requires careful handling of the weight matrices for Query (Q), Key (K), Value (V), and Output (O) projections within each attention layer.
The Q, K, and V weights are often stored combined in matrices like q_lin.weight
, k_lin.weight
, v_lin.weight
(shape [hidden_dim, hidden_dim]
) or sometimes combined into one large in_proj_weight
. The out_lin.weight
(shape [hidden_dim, hidden_dim]
) combines the outputs. We need to identify the rows/columns corresponding to the specific heads being pruned.
Let's illustrate masking the K, V, and Output projection weights for a single head.
def create_mask(param_shape, head_idx_to_prune, num_heads, head_dim, prune_dim):
"""Creates a mask for a weight tensor based on head index."""
mask = torch.ones(param_shape)
start_index = head_idx_to_prune * head_dim
end_index = start_index + head_dim
if prune_dim == 0: # Pruning rows (e.g., for Q, K, V weights if shape is [hidden_dim, hidden_dim])
mask[start_index:end_index, :] = 0
elif prune_dim == 1: # Pruning columns (e.g., for O weights if shape is [hidden_dim, hidden_dim])
mask[:, start_index:end_index] = 0
return mask
# Apply pruning (conceptually - using permanent modification here for simplicity)
# In practice, use torch.nn.utils.prune for proper masking and potential removal
for layer_idx, head_idx in heads_to_prune:
attention_layer = transformer_blocks[layer_idx].attention
# --- Prune Q, K, V weights ---
# Shape: [hidden_dim, hidden_dim]. Need to prune rows corresponding to the head output features.
q_weight = attention_layer.q_lin.weight
k_weight = attention_layer.k_lin.weight
v_weight = attention_layer.v_lin.weight
# We consider the output dimension of Q, K, V for a head as the target to prune
q_mask = create_mask(q_weight.shape, head_idx, num_heads, head_dim, prune_dim=0)
k_mask = create_mask(k_weight.shape, head_idx, num_heads, head_dim, prune_dim=0)
v_mask = create_mask(v_weight.shape, head_idx, num_heads, head_dim, prune_dim=0)
# Apply mask (directly modifying weights here)
with torch.no_grad():
q_weight.data *= q_mask
k_weight.data *= k_mask
v_weight.data *= v_mask
# Also prune corresponding biases if they exist and are structured per head (often they are not)
if attention_layer.q_lin.bias is not None:
# Bias shape is typically [hidden_dim], prune the slice corresponding to the head
bias_mask = create_mask(attention_layer.q_lin.bias.shape, head_idx, num_heads, head_dim, prune_dim=0) # dim 0 for bias vector
attention_layer.q_lin.bias.data *= bias_mask[:, 0] # Use first column of mask
# Repeat for k_lin.bias, v_lin.bias if they exist
# --- Prune Output projection weights ---
# Shape: [hidden_dim, hidden_dim]. Need to prune columns corresponding to the head input features.
o_weight = attention_layer.out_lin.weight
o_mask = create_mask(o_weight.shape, head_idx, num_heads, head_dim, prune_dim=1) # Prune columns for output proj
with torch.no_grad():
o_weight.data *= o_mask
# Output bias (out_lin.bias) is usually a single vector of size [hidden_dim] and is not typically pruned per-head.
print(f"Applied pruning masks to {len(heads_to_prune)} heads.")
This visualization shows how structured pruning removes entire components (like an attention head), unlike unstructured pruning which scatters zeros.
Comparison of unstructured vs. structured sparsity patterns. Structured pruning removes entire blocks (e.g., Head 2), potentially enabling hardware acceleration.
Important Implementation Note: The torch.nn.utils.prune
module offers more robust ways to handle pruning, including managing masks persistently and functions like prune.remove
to make pruning permanent by actually removing the zeroed parameters (if the structure allows, which is complex for head pruning). For production scenarios, using such utilities or specialized libraries (like NVIDIA's FasterTransformer or sparsity-aware compilers) is recommended. Directly zeroing weights as shown here illustrates the concept but might not yield speedups alone.
After pruning, we need to assess the impact.
# Example: Recalculate parameters (conceptual - requires detailed check)
# A simple way is to count non-zero elements if weights were zeroed directly
# Note: torch.nn.utils.prune handles this more formally
non_zero_params = sum(p.nonzero().size(0) for p in model.parameters() if p.requires_grad)
total_params = model.num_parameters()
print(f"Original parameters: {total_params}")
print(f"Parameters after pruning (non-zero): {non_zero_params}")
print(f"Reduction: {(total_params - non_zero_params) / total_params * 100:.2f}%")
# Example: Evaluate performance (requires a proper evaluation dataset and task)
# model.eval()
# with torch.no_grad():
# outputs = model(**inputs)
# logits = outputs.logits
# # ... calculate accuracy, perplexity, or other relevant metric ...
# print("Evaluation results need a proper dataset and metric.")
Structured pruning can sometimes cause a noticeable drop in performance. Fine-tuning the pruned model on the original task (or a downstream task) for a short duration with a low learning rate can help recover lost accuracy.
# Pseudocode for fine-tuning setup
# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
# model.train()
# for epoch in range(num_finetune_epochs):
# for batch in fine_tuning_dataloader:
# optimizer.zero_grad()
# outputs = model(**batch)
# loss = outputs.loss
# loss.backward()
# # Important: Ensure gradients don't revive pruned weights if using direct zeroing
# # Apply masks again or use a pruning utility that handles this
# optimizer.step()
# print("Fine-tuning complete.")
Structured pruning typically exhibits a trade-off between the level of sparsity and the performance degradation, often requiring fine-tuning for recovery.
Typical relationship between structured pruning sparsity (e.g., removing attention heads) and the resulting drop in model performance before fine-tuning. Higher sparsity often leads to a more significant performance decrease.
This practical exercise provides a foundation for applying structured pruning. Remember that optimizing the process involves careful selection of the pruning target, importance metric, sparsity level, and potentially integrating it with fine-tuning and specialized deployment frameworks.
© 2025 ApX Machine Learning