趋近智
自动微分(AD)是一种机制,使得JAX等框架能够准确高效地计算可能复杂函数的导数。虽然您可能熟悉使用jax.grad来获取函数对其第一个参数的梯度,但这种便利是对更基础运算的抽象,对应于AD的两种主要模式:前向模式和反向模式。了解这些模式对于处理更复杂的微分任务、优化性能以及定义自定义导数规则非常重要。
本节回顾了这两种模式。AD通过将函数分解为一系列基本运算(如加法、乘法、sin、exp),并在这些运算层面反复应用链式法则来进行操作。与使用有限差分近似导数(并存在截断误差和舍入误差)的数值微分,或可能导致指数级大表达式的符号微分不同,AD以可控的计算成本计算精确导数(达到机器精度)。
前向模式AD在计算导数的同时,也进行原函数的求值。可以将其视为将中间变量对输入的“敏感度”向前传播通过计算图。
考虑函数 f:Rn→Rm。前向模式特别擅长计算雅可比向量积(JVP)。函数 f 在点 x 的雅可比矩阵 J 包含所有偏导数 ∂xj∂fi。JVP 计算此雅可比矩阵 J 与“切向”向量 v 的乘积,其中 v 通常表示输入空间中的一个方向:
JVP(f,x,v)=J⋅v=dαdf(x+αv)α=0此运算表明了当输入 x 沿方向 v 发生微小扰动时,函数 f 的输出如何变化。前向模式无需显式构建完整的雅可比矩阵 J,即可高效计算此 JVP。它通过在函数求值的同时,通过链式法则传播方向导数 v 来实现这一点。
使用前向模式计算一个 JVP 的计算成本大致与计算原函数 f(x) 本身的成本成比例,无论输出维度 m 为何。然而,如果您需要完整的雅可比矩阵,则需要计算 n 个 JVP,输入空间中每个标准基向量对应一个。因此,当前向模式的输入数量 n 明显小于输出数量 m 时(即“高”雅可比矩阵),其计算效率高。
反向模式AD,在神经网络背景下常与反向传播同义,其工作方式不同。它首先执行前向传播来计算函数的输出值 f(x),并可能存储中间值(激活)。然后,它执行反向传播,从输出开始,使用链式法则将导数反向传播通过计算图。
反向模式擅长计算向量雅可比积(VJP)。VJP 计算“余切”向量 vT 与雅可比矩阵 J 的乘积:
VJP(f,x,v)=vT⋅J此运算接受一个向量 v,该向量代表函数输出处的敏感度或梯度,并计算出函数输入对应的敏感度或梯度。
使用反向模式计算一个 VJP 的计算成本大致与计算原函数 f(x) 的成本成比例,无论输入维度 n 为何。这在处理具有许多输入和少量输出的函数时是一个显著优势,这是机器学习模型训练中的常见情况,函数即损失计算(许多参数作为输入,单个标量损失作为输出)。
要计算标量值函数 L(x)(其中 m=1)的梯度,我们只需要一次 VJP 计算。由于 L(x) 是标量,其雅可比矩阵是一个行向量(即梯度的转置,∇L(x)T)。那么,当标量“敏感度” v=1 时,VJP 为:
VJP(L,x,1)=1⋅∇L(x)T=∇L(x)T结果是梯度向量的转置。这就是反向模式作为 jax.grad 基础的原因。使用反向模式计算完整的雅可比矩阵将需要 m 次 VJP 计算,输出空间中每个标准基向量对应一次。
前向模式 (JVP) 和反向模式 (VJP) 自动微分的流程。前向模式在前向传播中同时计算函数值和导数。反向模式在一次前向传播中计算函数值,然后在一个单独的反向传播中计算导数。
jax.grad 的默认选择。回顾了这些基础思想后,我们现在可以查看 JAX 如何通过 jax.jvp 和 jax.vjp 提供对 JVP 和 VJP 的显式控制,支持更复杂的微分技术。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造