Masterclass
While absolute positional encodings, whether sinusoidal or learned, provide the necessary sequence order information to Transformers, they have limitations. They typically assume a maximum sequence length, and their ability to generalize to sequences longer than those seen during training can be restricted. Furthermore, they don't explicitly represent the relative distance between tokens, which might be a more natural way for attention mechanisms to operate.
The Transformer-XL architecture, introduced by Dai et al. (2019), proposed a novel relative positional encoding scheme designed specifically to address these issues, particularly in the context of processing longer sequences using a segment-level recurrence mechanism (though the encoding method itself is valuable independently). Instead of adding positional information to the word embeddings, Transformer-XL injects relative position information directly into the attention score calculation.
The central idea is to modify how the attention score between a query at position i and a key at position j is computed. In the standard Transformer, the score depends on the query qi and the key kj, where both potentially contain absolute positional information added to their respective embeddings.
Transformer-XL reformulates the attention score calculation to explicitly depend on the relative distance (i−j). It achieves this by making two primary modifications:
Relative Position Embeddings for Keys: Instead of using absolute position embeddings pj for the key vector, it uses relative position embeddings Ri−j that represent the offset between the query and key positions. These embeddings Ri−j are typically fixed sinusoidal encodings, similar to the original Transformer, but they encode the relative distance rather than an absolute position. Importantly, the same relative embedding Rk is used for all query positions i when considering a key that is k positions away (i.e., j=i−k). This allows the model to potentially generalize to unseen relative distances.
Decomposition of Query Interaction: The query vector qi interacts differently with the content and positional aspects of the key. The standard dot product qiTkj is decomposed into multiple terms that separate content-content interaction, content-position interaction, and position-position interaction using dedicated trainable parameters.
Let qi=WQxi be the query vector for token xi at position i, and kj=WKxj be the content-based key vector for token xj at position j. Let Ri−j be the sinusoidal embedding vector for the relative position i−j. In the standard Transformer, the core term inside the softmax would be approximately qiT(kj+pj).
Transformer-XL replaces this with a more sophisticated calculation for the attention score Ai,j:
Ai,jrel=(a) content-basedqiTWKxj+(b) content-positionqiTWRRi−j+(c) global content biasuTWKxj+(d) global position biasvTWRRi−jHere:
Let's break down the terms:
The final attention weights are obtained by applying softmax over these scores Ai,jrel (usually after scaling by dk1).
Generating the relative positional encodings Ri−j typically involves creating standard sinusoidal encodings for a maximum expected relative distance (e.g., from −L to +L, where L is the context length or segment length). During the attention calculation for a query at position i, you would look up the appropriate encoding Rk for each key at position j=i−k.
The introduction of the trainable vectors u and v and the separate projection matrix WR adds parameters compared to the standard Transformer attention, but allows for more nuanced modeling of relative positional importance.
A simplified PyTorch-style outline for calculating the relative attention scores might look like this (focusing on the score calculation, omitting multi-head and other details):
import torch
import torch.nn as nn
import math
class RelativeSinusoidalPositionalEncoding(nn.Module):
# Generates fixed sinusoidal encodings for relative positions
def __init__(self, d_model, max_len=5000):
super().__init__()
self.d_model = d_model
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# We need encodings for relative positions [-max_len+1, max_len-1]
# Create double the length and slice later. Store centrally shifted.
pe_full = torch.cat([
pe.flip(0)[:-1, :],
pe
], dim=0) # size (2*max_len - 1, d_model)
self.register_buffer('pe', pe_full)
self.max_len = max_len
def forward(self, seq_len_q, seq_len_k):
# Assume query length seq_len_q, key length seq_len_k
# We need relative positions from -(seq_len_k - 1)
# to (seq_len_q - 1)
# In self-attention, seq_len_q == seq_len_k == L
# Relative indices range from -(L-1) to (L-1)
# Map these indices to the stored buffer [0, 2*max_len - 2]
# Example for self-attention (L=seq_len_q=seq_len_k)
relative_indices = torch.arange(
seq_len_k - 1, -seq_len_q, -1, dtype=torch.long
)
# Shift indices to be positive for lookup in 'pe' buffer
# Center is at max_len - 1
buffer_indices = relative_indices + self.max_len - 1
relative_encodings = self.pe[buffer_indices]
# Shape (seq_len_q + seq_len_k - 1, d_model)
# We need a matrix R of shape (seq_len_q, seq_len_k, d_model)
# where R[i, j, :] = encoding for relative distance (i - j)
# This requires careful slicing/indexing based on the use case
# (self-attention vs encoder-decoder)
# For self-attention (seq_len_q=L, seq_len_k=L):
# We need encodings R_{i-j} for i in [0, L-1], j in [0, L-1]
# The relative distances range from -(L-1) to L-1
# The relative_encodings buffer holds encodings for distances k
# from L-1 down to -(L-1)
start_idx = self.max_len - seq_len_k
end_idx = start_idx + seq_len_q + seq_len_k - 1
# Select the relevant part of the buffer
rel_enc = self.pe[start_idx:end_idx]
# Shape (L+L-1, d_model) for self-attention
# Create the final matrix R (L, L, d_model) efficiently
# This might involve clever slicing or matrix operations
# For simplicity, let's assume we have a function
# `get_rel_embeddings(rel_enc, L)`
# that returns the (L, L, d_model) matrix.
# R = get_rel_embeddings(rel_enc, seq_len_q) # Placeholder
# # for complex indexing
# return R
# Simplified: Let's focus on the attention score calculation
# assuming R_ij matrix is available
pass
# Returning the buffer for now, actual use requires more
# indexing logic
class TransformerXLRelativeAttention(nn.Module):
def __init__(self, d_model, nhead):
super().__init__()
self.d_model = d_model
self.nhead = nhead
self.d_head = d_model // nhead
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_r = nn.Linear(d_model, d_model)
# Projection for relative embeddings
# Trainable parameters u and v
# (shared across heads initially for simplicity)
self.u = nn.Parameter(torch.Tensor(self.nhead, self.d_head))
self.v = nn.Parameter(torch.Tensor(self.nhead, self.d_head))
nn.init.xavier_uniform_(self.u)
nn.init.xavier_uniform_(self.v)
self.dropout = nn.Dropout(0.1)
self.scale = 1.0 / math.sqrt(self.d_head)
# Assume RelativeSinusoidalPositionalEncoding module provides R
# self.relative_pos_encoder = RelativeSinusoidalPositionalEncoding(
# d_model, max_len)
def forward(self, query_embed, key_embed, value_embed, R_ij,
mask=None):
# query_embed, key_embed, value_embed:
# (batch_size, seq_len, d_model)
# R_ij: Precomputed relative positional encodings projection
# W_R * R_{i-j}
# Expected shape for efficient computation:
# (batch_size, nhead, seq_len_q, d_head) for term (b)
# And (batch_size, nhead, seq_len_k, d_head) for term (d)
# after multiplying by v
# Mask: (batch_size, seq_len_q, seq_len_k)
batch_size, seq_len_q, _ = query_embed.size()
seq_len_k = key_embed.size(1)
Q = self.W_q(query_embed).view(
batch_size, seq_len_q, self.nhead, self.d_head
)
K = self.W_k(key_embed).view(
batch_size, seq_len_k, self.nhead, self.d_head
)
V = self.W_v(value_embed).view(
batch_size, seq_len_k, self.nhead, self.d_head
)
# R_ij needs to be projected by W_r and reshaped appropriately
# before passing here
# Projected R_ij = self.W_r(raw_R_ij).view(....)
# Transpose for attention calculation:
# (batch_size, nhead, seq_len, d_head)
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Term (a): Content-based
AC = torch.matmul(Q + self.u.unsqueeze(0).unsqueeze(2),
K.transpose(-2, -1))
# Q shape: (batch, nhead, seq_len_q, d_head)
# K^T shape: (batch, nhead, d_head, seq_len_k)
# Result AC shape: (batch, nhead, seq_len_q, seq_len_k)
# Here we add global content bias 'u' to the query side
# Term (b) and (d): Position-based
# This requires careful handling of R_ij tensor shapes
# and relative indexing
# Assume R_proj = W_r(R) is precomputed and shaped
# (seq_len_q, seq_len_k, nhead, d_head)
# R_proj = R_proj.permute(2, 0, 3, 1)
# -> (nhead, seq_len_q, d_head, seq_len_k)
# Term (b): (Q + self.v) * R_proj
# BD = torch.matmul(Q + self.v.unsqueeze(0).unsqueeze(2), R_proj)
# Pseudocode - dimensions need care
# Simplified calculation showing the concept -
# actual implementation is complex
# due to efficient computation of relative terms using matrix shifts
# or skewing.
# See Dai et al. (2019) Appendix B for efficient implementation details.
# Placeholder for combined score calculation:
# scores = AC + BD # Placeholder for the full score calculation
# Let's use a placeholder for scores:
# Assuming AC represents the sum of terms (a) and (c)
# Assuming BD represents the sum of terms (b) and (d)
# computed efficiently
# scores = (AC + BD) * self.scale # Placeholder
# Fake scores for demonstration structure
scores = AC * self.scale
if mask is not None:
# Ensure mask has compatible shape
# (batch_size, 1, seq_len_q, seq_len_k)
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
output = torch.matmul(attn_weights, V)
# (batch, nhead, seq_len_q, d_head)
output = output.transpose(1, 2).contiguous().view(
batch_size, seq_len_q, self.d_model
)
return output, attn_weights
Note: The actual implementation of term (b) and (d) requires careful tensor manipulation (often involving skewing matrices) to perform the relative position calculations efficiently without explicitly constructing the full Ri−j matrix for every pair (i,j) at each step. The code above simplifies this part.
Compared to the approach by Shaw et al., which adds relative position biases after the main query-key dot product, Transformer-XL integrates the relative position information more deeply into the score calculation by having the query interact directly with relative position embeddings (qiTWRRi−j) and incorporating global position biases (vTWRRi−j).
In summary, the Transformer-XL relative positional encoding offers a sophisticated alternative to absolute encodings. By focusing on relative distances and decomposing the attention score calculation, it provides better generalization capabilities for sequence length and forms a cornerstone for architectures designed to handle very long contexts. Its implementation requires careful handling of relative positional embeddings and additional trainable parameters, but the benefits for long-sequence modeling are significant.
© 2025 ApX Machine Learning