自动微分是JAX的核心组成部分,它使得机器学习模型能够进行基于梯度的优化。虽然jax.grad提供了获取梯度的便捷接口,但了解并控制其内部运作将为复杂任务带来更大的灵活性和更高的效率。本章将基于jax.grad的初步知识进行拓展。你将学习自动微分的两种基本模式:正向模式(雅可比向量积,JVP)和反向模式(向量雅可比积,VJP)。我们将介绍如何直接使用jax.jvp和jax.vjp计算它们,这些构成了jax.grad及其他变换的构建块。主要内容包括:计算通过组合微分变换得到的高阶导数。高效计算完整雅可比矩阵和海森矩阵的方法。使用jax.custom_vjp和jax.custom_jvp为函数定义自定义微分规则,这可用于提高数值稳定性、优化性能或处理非JAX代码。理解自动微分在lax.scan、lax.cond和lax.while_loop等控制流原语中如何运作。处理不可微分函数或需要明确停止梯度传播的函数的方法,例如使用jax.lax.stop_gradient。在本章结束时,你将对JAX的自动微分系统有更全面的理解,并掌握在进阶场景中有效应用它的工具。