趋近智
大师班
旋转位置编码(RoPE)提供了一种独特方式,将位置信息融入到 Transformer 架构中。不同于通过添加位置向量的绝对位置编码,或通常直接修改注意力分数计算的相对位置编码(如 Shaw 等人的方法或 Transformer-XL),RoPE 在注意力分数计算 之前,对查询 (Q) 和键 (K) 向量应用依赖于位置的旋转。这种方式通过旋转变换巧妙地表示了相对位置信息。
其核心思想源于一个发现:两个分别旋转了 α 和 β 角度的向量之间的点积,取决于它们的原始点积以及角度差 (α−β)。RoPE 运用此特性,设计出依赖于令牌绝对位置的旋转矩阵。
设位于位置 m 的查询向量为 qm,位于位置 n 的键向量为 kn。RoPE 旨在变换这些向量,使得它们的内积 qm′⋅kn′ 主要依赖于原始向量 qm,kn 和它们的相对位置 m−n。
实现此目的的方式是将嵌入维度 d 视为若干维度对,并对每对应用二维旋转。对于向量 x 和位置 m,变换 f(x,m) 会施加一个旋转。设查询和键向量的维度为 d。我们可以将向量分成 d/2 个大小为 2 的块。对于第 i 个块(对应维度 2i−1 和 2i),旋转矩阵 Rm,i 定义为:
Rm,i=(cos(mθi)sin(mθi)−sin(mθi)cos(mθi))这里,θi 是一个频率项,它依赖于块索引 i。一个常见选择是 θi=base−2i/d,其中 base 是一个较大的数字(例如 10000),确保频率在不同维度上有所变化。这类似于正弦绝对位置编码中的频率选择。
RoPE 变换随后按块应用于查询 qm 和键 kn:
qm′=f(qm,m)=Rmqmkn′=f(kn,n)=Rnkn其中 Rm 和 Rn 分别表示由 Rm,i 和 Rn,i 块形成的块对角矩阵。
值得关注的属性是,旋转后的查询向量和键向量之间的内积本身就捕捉了相对位置信息:
(qm′)Tkn′=(Rmqm)T(Rnkn)=qmTRmTRnkn由于 Rm 是一个旋转矩阵,其转置是其逆,RmT=Rm−1=R−m。因此,RmTRn=R−mRn=Rn−m。内积变为:
(qm′)Tkn′=qmTRn−mkn这种最终形式表明,位于位置 m 的查询与位于位置 n 的键之间的作用,明确依赖于它们的相对位置 n−m(或等效地,m−n,因为 Rn−m 包含了此差异)以及原始查询和键向量。
另一种方法是,使用复数提供了一种简洁的视角。将每个二维块 [x2i−1,x2i] 表示为复数 xi=x2i−1+jx2i,则旋转 mθi 等价于乘以 ejmθi。旋转后的查询和键分量为 qm,i′=qm,iejmθi 和 kn,i′=kn,iejnθi。它们对注意力分数的贡献涉及其乘积的实部(考虑到复数点积中有一个被共轭):
实部(qm,i′kn,i′)=实部((qm,iejmθi)(kn,iejnθi))=实部(qm,ikn,iejmθie−jnθi)=实部(qm,ikn,iej(m−n)θi)对所有块 i 求和,再次显示了对相对位置 m−n 的依赖。
实际应用中,RoPE 应用于多头注意力机制内的查询和键投影,在计算注意力分数之前。这通常涉及预先计算所有所需位置和维度的余弦和正弦值。
让我们看一个 PyTorch 实现代码片段。假设 q 和 k 是形状为 (batch_size, seq_len, num_heads, head_dim) 的张量。我们还需要预先计算好的 cos_cached 和 sin_cached 张量,通常形状为 (max_seq_len, head_dim // 2)。
import torch
def rotate_half(x):
"""将输入张量的隐藏维度旋转一半。"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
# 将后半部分取负,然后连接起来:(-x2, x1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos_cached, sin_cached):
"""
将旋转位置编码应用于查询和键张量。
参数:
q (torch.Tensor): 查询张量,形状 (bs, seq_len, num_heads, head_dim)
k (torch.Tensor): 键张量,形状 (bs, seq_len, num_heads, head_dim)
cos_cached (torch.Tensor): 预先计算的余弦值,
形状 (seq_len, head_dim // 2)
sin_cached (torch.Tensor): 预先计算的正弦值,
形状 (seq_len, head_dim // 2)
返回:
Tuple[torch.Tensor, torch.Tensor]: 旋转后的查询和键张量。
"""
# 为 num_heads 添加维度,并在需要时沿批次维度扩展
# cos_cached 形状: (seq_len, 1, head_dim // 2)
cos = cos_cached[:q.shape[1], ...].unsqueeze(1)
# sin_cached 形状: (seq_len, 1, head_dim // 2)
sin = sin_cached[:q.shape[1], ...].unsqueeze(1)
# 将 cos 和 sin 重复以适配完整的 head_dim: (seq_len, 1, head_dim)
cos = torch.cat((cos, cos), dim=-1)
sin = torch.cat((sin, sin), dim=-1)
# 应用旋转
# q_rot = (q * cos) + (rotate_half(q) * sin)
# k_rot = (k * cos) + (rotate_half(k) * sin)
# 替代计算方法,避免在主计算中显式调用
# rotate_half 函数
# 重塑 q 和 k,以分离维度对
# q 形状: (bs, seq_len, num_heads, head_dim / 2, 2)
q_reshaped = q.float().reshape(*q.shape[:-1], -1, 2)
k_reshaped = k.float().reshape(*k.shape[:-1], -1, 2)
# 使用复数乘法逻辑应用旋转
# 将 cos/sin 转换为复数: R = cos + j*sin
# 将 q/k 块转换为复数: Q = q1 + j*q2
# 旋转后的 Q' = Q * R = (q1 + j*q2)(cos + j*sin)
# = (q1*cos - q2*sin) + j*(q1*sin + q2*cos)
# q_out1 = q1*cos - q2*sin
# q_out2 = q2*cos + q1*sin
# 重塑 cos/sin 以便广播:
# (1, seq_len, 1, head_dim / 2) -> (1, seq_len, 1, head_dim / 2, 1)
# 只保留前半部分用于配对
cos = cos[..., :q.shape[-1] // 2].unsqueeze(-1)
# 只保留前半部分用于配对
sin = sin[..., :q.shape[-1] // 2].unsqueeze(-1)
q_out1 = q_reshaped[..., 0:1] * cos - q_reshaped[..., 1:2] * sin
q_out2 = q_reshaped[..., 1:2] * cos + q_reshaped[..., 0:1] * sin
q_rot = torch.cat((q_out1, q_out2), dim=-1).flatten(start_dim=-2)
k_out1 = k_reshaped[..., 0:1] * cos - k_reshaped[..., 1:2] * sin
k_out2 = k_reshaped[..., 1:2] * cos + k_reshaped[..., 0:1] * sin
k_rot = torch.cat((k_out1, k_out2), dim=-1).flatten(start_dim=-2)
return q_rot.type_as(q), k_rot.type_as(k)
# 示例用法:
# 假设您已经预先计算了 cos_cached, sin_cached
# max_seq_len = 2048
# head_dim = 128
# base = 10000.0
# inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() /
# head_dim))
# t = torch.arange(max_seq_len, device=inv_freq.device,
# dtype=inv_freq.dtype)
# freqs = torch.einsum("i,j->ij", t, inv_freq)
# emb = torch.cat((freqs, freqs), dim=-1)
# cos_cached = emb.cos()[:, :head_dim // 2]
# sin_cached = emb.sin()[:, :head_dim // 2]
# 在注意力层内部:
# q_rot, k_rot = apply_rotary_pos_emb(q, k, cos_cached, sin_cached)
# 使用 q_rot 和 k_rot 计算注意力分数
apply_rotary_pos_emb 函数接收查询、键和预先计算的余弦/正弦值(源自位置索引和频率)。它重塑最后一个维度以处理维度对,应用旋转逻辑,并返回修改后的查询和键张量。这些旋转后的张量随后用于标准的缩放点积注意力计算。
RoPE 在现代大型语言模型中得到广泛应用,缘于其多项优点:
与其他方法比较:
频率计算中 base 超参数的选择 (θi=base−2i/d) 可能影响性能和外推能力,需要仔细调整。尽管它在数学上精妙且在 Llama 和 PaLM 等模型中取得了实际成效,但理解它与其他模型组件的关联以及在超长序列上的表现,仍是一个活跃的研究方向。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造