趋近智
vmap 实现自动向量化在数值计算和机器学习 (machine learning)中,您经常需要同时对多个数据点应用同一个函数,这个过程通常被称为批处理。尽管您可以编写显式循环或依赖于为批处理操作设计的函数,但这些方法有时会很冗长,或者需要手动仔细管理数组维度。
JAX 提供了一个转换,jax.vmap,专门用于自动向量 (vector)化。它允许您将一个为处理单个数据点而编写的函数,有效地应用到整个批次(或多个批次)的数据上,通常无需重写原函数的逻辑。vmap 相当于自动为您的计算添加了一个“批次维度”。
在这一章,您将学习:
jax.vmap 对处理单个或多个参数 (parameter)的函数进行向量化。in_axes 和 out_axes 参数控制哪些轴被映射。vmap 来应对更复杂的情形。vmap 如何与 jit 和 grad 等 JAX 的其他转换配合。vmap 的一些重要事项。学完本章后,您将能够使用 vmap 来让 JAX 中的批处理代码更简洁,并通常运行得更快。
4.1 向量化的原理
4.2 介绍 `jax.vmap`
4.3 对特定参数进行映射(`in_axes`,`out_axes`)
4.4 处理多个批处理参数
4.5 嵌套 `vmap`
4.6 结合 `vmap` 与 `jit` 和 `grad`
4.7 `vmap`的性能考量
4.8 动手实践:函数向量化