Masterclass
Rotary Position Embedding (RoPE) offers a distinct method for incorporating positional information into the Transformer architecture. Unlike absolute positional embeddings that add positional vectors or relative position embeddings that often modify the attention score calculation directly (like in Shaw et al.'s method or Transformer-XL), RoPE applies position-dependent rotations to the query (Q) and key (K) vectors before the attention scores are computed. This approach elegantly encodes relative positional information through the rotational transformation.
The core idea stems from the observation that the dot product between two vectors rotated by angles α and β respectively depends on their original dot product and the difference between the angles (α−β). RoPE leverages this property by designing rotation matrices that depend on the token's absolute position.
Consider a query vector qm at position m and a key vector kn at position n. RoPE aims to transform these vectors such that their inner product qm′⋅kn′ primarily depends on the original vectors qm,kn and their relative position m−n.
This is achieved by viewing the embedding dimension d as pairs of dimensions and applying a 2D rotation to each pair. For a vector x and a position m, the transformation f(x,m) applies a rotation. Let the query and key vectors have dimension d. We can partition the vectors into d/2 blocks of size 2. For the i-th block (corresponding to dimensions 2i−1 and 2i), the rotation matrix Rm,i is defined as:
Rm,i=(cos(mθi)sin(mθi)−sin(mθi)cos(mθi))Here, θi is a frequency term that depends on the block index i. A common choice is θi=base−2i/d, where base is a large number (e.g., 10000) ensuring that frequencies vary across dimensions. This resembles the frequency choices in sinusoidal absolute positional embeddings.
The RoPE transformation is then applied block-wise to the query qm and key kn:
qm′=f(qm,m)=Rmqmkn′=f(kn,n)=Rnknwhere Rm and Rn represent the block-diagonal matrices formed by the Rm,i and Rn,i blocks, respectively.
The remarkable property is that the inner product between the rotated query and key vectors inherently captures relative position:
(qm′)Tkn′=(Rmqm)T(Rnkn)=qmTRmTRnknSince Rm is a rotation matrix, its transpose is its inverse, RmT=Rm−1=R−m. Therefore, RmTRn=R−mRn=Rn−m. The inner product becomes:
(qm′)Tkn′=qmTRn−mknThis final form demonstrates that the interaction between the query at position m and the key at position n explicitly depends on their relative position n−m (or equivalently, m−n, as Rn−m incorporates this difference) and the original query and key vectors.
Alternatively, using complex numbers provides a concise view. Representing each 2D block [x2i−1,x2i] as a complex number xi=x2i−1+jx2i, the rotation by mθi is equivalent to multiplication by ejmθi. The rotated query and key components are qm,i′=qm,iejmθi and kn,i′=kn,iejnθi. Their contribution to the attention score involves the real part of their product (considering one is conjugated in the complex dot product):
Re(qm,i′kn,i′)=Re((qm,iejmθi)(kn,iejnθi))=Re(qm,ikn,iejmθie−jnθi)=Re(qm,ikn,iej(m−n)θi)Summing over all blocks i again shows the dependency on the relative position m−n.
In practice, RoPE is applied to the query and key projections within the multi-head attention mechanism before computing the attention scores. This typically involves precomputing the cosine and sine values for all required positions and dimensions.
Let's consider a PyTorch implementation snippet. Assume q
and k
are tensors of shape (batch_size, seq_len, num_heads, head_dim)
. We need precomputed cos_cached
and sin_cached
tensors, usually of shape (max_seq_len, head_dim // 2)
.
import torch
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
# Negate the second half, then concatenate: (-x2, x1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos_cached, sin_cached):
"""
Applies Rotary Position Embedding to the query and key tensors.
Args:
q (torch.Tensor): Query tensor, shape (bs, seq_len, num_heads, head_dim)
k (torch.Tensor): Key tensor, shape (bs, seq_len, num_heads, head_dim)
cos_cached (torch.Tensor): Precomputed cosine values,
shape (seq_len, head_dim // 2)
sin_cached (torch.Tensor): Precomputed sine values,
shape (seq_len, head_dim // 2)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors.
"""
# Add dimension for num_heads and expand along batch dimension if needed
# cos_cached shape: (seq_len, 1, head_dim // 2)
cos = cos_cached[:q.shape[1], ...].unsqueeze(1)
# sin_cached shape: (seq_len, 1, head_dim // 2)
sin = sin_cached[:q.shape[1], ...].unsqueeze(1)
# Repeat cos and sin for full head_dim: (seq_len, 1, head_dim)
cos = torch.cat((cos, cos), dim=-1)
sin = torch.cat((sin, sin), dim=-1)
# Apply rotation
# q_rot = (q * cos) + (rotate_half(q) * sin)
# k_rot = (k * cos) + (rotate_half(k) * sin)
# Alternative calculation avoiding explicit
# rotate_half call within the main calc
# Reshape q and k to separate pairs of dimensions
# q shape: (bs, seq_len, num_heads, head_dim / 2, 2)
q_reshaped = q.float().reshape(*q.shape[:-1], -1, 2)
k_reshaped = k.float().reshape(*k.shape[:-1], -1, 2)
# Apply rotation using complex number multiplication logic
# Convert cos/sin to complex numbers: R = cos + j*sin
# Convert q/k blocks to complex: Q = q1 + j*q2
# Rotated Q' = Q * R = (q1 + j*q2)(cos + j*sin)
# = (q1*cos - q2*sin) + j*(q1*sin + q2*cos)
# q_out1 = q1*cos - q2*sin
# q_out2 = q2*cos + q1*sin
# Reshape cos/sin for broadcasting:
# (1, seq_len, 1, head_dim / 2) -> (1, seq_len, 1, head_dim / 2, 1)
# Keep only first half for pairs
cos = cos[..., :q.shape[-1] // 2].unsqueeze(-1)
# Keep only first half for pairs
sin = sin[..., :q.shape[-1] // 2].unsqueeze(-1)
q_out1 = q_reshaped[..., 0:1] * cos - q_reshaped[..., 1:2] * sin
q_out2 = q_reshaped[..., 1:2] * cos + q_reshaped[..., 0:1] * sin
q_rot = torch.cat((q_out1, q_out2), dim=-1).flatten(start_dim=-2)
k_out1 = k_reshaped[..., 0:1] * cos - k_reshaped[..., 1:2] * sin
k_out2 = k_reshaped[..., 1:2] * cos + k_reshaped[..., 0:1] * sin
k_rot = torch.cat((k_out1, k_out2), dim=-1).flatten(start_dim=-2)
return q_rot.type_as(q), k_rot.type_as(k)
# Example Usage:
# Assume you have precomputed cos_cached, sin_cached
# max_seq_len = 2048
# head_dim = 128
# base = 10000.0
# inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() /
# head_dim))
# t = torch.arange(max_seq_len, device=inv_freq.device,
# dtype=inv_freq.dtype)
# freqs = torch.einsum("i,j->ij", t, inv_freq)
# emb = torch.cat((freqs, freqs), dim=-1)
# cos_cached = emb.cos()[:, :head_dim // 2]
# sin_cached = emb.sin()[:, :head_dim // 2]
# Inside the attention layer:
# q_rot, k_rot = apply_rotary_pos_emb(q, k, cos_cached, sin_cached)
# Compute attention scores using q_rot and k_rot
The apply_rotary_pos_emb
function takes queries, keys, and the precomputed cosine/sine values (derived from position indices and frequencies). It reshapes the last dimension to handle pairs, applies the rotation logic, and returns the modified query and key tensors. These rotated tensors are then used in the standard scaled dot-product attention calculation.
RoPE has become popular in modern LLMs due to several advantages:
Compared to other methods:
The choice of the base
hyperparameter for frequency calculation (θi=base−2i/d) can influence performance and extrapolation capabilities, requiring careful tuning. Despite its mathematical elegance and practical success in models like Llama and PaLM, understanding its interaction with other model components and its behavior on very long sequences remains an area of active research.
© 2025 ApX Machine Learning