趋近智
jax.jitjitjit 的常见问题jitgrad 进行自动微分jax.gradgrad的grad)jax.value_and_grad)vmap 实现自动向量化jax.vmapin_axes,out_axes)vmapvmap 与 jit 和 gradvmap的性能考量pmap 在多设备上并行计算jax.pmapin_axes, out_axes)lax.psum、lax.pmean等)pmap 与其他变换结合使用pmap 化的函数本章介绍 JAX,它是一个为高性能数值计算设计的 Python 库,尤其适用于机器学习研究。我们会首先说明 JAX 是什么,以及它在 Python 科学计算体系中的位置。
你将了解 JAX 的核心思想,重点说明它与 NumPy 的关系以及对函数转换的依赖。我们将比较 jax.numpy 与标准 NumPy 库,指出你需要注意的主要异同点,例如不可变性。我们还将介绍安装 JAX 的具体步骤,以及如何为 CPU、GPU 和 TPU 等不同硬件进行配置。
最后,你将亲手实践创建和操作 JAX 数组(这一核心数据结构),并学习 JAX 如何在可用硬件设备上管理计算。到本章结束时,你将对 JAX 的作用、其基本 API 有一个初步认识,并知道如何设置环境来开始使用它。
1.1 JAX 是什么?
1.2 JAX 对比 NumPy
1.3 核心设计理念:函数变换
1.4 安装与设置
1.5 使用 JAX 数组
1.6 设备管理:CPU、GPU、TPU
1.7 动手练习:基本数组操作
© 2026 ApX Machine Learning用心打造