JAX Control Flow Primitives, JAX core contributors, 2024 - Official documentation explaining how JAX's automatic differentiation interacts with lax.cond, lax.while_loop, and lax.scan, including tracing behavior and implications.
Training Deep Nets with Sublinear Memory Cost, Tianqi Chen, Bing Xu, Chiyuan Zhang, Carlos Guestrin, 2016arXiv preprint arXiv:1604.06174DOI: 10.48550/arXiv.1604.06174 - Introduces a practical application of gradient checkpointing to reduce memory consumption during the training of deep neural networks, directly relevant to managing memory in long lax.while_loop and lax.scan operations.