趋近智
大师班
生成 N 个词元需要大型语言模型进行 N 次顺序前向传递,这是自回归生成中的一个根本限制。尽管像 KV 缓存这样的技术优化了单次前向传递内的计算,这种顺序依赖性依然自然限制了可达到的最大速度。推测解码提供了一种巧妙的方法来并行化此过程的部分内容,目标是在主大型模型每次单次前向传递中生成多个词元,从而降低总体的实际运行延迟。
核心思路在于使用两个模型:
目标模型不再一次生成一个词元,过程如下:
潜在的加速来自于如果一个周期中接受了 n>0 个词元,我们就有效生成了 n+1 个词元(n 个接受的词元加上最终采样的词元),仅使用昂贵的目标模型的一次前向传递和草稿模型的 k 次快速前向传递。如果草稿模型足够准确,接受率(n 接近 k)可以很高,从而大幅减少生成时间。重要的是,统计接受机制确保最终生成的序列遵循目标模型的精确概率分布。
流程图,说明推测解码过程。草稿模型提出词元,目标模型在单次传递中验证它们,接受循环决定在采样下一个词元之前保留多少个提出的词元。
以下是一个类似 PyTorch 的代码片段,说明了核心循环结构:
import torch
import torch.nn.functional as F
def speculative_decode_step(target_model, draft_model, input_ids, k):
"""
执行推测解码的一个步骤。
假设模型返回对数(logits)并在内部处理 KV 缓存。
这是一个简化说明。
"""
# 1. 草稿生成
# (使用草稿模型的自回归生成)
draft_output_ids = draft_model.generate(input_ids, max_new_tokens=k, ...)
# 只获取 k 个新词元
draft_ids = draft_output_ids[:, input_ids.shape[-1]:]
# 将原始输入与草稿词元结合用于验证
verify_ids = torch.cat([input_ids, draft_ids], dim=-1)
# 2. 目标验证 (单次前向传递)
# target_logits 的形状: [batch_size, verify_seq_len, vocab_size]
with torch.no_grad(): # 确保不计算梯度
target_logits = target_model(verify_ids).logits
# 提取草稿位置的目标概率
# 我们查看用于预测 draft_ids[j] 的对数,给定
# 前导词元
# target_probs 的形状: [batch_size, k, vocab_size]
target_probs = F.softmax(
target_logits[:, input_ids.shape[-1]-1:-1, :],
dim=-1
)
# 同时获取草稿模型对草稿词元的概率
# (可能需要单独调用或作为 draft_model.generate 的一部分)
# 假设 draft_probs 的形状为 [batch_size, k, vocab_size]
# draft_probs = get_draft_probs(
# draft_model, input_ids, draft_ids
# ) # 占位函数
accepted_count = 0
for j in range(k):
# 获取在步骤 j *被*选作草稿的特定词元的概率
# 形状 [batch_size, 1]
p_target = target_probs[:, j, draft_ids[:, j]].unsqueeze(-1)
# 形状 [batch_size, 1]
p_draft = draft_probs[:, j, draft_ids[:, j]].unsqueeze(-1)
# 添加 epsilon 以提高数值稳定性
ratio = p_target / (p_draft + 1e-8)
# 形状 [batch_size, 1]
random_uniform = torch.rand_like(ratio)
# 检查批处理中所有项是否被接受
if (ratio >= random_uniform).all():
accepted_count += 1
else:
# 发生了拒绝
# 基于修改的分布采样第 (accepted_count + 1) 个词元
# p_modified = (target_probs[:, j, :]
# - random_uniform * draft_probs[:, j, :]).clamp(min=0)
# p_modified /= p_modified.sum(dim=-1, keepdim=True)
# next_token = torch.multinomial(p_modified, num_samples=1)
# final_ids = torch.cat([
# input_ids,
# draft_ids[:, :accepted_count],
# next_token
# ], dim=-1)
# return final_ids
break # 简化:停止接受
if accepted_count == k:
# 所有 k 个都接受,从目标模型的最后分布采样第 (k+1) 个词元
next_token_probs = F.softmax(target_logits[:, -1, :], dim=-1)
next_token = torch.multinomial(next_token_probs, num_samples=1)
final_ids = torch.cat([input_ids, draft_ids, next_token], dim=-1)
else:
# 在 accepted_count + 1 处发生拒绝
# 简化:仅为说明返回已接受的前缀
# 实际实现会在这里采样修正后的词元
final_ids = torch.cat(
[input_ids, draft_ids[:, :accepted_count]], dim=-1
)
# 需要基于修正的分布采样下一个词元
return final_ids # 返回扩展序列
# 示例用法
# current_tokens = ... # 初始序列
# new_tokens = speculative_decode_step(
# large_model, small_model, current_tokens, k=5
# )
推测解码代表了一个有前景的方向,用于加速大型语言模型推理,在对延迟敏感的应用中尤其有价值。尽管它与标准自回归解码相比引入了额外的复杂度,但显著加速的潜力通常值得付出努力,特别是当与 KV 缓存和优化注意力核等其他优化技术结合时。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造