生成 $N$ 个词元需要大型语言模型进行 $N$ 次顺序前向传递,这是自回归生成中的一个根本限制。尽管像 KV 缓存这样的技术优化了单次前向传递内的计算,这种顺序依赖性依然自然限制了可达到的最大速度。推测解码提供了一种巧妙的方法来并行化此过程的部分内容,目标是在主大型模型每次单次前向传递中生成多个词元,从而降低总体的实际运行延迟。核心思路在于使用两个模型:目标模型 (Target Model): 大型、高质量语言模型,我们希望精确匹配其输出分布。这是我们最终希望用于生成的模型,但它速度慢。草稿模型 (Draft Model): 一个小得多、速度更快的语言模型。该模型不如目标模型精确,但可以非常快地生成词元序列。目标模型不再一次生成一个词元,过程如下:草稿生成: 快速草稿模型“推测”或提议一个短序列,包含 $k$ 个候选词元,紧随当前上下文。令当前序列为 $x_{1..i}$。草稿模型生成 $\hat{x}{i+1}, \hat{x}{i+2}, \dots, \hat{x}_{i+k}$。目标验证: 大型目标模型随后执行单次前向传递,将原始上下文 $x_{1..i}$ 和整个草稿序列 $\hat{x}{i+1..i+k}$ 作为输入。此单次传递根据目标模型,有效计算每个词元 $j=1..k$ 的真实概率 $p_T(\hat{x}{i+j} | x_{1..i}, \hat{x}_{i+1..i+j-1})$。重要的是,KV 缓存等技术在此仍然适用,使此验证步骤高效。接受检查: 一种统计接受机制(通常基于拒绝采样)用于比较草稿模型的预测与目标模型验证的概率。对于每个草稿词元 $\hat{x}_{i+j}$(从 $j=1$ 开始):比较目标模型分配的概率 $p_T(\hat{x}{i+j} | \dots)$ 与草稿模型分配的概率 $p_D(\hat{x}{i+j} | \dots)$。如果目标模型认为草稿词元 $\hat{x}{i+j}$ 足够可能(与草稿模型认为的概率相比),则接受该词元。一种常见方法是当 $p_T(\hat{x}{i+j} | \dots) / p_D(\hat{x}_{i+j} | \dots) \ge u_j$ 时接受,其中 $u_j \sim U(0, 1)$ 是一个随机数。此检查顺序进行,从 $j=1$ 到 $k$。如果任何词元 $\hat{x}_{i+j}$ 被拒绝,过程在该点停止接受词元。令 $n$ 为接受的词元数量($0 \le n < k$)。修正/继续:如果接受了 $n < k$ 个词元(意味着 $\hat{x}{i+n+1}$ 被拒绝),序列 $x{1..i+n}$ 现在已确定。下一个词元 $x_{i+n+1}$ 从一个修改过的分布中采样,该分布来源于目标模型的概率和草稿模型在该位置的概率,确保整体分布与目标模型匹配。如果所有 $k$ 个词元都被接受($n = k$),序列 $x_{1..i+k}$ 已确定。目标模型的前向传递已经计算了下一个词元 $x_{i+k+1}$ 的分布,所以我们直接从 $p_T(x | x_{1..i+k})$ 采样。重复: 过程从步骤1重复,使用新扩展的序列。潜在的加速来自于如果一个周期中接受了 $n > 0$ 个词元,我们就有效生成了 $n+1$ 个词元($n$ 个接受的词元加上最终采样的词元),仅使用昂贵的目标模型的一次前向传递和草稿模型的 $k$ 次快速前向传递。如果草稿模型足够准确,接受率($n$ 接近 $k$)可以很高,从而大幅减少生成时间。重要的是,统计接受机制确保最终生成的序列遵循目标模型的精确概率分布。digraph G { rankdir=TB; node [shape=box, style=rounded, fontname="sans-serif", color="#495057", fillcolor="#e9ecef", style="filled,rounded", fontsize=11]; edge [color="#495057", fontsize=11]; Start [label="当前上下文\nx_1 ... x_i"]; Draft [label="草稿模型生成\nk 个候选词元:\n x̂_(i+1) ... x̂_(i+k)"]; Verify [label="目标模型验证\n(单次前向传递)\n计算 p_T(x̂_(i+j)|...)"]; AcceptLoop [label="对于 j = 1 到 k:\n检查 x̂_(i+j) 的接受情况"]; Accept [label="词元 x̂_(i+j) 已接受"]; Reject [label="词元 x̂_(i+j) 已拒绝\n(已接受 n = j-1)"]; SampleCorrected [label="从修正分布中\n采样 x_(i+n+1)"]; AllAccepted [label="所有 k 个词元都已接受\n(n = k)"]; SampleNext [label="从目标分布 p_T 中\n采样 x_(i+k+1)"]; Update [label="更新上下文:\nx_1 ... x_(i+n+1 或 i+k+1)"]; Start -> Draft; Draft -> Verify; Verify -> AcceptLoop; AcceptLoop -> Accept [label="p_T/p_D >= U(0,1)"]; AcceptLoop -> Reject [label="p_T/p_D < U(0,1)"]; Accept -> AcceptLoop [label="j < k"]; Accept -> AllAccepted [label="j = k"]; Reject -> SampleCorrected; AllAccepted -> SampleNext; SampleCorrected -> Update; SampleNext -> Update; Update -> Draft [label="继续生成"]; } 流程图,说明推测解码过程。草稿模型提出词元,目标模型在单次传递中验证它们,接受循环决定在采样下一个词元之前保留多少个提出的词元。实现考量草稿模型选择: 选择合适的草稿模型很重要。它需要比目标模型快很多(例如,更少的层/参数,蒸馏模型)。但是,如果其预测与目标模型的差异过大,接受率会很低,抵消性能优势。在速度和预测质量之间找到平衡是必要的。步数 ($k$): 选择推测步数 $k$ 涉及权衡。较大的 $k$ 为每次目标模型推理提供了更大的加速潜力。但是,草稿模型在较长序列上的预测可能与目标模型偏离更多,降低所有 $k$ 个词元都被接受的概率。最佳 $k$ 通常取决于模型和任务。开销: 尽管在实际运行时间上更快,推测解码需要将目标模型和草稿模型都保留在内存中,增加了内存占用。运行草稿模型和执行接受检查也有计算开销。以下是一个类似 PyTorch 的代码片段,说明了核心循环结构:import torch import torch.nn.functional as F def speculative_decode_step(target_model, draft_model, input_ids, k): """ 执行推测解码的一个步骤。 假设模型返回对数(logits)并在内部处理 KV 缓存。 这是一个简化说明。 """ # 1. 草稿生成 # (使用草稿模型的自回归生成) draft_output_ids = draft_model.generate(input_ids, max_new_tokens=k, ...) # 只获取 k 个新词元 draft_ids = draft_output_ids[:, input_ids.shape[-1]:] # 将原始输入与草稿词元结合用于验证 verify_ids = torch.cat([input_ids, draft_ids], dim=-1) # 2. 目标验证 (单次前向传递) # target_logits 的形状: [batch_size, verify_seq_len, vocab_size] with torch.no_grad(): # 确保不计算梯度 target_logits = target_model(verify_ids).logits # 提取草稿位置的目标概率 # 我们查看用于预测 draft_ids[j] 的对数,给定 # 前导词元 # target_probs 的形状: [batch_size, k, vocab_size] target_probs = F.softmax( target_logits[:, input_ids.shape[-1]-1:-1, :], dim=-1 ) # 同时获取草稿模型对草稿词元的概率 # (可能需要单独调用或作为 draft_model.generate 的一部分) # 假设 draft_probs 的形状为 [batch_size, k, vocab_size] # draft_probs = get_draft_probs( # draft_model, input_ids, draft_ids # ) # 占位函数 accepted_count = 0 for j in range(k): # 获取在步骤 j *被*选作草稿的特定词元的概率 # 形状 [batch_size, 1] p_target = target_probs[:, j, draft_ids[:, j]].unsqueeze(-1) # 形状 [batch_size, 1] p_draft = draft_probs[:, j, draft_ids[:, j]].unsqueeze(-1) # 添加 epsilon 以提高数值稳定性 ratio = p_target / (p_draft + 1e-8) # 形状 [batch_size, 1] random_uniform = torch.rand_like(ratio) # 检查批处理中所有项是否被接受 if (ratio >= random_uniform).all(): accepted_count += 1 else: # 发生了拒绝 # 基于修改的分布采样第 (accepted_count + 1) 个词元 # p_modified = (target_probs[:, j, :] # - random_uniform * draft_probs[:, j, :]).clamp(min=0) # p_modified /= p_modified.sum(dim=-1, keepdim=True) # next_token = torch.multinomial(p_modified, num_samples=1) # final_ids = torch.cat([ # input_ids, # draft_ids[:, :accepted_count], # next_token # ], dim=-1) # return final_ids break # 简化:停止接受 if accepted_count == k: # 所有 k 个都接受,从目标模型的最后分布采样第 (k+1) 个词元 next_token_probs = F.softmax(target_logits[:, -1, :], dim=-1) next_token = torch.multinomial(next_token_probs, num_samples=1) final_ids = torch.cat([input_ids, draft_ids, next_token], dim=-1) else: # 在 accepted_count + 1 处发生拒绝 # 简化:仅为说明返回已接受的前缀 # 实际实现会在这里采样修正后的词元 final_ids = torch.cat( [input_ids, draft_ids[:, :accepted_count]], dim=-1 ) # 需要基于修正的分布采样下一个词元 return final_ids # 返回扩展序列 # 示例用法 # current_tokens = ... # 初始序列 # new_tokens = speculative_decode_step( # large_model, small_model, current_tokens, k=5 # )推测解码代表了一个有前景的方向,用于加速大型语言模型推理,在对延迟敏感的应用中尤其有价值。尽管它与标准自回归解码相比引入了额外的复杂度,但显著加速的潜力通常值得付出努力,特别是当与 KV 缓存和优化注意力核等其他优化技术结合时。