ApX 标志

趋近智

JAX 入门
章节 1: JAX 简介
JAX 是什么?
JAX 对比 NumPy
核心设计理念:函数变换
安装与设置
使用 JAX 数组
设备管理:CPU、GPU、TPU
动手练习:基本数组操作
第 1 章测验
章节 2: 通过 JIT 编译加速函数
速度提升:为何需要编译?
介绍 jax.jit
JIT 工作原理:追踪与编译
Python 控制流与 jit
静态值与跟踪值
jit 的常见问题
动手实践:应用 jit
第 2 章测验
章节 3: 使用 grad 进行自动微分
理解梯度
介绍 jax.grad
自动微分的工作方式:反向模式
关于参数求导
高阶导数(gradgrad
值和梯度 (jax.value_and_grad)
求导与控制流
局限性与注意事项
动手实践:计算梯度
第 3 章测验
章节 4: 使用 vmap 实现自动向量化
向量化的原理
介绍 jax.vmap
对特定参数进行映射(in_axesout_axes
处理多个批处理参数
嵌套 vmap
结合 vmapjitgrad
vmap的性能考量
动手实践:函数向量化
第 4 章测验
章节 5: 使用 pmap 在多设备上并行计算
数据并行 (SPMD) 介绍
介绍 jax.pmap
将数据映射到设备 (in_axes, out_axes)
设备网格与轴名称
集体操作(lax.psumlax.pmean等)
pmap 与其他变换结合使用
调试 pmap 化的函数
动手实践:并行计算
第 5 章测验
章节 6: JAX 中的状态管理
函数纯粹性与副作用
函数式代码中的状态挑战
模式:显式状态传递
使用 PyTree 管理分层状态
示例:有状态计数器
例子:简单的优化器状态
将状态管理与变换结合
实践:实现有状态函数
第 6 章测验