与将单个权重置零的非结构化剪枝不同,结构化剪枝移除的是整个参数组,例如注意力头或前馈网络 (FFN) 层内的神经元。这种方法形成规则的稀疏模式,有望在为发挥此类结构作用而设计的硬件上带来更明显的推理加速,尽管它通常需要更细致的实现和微调才能保持模型性能。在此实践练习中,我们将专注于实现基于 Transformer 模型的注意力头剪枝。这包括识别并移除模型各层中最不重要的注意力头。目标通过从预训练的 Transformer 模型中移除固定比例的注意力头,来实现结构化剪枝,并评估其对模型大小和相关性能指标的影响。设置我们将使用 Hugging Face transformers 库以及 PyTorch。请确保已安装这些库。我们将使用一个较小的预训练 Transformer 模型,例如 bert-base-uncased 或 distilbert-base-uncased,以便管理计算量。import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer import numpy as np # 加载预训练模型和分词器 model_name = "distilbert-base-uncased" # 使用一个易于管理的模型 model = AutoModelForSequenceClassification.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # 示例:准备一些虚拟数据用于重要性计算或评估 dummy_texts = ["This is an example sentence.", "Another example for testing."] inputs = tokenizer(dummy_texts, return_tensors="pt", padding=True, truncation=True) print(f"模型: {model_name}") print(f"参数数量: {model.num_parameters()}") # 访问Transformer块结构(DistilBERT示例) # 注意:不同模型(BERT、GPT等)的结构有所不同 transformer_blocks = model.distilbert.transformer.layer num_layers = len(transformer_blocks) num_heads = model.config.num_attention_heads head_dim = model.config.dim // num_heads print(f"层数: {num_layers}, 每层头数: {num_heads}, 头维度: {head_dim}")步骤 1:计算头的重要性我们需要一个衡量各注意力头重要性的指标。一种常见方法是使用与每个头相关的权重的范数(L1 或 L2 范数)。具体来说,我们可以查看自注意力机制中每个头的输出投影权重 ($W^O$)。范数权重较小的头被认为重要性较低。让我们概述为每个头的输出投影权重计算 L2 范数的过程:head_importances = [] for layer_idx in range(num_layers): attention_layer = transformer_blocks[layer_idx].attention # 输出投影权重矩阵:形状 (hidden_dim, hidden_dim) W_O = attention_layer.out_lin.weight.data # 形状:(768, 768) 对于 distilbert-base # W_O 结合了所有头的输出。每个头贡献一个切片。 # W_O 可以看作 concat([W_O_h1, W_O_h2, ..., W_O_hN]),其中每个 W_O_hi 的形状为 (head_dim, hidden_dim) # 但在实际矩阵中是转置的。 # 因此,我们需要为 W_O 按头计算每列的范数。 # 重新整形为头视图后的有效形状:(hidden_dim, num_heads, head_dim) # 我们需要的是从每个头的输出空间进行投影的范数。 # 让我们计算从每个头的输出投影的权重的 L2 范数。 # 输出线性层权重矩阵 W_O 的形状为 [dim, dim]。 # 它可以被视为连接了多个矩阵,每个矩阵的形状为 [dim, head_dim],每个头对应一个。 # W_O = [W_O_1 | W_O_2 | ... | W_O_num_heads],其中 W_O_i 的形状为 [dim, head_dim] layer_head_norms = [] for head_idx in range(num_heads): # 提取对应于 head_idx 输出投影的权重 # 形状:(hidden_dim, head_dim) head_weights = W_O[:, head_idx * head_dim : (head_idx + 1) * head_dim] norm = torch.linalg.norm(head_weights).item() layer_head_norms.append(norm) head_importances.append(layer_head_norms) # 存储该层中每个头的范数 # 将列表的列表展平为单个 (layer_idx, head_idx, importance) 元组列表 all_head_importances = [] for layer_idx, norms in enumerate(head_importances): for head_idx, norm in enumerate(norms): all_head_importances.append(((layer_idx, head_idx), norm)) # 根据重要性全局排序头(升序) all_head_importances.sort(key=lambda x: x[1]) print(f"已计算 {len(all_head_importances)} 个头的重要性。") # 显示几个最不重要的头 print("最不重要的头(层,头):") for i in range(min(10, len(all_head_importances))): print(f" 层 {all_head_importances[i][0][0]}, 头 {all_head_importances[i][0][1]}: 范数 = {all_head_importances[i][1]:.4f}") 说明: 重要性计算可以更为复杂,涉及激活分析或正向/反向传播过程中的梯度信息,但权重范数是一种常用且更简单的起点。步骤 2:定义稀疏度并识别要剪枝的头让我们设定一个目标稀疏度。例如,我们可能目标是剪枝总注意力头的 20%。target_sparsity = 0.20 # 剪枝20%的头 total_heads = num_layers * num_heads num_heads_to_prune = int(total_heads * target_sparsity) # 获取重要性得分最低的头 heads_to_prune = {head_info[0] for head_info in all_head_importances[:num_heads_to_prune]} print(f"总头数: {total_heads}") print(f"目标稀疏度: {target_sparsity*100:.1f}%") print(f"要剪枝的头数: {num_heads_to_prune}") # print(f"识别出要剪枝的头: {sorted(list(heads_to_prune))}") # 取消注释以查看列表步骤 3:应用剪枝掩码应用结构化剪枝涉及创建掩码以将与所选头相关的参数置零。这需要细致处理每个注意力层中查询 (Q)、键 (K)、值 (V) 和输出 (O) 投影的权重矩阵。Q、K 和 V 权重通常组合存储在 q_lin.weight、k_lin.weight、v_lin.weight 等矩阵中(形状 [hidden_dim, hidden_dim]),有时也合并成一个大的 in_proj_weight。out_lin.weight(形状 [hidden_dim, hidden_dim])结合了输出。我们需要识别与被剪枝的特定头对应的行/列。让我们演示如何掩蔽单个头的 K、V 和输出投影权重。def create_mask(param_shape, head_idx_to_prune, num_heads, head_dim, prune_dim): """根据头索引为权重张量创建掩码。""" mask = torch.ones(param_shape) start_index = head_idx_to_prune * head_dim end_index = start_index + head_dim if prune_dim == 0: # 剪枝行(例如,对于 Q, K, V 权重,如果形状是 [hidden_dim, hidden_dim]) mask[start_index:end_index, :] = 0 elif prune_dim == 1: # 剪枝列(例如,对于 O 权重,如果形状是 [hidden_dim, hidden_dim]) mask[:, start_index:end_index] = 0 return mask # 应用剪枝 - 这里为简化起见采用永久修改 # 实际中,使用 torch.nn.utils.prune 进行正确的掩蔽和潜在的移除 for layer_idx, head_idx in heads_to_prune: attention_layer = transformer_blocks[layer_idx].attention # --- 剪枝 Q, K, V 权重 --- # 形状:[hidden_dim, hidden_dim]。需要剪枝对应于头输出特征的行。 q_weight = attention_layer.q_lin.weight k_weight = attention_layer.k_lin.weight v_weight = attention_layer.v_lin.weight # 我们将 Q, K, V 的头输出维度视为剪枝目标 q_mask = create_mask(q_weight.shape, head_idx, num_heads, head_dim, prune_dim=0) k_mask = create_mask(k_weight.shape, head_idx, num_heads, head_dim, prune_dim=0) v_mask = create_mask(v_weight.shape, head_idx, num_heads, head_dim, prune_dim=0) # 应用掩码(这里直接修改权重) with torch.no_grad(): q_weight.data *= q_mask k_weight.data *= k_mask v_weight.data *= v_mask # 如果偏置存在且按头结构化(通常不是),也剪枝相应的偏置 if attention_layer.q_lin.bias is not None: # 偏置形状通常是 [hidden_dim],剪枝与头对应的切片 bias_mask = create_mask(attention_layer.q_lin.bias.shape, head_idx, num_heads, head_dim, prune_dim=0) # 偏置向量的维度0 attention_layer.q_lin.bias.data *= bias_mask[:, 0] # 使用掩码的第一列 # 如果 k_lin.bias, v_lin.bias 存在,重复此操作 # --- 剪枝输出投影权重 --- # 形状:[hidden_dim, hidden_dim]。需要剪枝对应于头输入特征的列。 o_weight = attention_layer.out_lin.weight o_mask = create_mask(o_weight.shape, head_idx, num_heads, head_dim, prune_dim=1) # 剪枝输出投影的列 with torch.no_grad(): o_weight.data *= o_mask # 输出偏置 (out_lin.bias) 通常是大小为 [hidden_dim] 的单个向量,通常不按头剪枝。 print(f"已对 {len(heads_to_prune)} 个头应用剪枝掩码。") 此图示说明了结构化剪枝如何移除整个组件(如注意力头),这与分散零值的非结构化剪枝不同。digraph G { rankdir=LR; node [shape=record, style=filled, color="#ced4da", fillcolor="#e9ecef", fontname="helvetica"]; edge [color="#868e96"]; subgraph cluster_unstructured { label = "非结构化剪枝"; bgcolor="#f8f9fa"; node [shape=point, color="#adb5bd"]; edge [style=invis]; u1 [pos="0,1!", label=""]; u2 [pos="0.5,1!", label=""]; u3 [pos="1,1!", label=""]; u4 [pos="1.5,1!", label=""]; u5 [pos="0,0.5!", label=""]; u6 [pos="0.5,0.5!", label="", color="#fa5252"]; u7 [pos="1,0.5!", label=""]; u8 [pos="1.5,0.5!", label="", color="#fa5252"]; u9 [pos="0,0!", label="", color="#fa5252"]; u10 [pos="0.5,0!", label=""]; u11 [pos="1,0!", label="", color="#fa5252"]; u12 [pos="1.5,0!", label=""]; u1->u2->u3->u4; u5->u6->u7->u8; u9->u10->u11->u12; u1->u5->u9; u2->u6->u10; u3->u7->u11; u4->u8->u12; pruned_u [label="移除的单个权重(红色)", shape=plaintext, pos="0.75,-0.5!", fontcolor="#495057"]; } subgraph cluster_structured { label = "结构化剪枝(头部示例)"; bgcolor="#f8f9fa"; node [shape=rect, style="filled", width=0.6, height=0.4, label="", fontname="helvetica"]; subgraph cluster_head1 { label="头 1"; color="#e9ecef"; bgcolor="#f8f9fa"; sh1_1; sh1_2; sh1_3;} subgraph cluster_head2 { label="头 2\n(已剪枝)"; color="#ffc9c9"; bgcolor="#ffecf0"; node[fillcolor="#ffa8a8"]; sh2_1; sh2_2; sh2_3;} subgraph cluster_head3 { label="头 3"; color="#e9ecef"; bgcolor="#f8f9fa"; sh3_1; sh3_2; sh3_3;} sh1_1 [pos="3,1.2!", fillcolor="#a5d8ff"]; sh1_2 [pos="3.7,1.2!", fillcolor="#a5d8ff"]; sh1_3 [pos="4.4,1.2!", fillcolor="#a5d8ff"]; sh2_1 [pos="3,0.5!", fillcolor="#ffa8a8"]; sh2_2 [pos="3.7,0.5!", fillcolor="#ffa8a8"]; sh2_3 [pos="4.4,0.5!", fillcolor="#ffa8a8"]; sh3_1 [pos="3, -0.2!", fillcolor="#a5d8ff"]; sh3_2 [pos="3.7,-0.2!", fillcolor="#a5d8ff"]; sh3_3 [pos="4.4,-0.2!", fillcolor="#a5d8ff"]; pruned_s [label="整个组(头)被移除", shape=plaintext, pos="3.7,-0.8!", fontcolor="#495057"]; } }非结构化与结构化稀疏模式的对比。结构化剪枝移除整个块(例如,头 2),这有望实现硬件加速。关于实现的一点说明: torch.nn.utils.prune 模块提供了处理剪枝的方法,包括持久管理掩码以及 prune.remove 等函数,可以通过实际移除置零参数使剪枝永久化(如果结构允许,但这对于头部剪枝而言比较复杂)。对于生产环境,建议使用此类工具或专用库(如 NVIDIA 的 FasterTransformer 或稀疏感知编译器)。如这里所示直接将权重置零展示了基本方法,但单独使用可能无法带来加速。步骤 4:评估剪枝后的模型剪枝后,我们需要评估其影响。参数数量: 虽然头部剪枝会移除参数,但减少量可能小于目标头部稀疏度百分比,因为共享嵌入和非注意力层保持不变。实际参数数量需要重新计算。性能: 在相关任务上评估剪枝后的模型。如果是分类模型,请检查验证集上的准确率。如果是生成模型,请检查困惑度或其他生成质量指标。延迟: 测量推理延迟。特别指出的是,从结构化剪枝中观察到显著的延迟降低,通常需要专门的推理后端或硬件,这些后端或硬件可以跳过涉及剪枝结构的计算。标准硬件上的简单掩蔽可能无法加速推理,甚至可能由于掩码应用开销而略微降低速度。# 示例:重新计算参数(需要详细检查) # 如果权重直接置零,一种简单方法是计算非零元素的数量 # 注意:torch.nn.utils.prune 更正式地处理此问题 non_zero_params = sum(p.nonzero().size(0) for p in model.parameters() if p.requires_grad) total_params = model.num_parameters() print(f"原始参数: {total_params}") print(f"剪枝后参数(非零): {non_zero_params}") print(f"减少量: {(total_params - non_zero_params) / total_params * 100:.2f}%") # 示例:评估性能(需要适当的评估数据集和任务) # model.eval() # with torch.no_grad(): # outputs = model(**inputs) # logits = outputs.logits # # ... 计算准确率、困惑度或其他相关指标 ... # print("评估结果需要适当的数据集和指标。")步骤 5:微调(可选但推荐)结构化剪枝有时会导致性能出现显著下降。对剪枝后的模型在原始任务(或下游任务)上以较低的学习率进行短时间微调,有助于恢复损失的准确率。# 微调设置的伪代码 # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) # model.train() # for epoch in range(num_finetune_epochs): # for batch in fine_tuning_dataloader: # optimizer.zero_grad() # outputs = model(**batch) # loss = outputs.loss # loss.backward() # # 重要提示:如果直接置零,请确保梯度不会使剪枝的权重复活 # # 再次应用掩码或使用处理此问题的剪枝工具 # optimizer.step() # print("微调完成。")结构化剪枝通常呈现出稀疏度水平与性能下降之间的权衡,通常需要通过微调来恢复。{"layout": {"title": "性能与头部稀疏度的关系", "xaxis": {"title": "注意力头稀疏度 (%)"}, "yaxis": {"title": "性能下降 (%)", "range": [0, 20]}, "font": {"family": "sans-serif"}, "plot_bgcolor": "#f8f9fa", "paper_bgcolor": "#ffffff"}, "data": [{"x": [0, 10, 20, 30, 40, 50], "y": [0, 1, 3, 8, 15, 25], "type": "scatter", "mode": "lines+markers", "name": "性能下降", "marker": {"color": "#f06595", "size": 8}, "line": {"color": "#f06595", "width": 2}}]}结构化剪枝稀疏度(例如,移除注意力头)与微调前模型性能下降之间的典型关系。更大幅度的稀疏度通常会导致更显著的性能下降。考量重要性指标: 重要性得分(权重范数、激活值大小、基于梯度)的选择,对哪些结构被剪枝以及最终性能有很大影响。通常需要进行实验。硬件/软件支持: 实现延迟优势需要兼容的硬件和推理库(例如 TensorRT、经过稀疏性优化的 ONNX Runtime、自定义核函数),这些硬件和库可以善用结构化稀疏性。微调: 在进行非微不足道的结构化剪枝后,通常需要为微调预留时间以达到可接受的性能。其他结构: 本示例专注于注意力头。类似的原理也适用于剪枝 FFN 层中的神经元/滤波器甚至整个层,并相应调整识别和掩蔽逻辑。此实践练习为应用结构化剪枝奠定了基础。请记住,优化该过程涉及仔细选择剪枝目标、重要性指标、稀疏度水平,并可能将其与微调和专用部署框架相结合。