趋近智
自动微分(AD)是一种机制,使得JAX等框架能够准确高效地计算可能复杂函数的导数。虽然您可能熟悉使用jax.grad来获取函数对其第一个参数 (parameter)的梯度,但这种便利是对更基础运算的抽象,对应于AD的两种主要模式:前向模式和反向模式。了解这些模式对于处理更复杂的微分任务、优化性能以及定义自定义导数规则非常重要。
本节回顾了这两种模式。AD通过将函数分解为一系列基本运算(如加法、乘法、sin、exp),并在这些运算层面反复应用链式法则来进行操作。与使用有限差分近似导数(并存在截断误差和舍入误差)的数值微分,或可能导致指数级大表达式的符号微分不同,AD以可控的计算成本计算精确导数(达到机器精度)。
前向模式AD在计算导数的同时,也进行原函数的求值。可以将其视为将中间变量对输入的“敏感度”向前传播通过计算图。
考虑函数 。前向模式特别擅长计算雅可比向量积(JVP)。函数 在点 的雅可比矩阵 包含所有偏导数 。JVP 计算此雅可比矩阵 与“切向”向量 的乘积,其中 通常表示输入空间中的一个方向:
此运算表明了当输入 沿方向 发生微小扰动时,函数 的输出如何变化。前向模式无需显式构建完整的雅可比矩阵 ,即可高效计算此 JVP。它通过在函数求值的同时,通过链式法则传播方向导数 来实现这一点。
使用前向模式计算一个 JVP 的计算成本大致与计算原函数 本身的成本成比例,无论输出维度 为何。然而,如果您需要完整的雅可比矩阵,则需要计算 个 JVP,输入空间中每个标准基向量对应一个。因此,当前向模式的输入数量 明显小于输出数量 时(即“高”雅可比矩阵),其计算效率高。
反向模式AD,在神经网络 (neural network)背景下常与反向传播 (backpropagation)同义,其工作方式不同。它首先执行前向传播来计算函数的输出值 ,并可能存储中间值(激活)。然后,它执行反向传播,从输出开始,使用链式法则将导数反向传播通过计算图。
反向模式擅长计算向量雅可比积(VJP)。VJP 计算“余切”向量 与雅可比矩阵 的乘积:
此运算接受一个向量 ,该向量代表函数输出处的敏感度或梯度,并计算出函数输入对应的敏感度或梯度。
使用反向模式计算一个 VJP 的计算成本大致与计算原函数 的成本成比例,无论输入维度 为何。这在处理具有许多输入和少量输出的函数时是一个显著优势,这是机器学习 (machine learning)模型训练中的常见情况,函数即损失计算(许多参数 (parameter)作为输入,单个标量损失作为输出)。
要计算标量值函数 (其中 )的梯度,我们只需要一次 VJP 计算。由于 是标量,其雅可比矩阵是一个行向量(即梯度的转置,)。那么,当标量“敏感度” 时,VJP 为:
结果是梯度向量的转置。这就是反向模式作为 jax.grad 基础的原因。使用反向模式计算完整的雅可比矩阵将需要 次 VJP 计算,输出空间中每个标准基向量对应一次。
前向模式 (JVP) 和反向模式 (VJP) 自动微分的流程。前向模式在前向传播中同时计算函数值和导数。反向模式在一次前向传播中计算函数值,然后在一个单独的反向传播中计算导数。
jax.grad 的默认选择。回顾了这些基础思想后,我们现在可以查看 JAX 如何通过 jax.jvp 和 jax.vjp 提供对 JVP 和 VJP 的显式控制,支持更复杂的微分技术。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•