趋近智
自动微分是JAX的核心组成部分,它使得机器学习 (machine learning)模型能够进行基于梯度的优化。虽然jax.grad提供了获取梯度的便捷接口,但了解并控制其内部运作将为复杂任务带来更大的灵活性和更高的效率。
本章将基于jax.grad的初步知识进行拓展。你将学习自动微分的两种基本模式:正向模式(雅可比向量 (vector)积,JVP)和反向模式(向量雅可比积,VJP)。我们将介绍如何直接使用jax.jvp和jax.vjp计算它们,这些构成了jax.grad及其他变换的构建块。
主要内容包括:
jax.custom_vjp和jax.custom_jvp为函数定义自定义微分规则,这可用于提高数值稳定性、优化性能或处理非JAX代码。lax.scan、lax.cond和lax.while_loop等控制流原语中如何运作。jax.lax.stop_gradient。在本章结束时,你将对JAX的自动微分系统有更全面的理解,并掌握在进阶场景中有效应用它的工具。
4.1 前向和反向模式自动微分回顾
4.2 雅可比向量积 (JVPs) 与 jax.jvp
4.3 向量-雅可比积 (VJPs) 与 jax.vjp
4.4 高阶导数
4.5 计算完整的雅可比矩阵和海森矩阵
4.6 使用 jax.custom_vjp 的自定义微分规则
4.7 jax.custom_jvp 自定义求导规则
4.8 通过控制流原语求导
4.9 处理不可微分函数
4.10 实践:实现自定义梯度