趋近智
PyTorch 自动微分功能的主要组成部分是计算图。它并非预先定义的静态结构;相反,PyTorch 在对张量执行操作时动态地构建它。可以将其看作一个有向无环图(DAG),其中节点代表张量或操作,边代表数据流和功能依赖关系。
理解这个图是根本所在,因为 autograd 引擎在反向传播过程中正是遍历它来使用链式法则计算梯度。涉及跟踪梯度的张量的每个操作,都在幕后有助于构建这个图结构。
像 TensorFlow 1.x 或 Theano 这样的框架采用的是静态计算图。在这些系统中,你首先定义整个图结构,编译它,然后用不同的输入数据执行它,可能多次执行。这种“先定义后运行”的方法在执行前允许进行重要的图级别优化。
相反,PyTorch 采用的是动态计算图方法,通常被称为“边运行边定义”。该图是隐式地、逐个操作地构建的,随着你的 Python 代码执行而形成。如果你的模型前向传播中包含循环或条件语句(例如 if 块),图结构实际上可以根据所采取的执行路径在不同迭代之间发生变化。
动态图的优点:
pdb 或打印语句)来检查中间值或图连接性。权衡:
尽管极其灵活,但边运行边定义的特性可能对某些在静态图环境中更简单的整体图优化带来挑战。然而,PyTorch 通过 TorchScript(第 4 章介绍)等工具弥补了这一点,这些工具允许图捕获和优化。
grad_fn 属性PyTorch 实际如何跟踪操作以构建这个图?当你对一个 requires_grad=True 的张量执行操作时,生成的输出张量会自动获得对其创建函数的引用。
这个引用存储在输出张量的 grad_fn 属性中。
让我们用一个简单例子来说明:
import torch
# 需要梯度的输入张量
a = torch.tensor([2.0, 3.0], requires_grad=True)
# 操作 1: 乘以 3
b = a * 3
# 操作 2: 计算均值
c = b.mean()
# 检查 grad_fn 属性
print(f"Tensor a: requires_grad={a.requires_grad}, grad_fn={a.grad_fn}")
# 预期输出: 张量 a: requires_grad=True, grad_fn=None
print(f"Tensor b: requires_grad={b.requires_grad}, grad_fn={b.grad_fn}")
# 预期输出: 张量 b: requires_grad=True, grad_fn=<MulBackward0 object at 0x...>
print(f"Tensor c: requires_grad={c.requires_grad}, grad_fn={c.grad_fn}")
# 预期输出: 张量 c: requires_grad=True, grad_fn=<MeanBackward0 object at 0x...>
请注意以下几点:
a 是图中的一个叶节点。它是用户直接创建的,而非 autograd 跟踪的操作结果。因此,它的 grad_fn 为 None。b 是由 a 乘以 3 产生的。它的 grad_fn 指向 MulBackward0,代表乘法操作。这个对象持有对乘法输入(张量 a 和标量 3)的引用,并且知道如何计算对 a 的梯度。c 是由对 b 进行 mean 操作产生的。它的 grad_fn 指向 MeanBackward0,它知道如何计算对它的输入 b 的梯度。这些 grad_fn 引用形成了一个链表,从输出张量(c)经过操作(MeanBackward0,MulBackward0)向后追溯到输入叶张量(a)。这个链式结构就是 autograd 使用的反向计算图。
尽管 PyTorch 不提供像 TensorBoard 为静态图提供的图视图那样的内置实时图可视化工具,但我们可以将前面例子中构建的图进行可视化。前向传播创建张量并关联 grad_fn 对象。反向传播(c.backward())反向遍历这个结构。
c = (a * 3).mean()的计算图表示。矩形是张量,椭圆形是操作。边显示数据流。grad_fn将创建的张量链接到它们的生成操作,从而形成反向路径。
当你对一个标量张量(像我们例子中的 c,或通常是一个损失值)调用 .backward() 时,autograd 引擎会从该张量开始向后遍历图。
grad_fn 关联的函数(c 的 MeanBackward0)。c)对其输入(b)的梯度。grad_fn 对象。因此,为 b 计算的梯度被传递给 MulBackward0。MulBackward0 计算对其输入(a)的梯度。a 是叶节点(grad_fn 为 None)并且 requires_grad=True,计算出的梯度会累积在 a.grad 中。这个过程一直持续,直到所有路径都到达叶节点或不需要梯度的张量。计算图为链式法则的这种应用提供了路线图。
理解计算图不仅仅是理论性的。它告诉你如何构建模型、调试梯度问题(例如,None 梯度通常意味着图的一部分已断开连接或不需要梯度),以及如何实现带有自己反向传播的自定义操作,正如我们将在本章后面看到的那样。它是使 PyTorch 自动微分得以实现的看不见的机制。
这部分内容有帮助吗?
autograd引擎,详细介绍了动态计算图、grad_fn属性以及反向传播机制。© 2026 ApX Machine Learning用心打造