趋近智
jax.jitjitjit 的常见问题jitgrad 进行自动微分jax.gradgrad的grad)jax.value_and_grad)vmap 实现自动向量化jax.vmapin_axes,out_axes)vmapvmap 与 jit 和 gradvmap的性能考量pmap 在多设备上并行计算jax.pmapin_axes, out_axes)lax.psum、lax.pmean等)pmap 与其他变换结合使用pmap 化的函数pmap 在多设备上并行计算现代硬件通常包含多个加速器,例如 GPU 或 TPU。尽管 jax.jit 优化单设备代码,且 jax.vmap 高效处理批处理,但它们本身并不会将计算分布到*多个*设备上。本章介绍 jax.pmap(并行映射),它是 JAX 用于将计算分布到不同设备上并行运行的函数变换。
您将学到:
pmap 所采用的执行模型。jax.pmap 在多个设备上同时执行相同的函数,每个设备处理不同的数据切片。in_axes 和 out_axes 等参数将数据分发到设备并从设备收集的方法。pmap 转换的函数中,集体操作(例如使用 jax.lax 原语在设备间求和或求平均)的用法。pmap 如何与其他 JAX 变换(例如 jit 和 grad)结合使用。学完本章后,您将能够在多设备系统上应用 pmap 为您的 JAX 程序实现数据并行。
5.1 数据并行 (SPMD) 介绍
5.2 介绍 `jax.pmap`
5.3 将数据映射到设备 (`in_axes`, `out_axes`)
5.4 设备网格与轴名称
5.5 集体操作(`lax.psum`、`lax.pmean`等)
5.6 将 `pmap` 与其他变换结合使用
5.7 调试 `pmap` 化的函数
5.8 动手实践:并行计算
© 2026 ApX Machine Learning用心打造