趋近智
自回归 (autoregressive)生成,即基于先前已生成的标记 (token)逐个生成文本的过程,是大型语言模型生成回复的中心方式。然而,简单的实现会遇到一个很大的性能障碍。Transformer架构中的自注意力 (self-attention)机制 (attention mechanism)是此过程的核心。为了生成下一个标记,例如标记 ,标准的自注意力计算需要基于所有之前的标记 来计算查询(Q)、键(K)和值(V),然后计算注意力分数。当生成随后的标记 时,这个完整过程会使用标记 再次进行。请注意这种重复:在生成标记 期间为标记 计算的键和值向量 (vector),与生成标记 时前 个标记所需的向量是相同的。在每一步重复这些计算是计算上的浪费,特别是当序列长度增加时。
键值(KV)缓存是一种基本优化技术,其目的在于消除自回归推理 (inference)中的这种重复。其中心思想简单而高效:存储自注意力层中为所有先前标记计算出的键(K)和值((V)张量,并在随后的生成步骤中重复使用它们。
在Transformer的自注意力 (self-attention)层中,输入序列 被投影到三个矩阵:查询()、键()和值()。
为可学习的权重 (weight)矩阵。注意力输出随后按以下方式计算:
考虑生成标记 (token) 。模型将标记嵌入 (embedding)序列 作为输入。在每个注意力层内,它计算 和 。它还基于最后一个标记 的嵌入(或对应于位置 的位置嵌入)计算查询向量 (vector) 。然后,注意力计算使用 和完整的键集合 以及值集合 。
现在,考虑生成标记 。输入序列为 。模型需要计算 和 。重要的是,对 和 的计算与上一步完全相同,因为它们仅依赖于输入标记 以及固定的权重矩阵 和 。
KV缓存善用这个特性。它不是在每一步重新计算所有的键和值,而是:
简化的流程图,展示了在步骤
t计算出的键(K)和值(V)如何被缓存并在步骤t+1复用,这样只需要为新标记x_{t+1}进行计算。
这大大减少了每个生成标记的计算成本。注意力计算复杂度不再像每个步骤中与序列长度 的平方大致成比例(如果考虑完整矩阵乘法是 ,或者仅将查询应用于现有键是 ),而是与过去标记相关的计算实际变成了常数时间(缓存查找和拼接),主要成本变为计算单个新标记的K和V,并将新查询应用于缓存的键( 部分为 )。
尽管KV缓存显著加快了推理 (inference)速度,但它也带来了内存成本。缓存需要为批次中的每个序列,存储所有先前标记 (token)在所有层和所有注意力头中的键和值张量。KV缓存的大小可以估算为:
缓存大小 ≈ batch_size × num_layers × 2(针对 K 和 V) × num_heads × sequence_length × head_dimension × bytes_per_element
这种内存占用随 sequence_length 线性增长。对于具有多层和多头的模型,以及处理长序列或大批次时,KV缓存会占用大量的GPU内存,有时成为可处理最大序列长度的限制因素。管理这种内存使用是一个重要的考量,它促成了分页注意力或缓存本身的量化 (quantization)等技术,尽管这些内容超出了本基本介绍的范围。
实作KV缓存通常涉及修改Transformer块的 forward 方法(或直接修改注意力模块),使其接受一个可选的 past_key_values 参数 (parameter)并返回更新后的 present_key_values。
下面是一个(高度简化的)草图,对比了标准注意力计算和使用KV缓存的注意力计算:
import torch
import torch.nn as nn
# 假设 'attention_layer' 是一个预定义的多头注意力模块
# 简化的多头注意力占位符
class SimpleMultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, (
"embed_dim 必须能被 num_heads 整除"
)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, query, key, value, past_kv=None):
batch_size, seq_len, _ = query.size()
# 投影查询、键、值
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
# 为多头注意力重塑形状
q = q.view(
batch_size, -1, self.num_heads, self.head_dim
).transpose(1, 2) # [B, nh, L_q, hd]
k = k.view(
batch_size, -1, self.num_heads, self.head_dim
).transpose(1, 2) # [B, nh, L_k, hd]
v = v.view(
batch_size, -1, self.num_heads, self.head_dim
).transpose(1, 2) # [B, nh, L_v, hd]
present_kv = None
if past_kv is not None:
# past_kv 是一个元组 (past_k, past_v)
# 每个的形状为 [B, nh, L_past, hd]
past_k, past_v = past_kv
# 沿着序列长度维度 (dim=2) 拼接
k = torch.cat((past_k, k), dim=2)
v = torch.cat((past_v, v), dim=2)
# 存储更新后的 K, V 以供下一步使用
present_kv = (k, v) # Shape [B, nh, L_past + L_k, hd]
else:
# 首次存储 K, V
present_kv = (k, v) # Shape [B, nh, L_k, hd]
# 计算注意力分数
# q: [B, nh, L_q, hd], k.transpose: [B, nh, hd, L_k]
# -> scores: [B, nh, L_q, L_k]
scores = torch.matmul(q, k.transpose(-2, -1))
scores = scores / (self.head_dim ** 0.5)
attn_weights = torch.softmax(scores, dim=-1)
# 将注意力权重应用于值
# attn_weights: [B, nh, L_q, L_k], v: [B, nh, L_v, hd]
# -> output: [B, nh, L_q, hd]
# (此处假设 L_v == L_k)
output = torch.matmul(attn_weights, v)
# 重塑并投影输出
output = output.transpose(1, 2).contiguous()
output = output.view(
batch_size, -1, self.embed_dim
) # [B, L_q, embed_dim]
output = self.out_proj(output)
# 返回输出和此层更新后的键值缓存
return output, present_kv
# --- 生成过程中的用法 ---
# model = 你的Transformer模型(...)
# kv_cache = None # 初始空缓存(每层一个列表或元组)
# input_ids = 初始提示词ID
# for _ in range(最大新标记数):
# # 准备当前步骤的输入
# # (通常仅为最后一个生成的标记)
# current_input_ids = input_ids[:, -1:] # Shape [B, 1]
# # 带缓存的前向传播
# # 注意:模型的前向传播需要处理缓存的向下传递
# # 并收集更新
# outputs = model(
# input_ids=current_input_ids,
# past_key_values=kv_cache,
# use_cache=True
# )
# logits = outputs.logits
# kv_cache = outputs.past_key_values # 为下一次迭代更新缓存
# # 获取预测的下一个标记ID(例如,使用argmax或采样)
# next_token_id = torch.argmax(
# logits[:, -1:, :], dim=-1
# ) # Shape [B, 1]
# # 为下一次迭代的完整输入追加新的标记ID
# # (尽管只有最后一个用于Q)
# input_ids = torch.cat([input_ids, next_token_id], dim=-1)
# # 检查停止条件等
在实践中,像Hugging Face Transformers这样的框架对这种缓存机制进行了抽象。当调用 generate 方法或使用带有 use_cache=True 参数的模型前向传播时,框架会自动处理生成步骤之间KV缓存的创建、传递和更新。然而,理解其基本原理对认识性能提升和内存影响很重要。
KV缓存是高效Transformer推理 (inference)的根本。它直接解决了简单自回归 (autoregressive)解码在序列长度方面的二次复杂度瓶颈,使得在实践中生成更长序列成为可能。尽管它带来了内存开销,但计算上的节省几乎总是使其成为一项不可或缺的优化。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•