趋近智
你已经使用过 JAX 的核心变换:jit 用于编译,grad 用于微分,以及 vmap 用于向量 (vector)化。尽管这些是核心要素,但许多精密的模型需要更复杂的计算结构,例如循环步骤、条件执行或动态循环。
本章介绍 JAX 的函数式控制流原语,它们使你能够在 JAX 可追踪和可编译的框架内表达这些复杂的模式。我们将探讨以下内容:
lax.scan 有效实现序列操作,例如循环神经网络 (neural network) (RNN) 中的操作。lax.cond 在编译函数内部实现条件逻辑。lax.while_loop 处理动态迭代。jit、grad 和 vmap 配合。学完本章,你将能够构建和分析包含循环、条件和序列依赖的 JAX 函数,为你构建更复杂的模型和算法做好准备。
1.1 核心转换回顾:jit, grad, vmap
1.2 精通 lax.scan 处理序列操作
1.3 使用 lax.cond 进行条件执行
1.4 使用 lax.while_loop 进行循环
1.5 结合控制流与变换
1.6 高级掩码技术
1.7 了解闭包和JAX暂存
1.8 实践:实现复杂的循环逻辑