趋近智
jax.jitjitjit 的常见问题jitgrad 进行自动微分jax.gradgrad的grad)jax.value_and_grad)vmap 实现自动向量化jax.vmapin_axes,out_axes)vmapvmap 与 jit 和 gradvmap的性能考量pmap 在多设备上并行计算jax.pmapin_axes, out_axes)lax.psum、lax.pmean等)pmap 与其他变换结合使用pmap 化的函数先修课程 熟悉 Python 和 NumPy。
级别:
JAX 基础知识
理解 JAX 的核心知识、它与 NumPy 的关联及其函数式编程方法。
函数变换
应用 JAX 的主要变换:jit 用于编译,grad 用于自动求导,vmap 用于向量化,pmap 用于并行化。
高性能代码
编写能高效运用 GPU 和 TPU 等现代加速器的 JAX 代码。
自动求导
使用 grad 自动计算 Python 函数的梯度。
状态管理
使用适合 JAX 的函数式编程模式实现有状态计算。
调试与性能分析
识别调试 JAX 代码时遇到的常见问题和基本方法。
本课程没有先修课程。
目前没有推荐的后续课程。
登录以撰写评论
分享您的反馈以帮助其他学习者。
© 2025 ApX Machine Learning用心打造