自动微分(AD)是一种机制,使得JAX等框架能够准确高效地计算可能复杂函数的导数。虽然您可能熟悉使用jax.grad来获取函数对其第一个参数的梯度,但这种便利是对更基础运算的抽象,对应于AD的两种主要模式:前向模式和反向模式。了解这些模式对于处理更复杂的微分任务、优化性能以及定义自定义导数规则非常重要。本节回顾了这两种模式。AD通过将函数分解为一系列基本运算(如加法、乘法、sin、exp),并在这些运算层面反复应用链式法则来进行操作。与使用有限差分近似导数(并存在截断误差和舍入误差)的数值微分,或可能导致指数级大表达式的符号微分不同,AD以可控的计算成本计算精确导数(达到机器精度)。前向模式自动微分(雅可比向量积)前向模式AD在计算导数的同时,也进行原函数的求值。可以将其视为将中间变量对输入的“敏感度”向前传播通过计算图。考虑函数 $f: \mathbb{R}^n \to \mathbb{R}^m$。前向模式特别擅长计算雅可比向量积(JVP)。函数 $f$ 在点 $x$ 的雅可比矩阵 $J$ 包含所有偏导数 $\frac{\partial f_i}{\partial x_j}$。JVP 计算此雅可比矩阵 $J$ 与“切向”向量 $v$ 的乘积,其中 $v$ 通常表示输入空间中的一个方向:$$ \text{JVP}(f, x, v) = J \cdot v = \left. \frac{d}{d\alpha} f(x + \alpha v) \right|_{\alpha=0} $$此运算表明了当输入 $x$ 沿方向 $v$ 发生微小扰动时,函数 $f$ 的输出如何变化。前向模式无需显式构建完整的雅可比矩阵 $J$,即可高效计算此 JVP。它通过在函数求值的同时,通过链式法则传播方向导数 $v$ 来实现这一点。使用前向模式计算一个 JVP 的计算成本大致与计算原函数 $f(x)$ 本身的成本成比例,无论输出维度 $m$ 为何。然而,如果您需要完整的雅可比矩阵,则需要计算 $n$ 个 JVP,输入空间中每个标准基向量对应一个。因此,当前向模式的输入数量 $n$ 明显小于输出数量 $m$ 时(即“高”雅可比矩阵),其计算效率高。反向模式自动微分(向量雅可比积)反向模式AD,在神经网络背景下常与反向传播同义,其工作方式不同。它首先执行前向传播来计算函数的输出值 $f(x)$,并可能存储中间值(激活)。然后,它执行反向传播,从输出开始,使用链式法则将导数反向传播通过计算图。反向模式擅长计算向量雅可比积(VJP)。VJP 计算“余切”向量 $v^T$ 与雅可比矩阵 $J$ 的乘积:$$ \text{VJP}(f, x, v) = v^T \cdot J $$此运算接受一个向量 $v$,该向量代表函数输出处的敏感度或梯度,并计算出函数输入对应的敏感度或梯度。使用反向模式计算一个 VJP 的计算成本大致与计算原函数 $f(x)$ 的成本成比例,无论输入维度 $n$ 为何。这在处理具有许多输入和少量输出的函数时是一个显著优势,这是机器学习模型训练中的常见情况,函数即损失计算(许多参数作为输入,单个标量损失作为输出)。要计算标量值函数 $L(x)$(其中 $m=1$)的梯度,我们只需要一次 VJP 计算。由于 $L(x)$ 是标量,其雅可比矩阵是一个行向量(即梯度的转置,$\nabla L(x)^T$)。那么,当标量“敏感度” $v=1$ 时,VJP 为:$$ \text{VJP}(L, x, 1) = 1 \cdot \nabla L(x)^T = \nabla L(x)^T $$结果是梯度向量的转置。这就是反向模式作为 jax.grad 基础的原因。使用反向模式计算完整的雅可比矩阵将需要 $m$ 次 VJP 计算,输出空间中每个标准基向量对应一次。digraph ADModes { rankdir=LR; node [shape=box, style=rounded, fontname="sans-serif", color="#495057", fontcolor="#495057"]; edge [fontname="sans-serif", color="#495057", fontcolor="#495057"]; subgraph cluster_fwd { label = "前向模式 (JVP)"; style=dashed; color="#adb5bd"; bgcolor="#f8f9fa"; node [color="#1c7ed6", fontcolor="#1c7ed6", style="rounded,filled", fillcolor="#e7f5ff"]; edge [color="#1c7ed6"]; InputF [label="输入 x"]; PerturbationF [label="输入方向 v"]; FunctionF [label="函数 f\n(计算 f(x) & Jv)"]; OutputF [label="输出 f(x)"]; JVPOutput [label="方向输出 Jv"]; {InputF, PerturbationF} -> FunctionF [label="单次传递"]; FunctionF -> {OutputF, JVPOutput}; } subgraph cluster_rev { label = "反向模式 (VJP)"; style=dashed; color="#adb5bd"; bgcolor="#f8f9fa"; node [color="#ae3ec9", fontcolor="#ae3ec9", style="rounded,filled", fillcolor="#f8f0fc"]; edge [color="#ae3ec9"]; InputR [label="输入 x"]; FunctionR [label="函数 f"]; OutputR [label="输出 f(x)"]; PerturbationR [label="输出敏感度 vᵀ"]; Adjoints [label="输入敏感度 vᵀJ"]; InputR -> FunctionR [label="1. 前向传播"]; FunctionR -> OutputR; {OutputR, PerturbationR} -> FunctionR [label="2. 反向传播", style=dashed, dir=back]; FunctionR -> Adjoints [style=dashed, dir=back]; // Changed arrow direction to point from FunctionR to Adjoints for clarity of result Adjoints -> InputR [style=invis]; // Ensure InputR aligns visually } }前向模式 (JVP) 和反向模式 (VJP) 自动微分的流程。前向模式在前向传播中同时计算函数值和导数。反向模式在一次前向传播中计算函数值,然后在一个单独的反向传播中计算导数。选择合适的模式使用前向模式 (JVP) 当输入数量 ($n$) 相对于输出数量 ($m$) 较少时。适用于计算雅可比矩阵的少量列或海森向量积。使用反向模式 (VJP) 当输出数量 ($m$) 相对于输入数量 ($n$) 较少时。这是训练神经网络的典型情况(许多参数作为输入,单个标量损失作为输出),使其成为 jax.grad 的默认选择。回顾了这些基础思想后,我们现在可以查看 JAX 如何通过 jax.jvp 和 jax.vjp 提供对 JVP 和 VJP 的显式控制,支持更复杂的微分技术。