趋近智
jax.jitjitjit 的常见问题jitgrad 进行自动微分jax.gradgrad的grad)jax.value_and_grad)vmap 实现自动向量化jax.vmapin_axes,out_axes)vmapvmap 与 jit 和 gradvmap的性能考量pmap 在多设备上并行计算jax.pmapin_axes, out_axes)lax.psum、lax.pmean等)pmap 与其他变换结合使用pmap 化的函数vmap 实现自动向量化在数值计算和机器学习中,您经常需要同时对多个数据点应用同一个函数,这个过程通常被称为批处理。尽管您可以编写显式循环或依赖于为批处理操作设计的函数,但这些方法有时会很冗长,或者需要手动仔细管理数组维度。
JAX 提供了一个转换,jax.vmap,专门用于自动向量化。它允许您将一个为处理单个数据点而编写的函数,有效地应用到整个批次(或多个批次)的数据上,通常无需重写原函数的逻辑。vmap 相当于自动为您的计算添加了一个“批次维度”。
在这一章,您将学习:
jax.vmap 对处理单个或多个参数的函数进行向量化。in_axes 和 out_axes 参数控制哪些轴被映射。vmap 来应对更复杂的情形。vmap 如何与 jit 和 grad 等 JAX 的其他转换配合。vmap 的一些重要事项。学完本章后,您将能够使用 vmap 来让 JAX 中的批处理代码更简洁,并通常运行得更快。
4.1 向量化的原理
4.2 介绍 `jax.vmap`
4.3 对特定参数进行映射(`in_axes`,`out_axes`)
4.4 处理多个批处理参数
4.5 嵌套 `vmap`
4.6 结合 `vmap` 与 `jit` 和 `grad`
4.7 `vmap`的性能考量
4.8 动手实践:函数向量化
© 2026 ApX Machine Learning用心打造