ApX 标志

趋近智

JAX 入门
章节 1: JAX 简介
JAX 是什么?
JAX 对比 NumPy
核心设计理念:函数变换
安装与设置
使用 JAX 数组
设备管理:CPU、GPU、TPU
动手练习:基本数组操作
第 1 章测验
章节 2: 通过 JIT 编译加速函数
速度提升:为何需要编译?
介绍 jax.jit
JIT 工作原理:追踪与编译
Python 控制流与 jit
静态值与跟踪值
jit 的常见问题
动手实践:应用 jit
第 2 章测验
章节 3: 使用 grad 进行自动微分
理解梯度
介绍 jax.grad
自动微分的工作方式:反向模式
关于参数求导
高阶导数(gradgrad
值和梯度 (jax.value_and_grad)
求导与控制流
局限性与注意事项
动手实践:计算梯度
第 3 章测验
章节 4: 使用 vmap 实现自动向量化
向量化的原理
介绍 jax.vmap
对特定参数进行映射(in_axesout_axes
处理多个批处理参数
嵌套 vmap
结合 vmapjitgrad
vmap的性能考量
动手实践:函数向量化
第 4 章测验
章节 5: 使用 pmap 在多设备上并行计算
数据并行 (SPMD) 介绍
介绍 jax.pmap
将数据映射到设备 (in_axes, out_axes)
设备网格与轴名称
集体操作(lax.psumlax.pmean等)
pmap 与其他变换结合使用
调试 pmap 化的函数
动手实践:并行计算
第 5 章测验
章节 6: JAX 中的状态管理
函数纯粹性与副作用
函数式代码中的状态挑战
模式:显式状态传递
使用 PyTree 管理分层状态
示例:有状态计数器
例子:简单的优化器状态
将状态管理与变换结合
实践:实现有状态函数
第 6 章测验

测验

章节: JAX 简介

测试您对本章概念的理解并进行练习

测验说明

  • 此测验包含 10 道问题来帮助您练习。
  • 您需要至少获得 70% 的分数才能通过。
  • 尝试次数:无限制。
  • 将保留您的最高分数。
  • 请在没有帮助的情况下尝试此测验;但是,如果需要,您可以随时参考章节笔记或使用代码解释器。
  • 完成所有章节测验即可获得课程完成证书。 了解更多
问题格式

问题设计得引人入胜,侧重于理解、应用和解释,而不是死记硬背。期待基于场景的问题,测试您应用所学知识的能力。

尝试记录

最佳分数和测验尝试将显示在这里。

© 2025 趋近智