趋近智
先决条件: Python, 机器学习入门, JAX 入门
级别:
性能优化
对 JAX 代码进行性能分析和优化,以在 GPU 和 TPU 等加速器上获得最佳性能。
分布式计算
实现数据并行和模型并行,运用 pmap
及其他分布式基本操作以进行多设备执行。
进阶变换
使用复杂的控制流基本操作(如 scan
, cond
, while_loop
),并理解它们与 JAX 变换的相互关系。
自定义自动微分规则
定义自定义的向量-雅可比乘积(VJP)和雅可比-向量乘积(JVP),用于非标准操作。
JAX 核心机制
透彻理解 JAX 的编译过程 (XLA) 及其内部表示 (jaxprs)。
大型训练
应用 JAX 进阶模式和库(如 Flax 或 Haiku),以高效训练大型神经网络。