趋近智
jax.jit
jit
jit
的常见问题jit
grad
进行自动微分jax.grad
grad
的grad
)jax.value_and_grad
)vmap
实现自动向量化jax.vmap
in_axes
,out_axes
)vmap
vmap
与 jit
和 grad
vmap
的性能考量pmap
在多设备上并行计算jax.pmap
in_axes
, out_axes
)lax.psum
、lax.pmean
等)pmap
与其他变换结合使用pmap
化的函数先决条件: 熟悉 Python 和 NumPy。
级别:
JAX 基础知识
理解 JAX 的核心知识、它与 NumPy 的关联及其函数式编程方法。
函数变换
应用 JAX 的主要变换:jit
用于编译,grad
用于自动求导,vmap
用于向量化,pmap
用于并行化。
高性能代码
编写能高效运用 GPU 和 TPU 等现代加速器的 JAX 代码。
自动求导
使用 grad
自动计算 Python 函数的梯度。
状态管理
使用适合 JAX 的函数式编程模式实现有状态计算。
调试与性能分析
识别调试 JAX 代码时遇到的常见问题和基本方法。