趋近智
熟练运用 JAX 的进阶方法,以达成高性能机器学习。本课程内容包含 JAX 核心机制、针对 GPU 和 TPU 的性能优化策略、使用 pmap 进行分布式计算、进阶自动微分、自定义操作,以及大型模型训练技术。运用 JAX 的函数式编程和编译能力,构建复杂且高效的机器学习系统。
先修课程 Python, 机器学习入门, JAX 入门
级别:
性能优化
对 JAX 代码进行性能分析和优化,以在 GPU 和 TPU 等加速器上获得最佳性能。
分布式计算
实现数据并行和模型并行,运用 pmap 及其他分布式基本操作以进行多设备执行。
进阶变换
使用复杂的控制流基本操作(如 scan, cond, while_loop),并理解它们与 JAX 变换的相互关系。
自定义自动微分规则
定义自定义的向量-雅可比乘积(VJP)和雅可比-向量乘积(JVP),用于非标准操作。
JAX 核心机制
透彻理解 JAX 的编译过程 (XLA) 及其内部表示 (jaxprs)。
大型训练
应用 JAX 进阶模式和库(如 Flax 或 Haiku),以高效训练大型神经网络。
本课程没有先修课程。
目前没有推荐的后续课程。
登录以撰写评论
分享您的反馈以帮助其他学习者。
© 2026 ApX Machine Learning用心打造