趋近智
jax.jitjitjit 的常见问题jitgrad 进行自动微分jax.gradgrad的grad)jax.value_and_grad)vmap 实现自动向量化jax.vmapin_axes,out_axes)vmapvmap 与 jit 和 gradvmap的性能考量pmap 在多设备上并行计算jax.pmapin_axes, out_axes)lax.psum、lax.pmean等)pmap 与其他变换结合使用pmap 化的函数grad 进行自动微分许多计算任务,特别是在机器学习模型训练中,需要计算函数的梯度。自动微分提供了一种高效的方法来计算这些导数。本章将介绍 jax.grad,它是JAX的核心变换,用于从处理数值输入的Python代码中获取梯度函数。
您将学习如何:
jax.grad 计算标量值函数 f 的梯度 ∇f(x)。grad 所使用的反向模式自动微分的基本原理。grad 计算高阶导数。jax.value_and_grad 高效地同时获取函数的输出值及其梯度。到本章结束时,您将能够有效使用 jax.grad 在JAX框架内对您的数值函数进行求导。
3.1 理解梯度
3.2 介绍 `jax.grad`
3.3 自动微分的工作方式:反向模式
3.4 关于参数求导
3.5 高阶导数(`grad`的`grad`)
3.6 值和梯度 (`jax.value_and_grad`)
3.7 求导与控制流
3.8 局限性与注意事项
3.9 动手实践:计算梯度
© 2026 ApX Machine Learning用心打造