趋近智
与将单个权重 (weight)置零的非结构化剪枝不同,结构化剪枝移除的是整个参数 (parameter)组,例如注意力头或前馈网络 (FFN) 层内的神经元。这种方法形成规则的稀疏模式,有望在为发挥此类结构作用而设计的硬件上带来更明显的推理 (inference)加速,尽管它通常需要更细致的实现和微调 (fine-tuning)才能保持模型性能。
在此实践练习中,我们将专注于实现基于 Transformer 模型的注意力头剪枝。这包括识别并移除模型各层中最不重要的注意力头。
通过从预训练 (pre-training)的 Transformer 模型中移除固定比例的注意力头,来实现结构化剪枝,并评估其对模型大小和相关性能指标的影响。
我们将使用 Hugging Face transformers 库以及 PyTorch。请确保已安装这些库。我们将使用一个较小的预训练 (pre-training) Transformer 模型,例如 bert-base-uncased 或 distilbert-base-uncased,以便管理计算量。
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import numpy as np
# 加载预训练模型和分词器
model_name = "distilbert-base-uncased" # 使用一个易于管理的模型
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 示例:准备一些虚拟数据用于重要性计算或评估
dummy_texts = ["This is an example sentence.", "Another example for testing."]
inputs = tokenizer(dummy_texts, return_tensors="pt", padding=True, truncation=True)
print(f"模型: {model_name}")
print(f"参数数量: {model.num_parameters()}")
# 访问Transformer块结构(DistilBERT示例)
# 注意:不同模型(BERT、GPT等)的结构有所不同
transformer_blocks = model.distilbert.transformer.layer
num_layers = len(transformer_blocks)
num_heads = model.config.num_attention_heads
head_dim = model.config.dim // num_heads
print(f"层数: {num_layers}, 每层头数: {num_heads}, 头维度: {head_dim}")
我们需要一个衡量各注意力头重要性的指标。一种常见方法是使用与每个头相关的权重 (weight)的范数(L1 或 L2 范数)。具体来说,我们可以查看自注意力 (self-attention)机制 (attention mechanism)中每个头的输出投影权重 ()。范数权重较小的头被认为重要性较低。
让我们概述为每个头的输出投影权重计算 L2 范数的过程:
head_importances = []
for layer_idx in range(num_layers):
attention_layer = transformer_blocks[layer_idx].attention
# 输出投影权重矩阵:形状 (hidden_dim, hidden_dim)
W_O = attention_layer.out_lin.weight.data # 形状:(768, 768) 对于 distilbert-base
# W_O 结合了所有头的输出。每个头贡献一个切片。
# W_O 可以看作 concat([W_O_h1, W_O_h2, ..., W_O_hN]),其中每个 W_O_hi 的形状为 (head_dim, hidden_dim)
# 但在实际矩阵中是转置的。
# 因此,我们需要为 W_O 按头计算每列的范数。
# 重新整形为头视图后的有效形状:(hidden_dim, num_heads, head_dim)
# 我们需要的是从每个头的输出空间进行投影的范数。
# 让我们计算从每个头的输出投影的权重的 L2 范数。
# 输出线性层权重矩阵 W_O 的形状为 [dim, dim]。
# 它可以被视为连接了多个矩阵,每个矩阵的形状为 [dim, head_dim],每个头对应一个。
# W_O = [W_O_1 | W_O_2 | ... | W_O_num_heads],其中 W_O_i 的形状为 [dim, head_dim]
layer_head_norms = []
for head_idx in range(num_heads):
# 提取对应于 head_idx 输出投影的权重
# 形状:(hidden_dim, head_dim)
head_weights = W_O[:, head_idx * head_dim : (head_idx + 1) * head_dim]
norm = torch.linalg.norm(head_weights).item()
layer_head_norms.append(norm)
head_importances.append(layer_head_norms) # 存储该层中每个头的范数
# 将列表的列表展平为单个 (layer_idx, head_idx, importance) 元组列表
all_head_importances = []
for layer_idx, norms in enumerate(head_importances):
for head_idx, norm in enumerate(norms):
all_head_importances.append(((layer_idx, head_idx), norm))
# 根据重要性全局排序头(升序)
all_head_importances.sort(key=lambda x: x[1])
print(f"已计算 {len(all_head_importances)} 个头的重要性。")
# 显示几个最不重要的头
print("最不重要的头(层,头):")
for i in range(min(10, len(all_head_importances))):
print(f" 层 {all_head_importances[i][0][0]}, 头 {all_head_importances[i][0][1]}: 范数 = {all_head_importances[i][1]:.4f}")
说明: 重要性计算可以更为复杂,涉及激活分析或正向/反向传播 (backpropagation)过程中的梯度信息,但权重范数是一种常用且更简单的起点。
让我们设定一个目标稀疏度。例如,我们可能目标是剪枝总注意力头的 20%。
target_sparsity = 0.20 # 剪枝20%的头
total_heads = num_layers * num_heads
num_heads_to_prune = int(total_heads * target_sparsity)
# 获取重要性得分最低的头
heads_to_prune = {head_info[0] for head_info in all_head_importances[:num_heads_to_prune]}
print(f"总头数: {total_heads}")
print(f"目标稀疏度: {target_sparsity*100:.1f}%")
print(f"要剪枝的头数: {num_heads_to_prune}")
# print(f"识别出要剪枝的头: {sorted(list(heads_to_prune))}") # 取消注释以查看列表
应用结构化剪枝涉及创建掩码以将与所选头相关的参数 (parameter)置零。这需要细致处理每个注意力层中查询 (Q)、键 (K)、值 (V) 和输出 (O) 投影的权重 (weight)矩阵。
Q、K 和 V 权重通常组合存储在 q_lin.weight、k_lin.weight、v_lin.weight 等矩阵中(形状 [hidden_dim, hidden_dim]),有时也合并成一个大的 in_proj_weight。out_lin.weight(形状 [hidden_dim, hidden_dim])结合了输出。我们需要识别与被剪枝的特定头对应的行/列。
让我们演示如何掩蔽单个头的 K、V 和输出投影权重。
def create_mask(param_shape, head_idx_to_prune, num_heads, head_dim, prune_dim):
"""根据头索引为权重张量创建掩码。"""
mask = torch.ones(param_shape)
start_index = head_idx_to_prune * head_dim
end_index = start_index + head_dim
if prune_dim == 0: # 剪枝行(例如,对于 Q, K, V 权重,如果形状是 [hidden_dim, hidden_dim])
mask[start_index:end_index, :] = 0
elif prune_dim == 1: # 剪枝列(例如,对于 O 权重,如果形状是 [hidden_dim, hidden_dim])
mask[:, start_index:end_index] = 0
return mask
# 应用剪枝 - 这里为简化起见采用永久修改
# 实际中,使用 torch.nn.utils.prune 进行正确的掩蔽和潜在的移除
for layer_idx, head_idx in heads_to_prune:
attention_layer = transformer_blocks[layer_idx].attention
# --- 剪枝 Q, K, V 权重 ---
# 形状:[hidden_dim, hidden_dim]。需要剪枝对应于头输出特征的行。
q_weight = attention_layer.q_lin.weight
k_weight = attention_layer.k_lin.weight
v_weight = attention_layer.v_lin.weight
# 我们将 Q, K, V 的头输出维度视为剪枝目标
q_mask = create_mask(q_weight.shape, head_idx, num_heads, head_dim, prune_dim=0)
k_mask = create_mask(k_weight.shape, head_idx, num_heads, head_dim, prune_dim=0)
v_mask = create_mask(v_weight.shape, head_idx, num_heads, head_dim, prune_dim=0)
# 应用掩码(这里直接修改权重)
with torch.no_grad():
q_weight.data *= q_mask
k_weight.data *= k_mask
v_weight.data *= v_mask
# 如果偏置存在且按头结构化(通常不是),也剪枝相应的偏置
if attention_layer.q_lin.bias is not None:
# 偏置形状通常是 [hidden_dim],剪枝与头对应的切片
bias_mask = create_mask(attention_layer.q_lin.bias.shape, head_idx, num_heads, head_dim, prune_dim=0) # 偏置向量的维度0
attention_layer.q_lin.bias.data *= bias_mask[:, 0] # 使用掩码的第一列
# 如果 k_lin.bias, v_lin.bias 存在,重复此操作
# --- 剪枝输出投影权重 ---
# 形状:[hidden_dim, hidden_dim]。需要剪枝对应于头输入特征的列。
o_weight = attention_layer.out_lin.weight
o_mask = create_mask(o_weight.shape, head_idx, num_heads, head_dim, prune_dim=1) # 剪枝输出投影的列
with torch.no_grad():
o_weight.data *= o_mask
# 输出偏置 (out_lin.bias) 通常是大小为 [hidden_dim] 的单个向量,通常不按头剪枝。
print(f"已对 {len(heads_to_prune)} 个头应用剪枝掩码。")
此图示说明了结构化剪枝如何移除整个组件(如注意力头),这与分散零值的非结构化剪枝不同。
非结构化与结构化稀疏模式的对比。结构化剪枝移除整个块(例如,头 2),这有望实现硬件加速。
关于实现的一点说明: torch.nn.utils.prune 模块提供了处理剪枝的方法,包括持久管理掩码以及 prune.remove 等函数,可以通过实际移除置零参数使剪枝永久化(如果结构允许,但这对于头部剪枝而言比较复杂)。对于生产环境,建议使用此类工具或专用库(如 NVIDIA 的 FasterTransformer 或稀疏感知编译器)。如这里所示直接将权重置零展示了基本方法,但单独使用可能无法带来加速。
剪枝后,我们需要评估其影响。
# 示例:重新计算参数(需要详细检查)
# 如果权重直接置零,一种简单方法是计算非零元素的数量
# 注意:torch.nn.utils.prune 更正式地处理此问题
non_zero_params = sum(p.nonzero().size(0) for p in model.parameters() if p.requires_grad)
total_params = model.num_parameters()
print(f"原始参数: {total_params}")
print(f"剪枝后参数(非零): {non_zero_params}")
print(f"减少量: {(total_params - non_zero_params) / total_params * 100:.2f}%")
# 示例:评估性能(需要适当的评估数据集和任务)
# model.eval()
# with torch.no_grad():
# outputs = model(**inputs)
# logits = outputs.logits
# # ... 计算准确率、困惑度或其他相关指标 ...
# print("评估结果需要适当的数据集和指标。")
结构化剪枝有时会导致性能出现显著下降。对剪枝后的模型在原始任务(或下游任务)上以较低的学习率进行短时间微调,有助于恢复损失的准确率。
# 微调设置的伪代码
# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
# model.train()
# for epoch in range(num_finetune_epochs):
# for batch in fine_tuning_dataloader:
# optimizer.zero_grad()
# outputs = model(**batch)
# loss = outputs.loss
# loss.backward()
# # 重要提示:如果直接置零,请确保梯度不会使剪枝的权重复活
# # 再次应用掩码或使用处理此问题的剪枝工具
# optimizer.step()
# print("微调完成。")
结构化剪枝通常呈现出稀疏度水平与性能下降之间的权衡,通常需要通过微调来恢复。
结构化剪枝稀疏度(例如,移除注意力头)与微调前模型性能下降之间的典型关系。更大幅度的稀疏度通常会导致更显著的性能下降。
此实践练习为应用结构化剪枝奠定了基础。请记住,优化该过程涉及仔细选择剪枝目标、重要性指标、稀疏度水平,并可能将其与微调和专用部署框架相结合。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•