趋近智
学习 JAX,用于高性能数值计算和机器学习 (machine learning)研究。本课程涵盖 JAX 核心知识,包括其 NumPy 接口、jit、grad、vmap 和 pmap 等函数变换,以及用于状态管理的函数式编程模式。您将获得适配现代硬件(GPU/TPU)来加速和求导 Python 代码的实践经验。
先修课程 熟悉 Python 和 NumPy。
级别:
JAX 基础知识
理解 JAX 的核心知识、它与 NumPy 的关联及其函数式编程方法。
函数变换
应用 JAX 的主要变换:jit 用于编译,grad 用于自动求导,vmap 用于向量化,pmap 用于并行化。
高性能代码
编写能高效运用 GPU 和 TPU 等现代加速器的 JAX 代码。
自动求导
使用 grad 自动计算 Python 函数的梯度。
状态管理
使用适合 JAX 的函数式编程模式实现有状态计算。
调试与性能分析
识别调试 JAX 代码时遇到的常见问题和基本方法。