趋近智
requires_grad)backward()).grad)torch.nn 搭建模型torch.nn.Module 基类torch.nn 损失)torch.optim)torch.utils.data.Datasettorchvision.transforms)torch.utils.data.DataLoader尽管Autograd自动追踪操作并计算梯度的能力对模型训练不可或缺,但在某些情况下,这种追踪是不必要甚至不希望的。具体来说,在模型评估(推理)期间,或者当你执行不应影响梯度计算的操作时,追踪历史会消耗内存和计算资源,而没有任何益处。PyTorch提供了选择性禁用梯度追踪的方法。
torch.no_grad() 上下文管理器禁用代码块梯度追踪最常用且推荐的方式是使用 torch.no_grad() 上下文管理器。在此 with 块内执行的任何PyTorch操作都会表现得如同所有输入张量都不需要梯度,即使它们最初设置了 requires_grad=True。
import torch
# 示例张量
x = torch.randn(2, 2, requires_grad=True)
w = torch.randn(2, 2, requires_grad=True)
b = torch.randn(2, 2, requires_grad=True)
# 在no_grad上下文之外的操作
y = x * w + b
print(f"y.requires_grad: {y.requires_grad}") # 输出:y.requires_grad: True
print(f"y.grad_fn: {y.grad_fn}") # 输出:y.grad_fn: <AddBackward0 object at ...>
# 在no_grad上下文内执行操作
print("\n进入torch.no_grad()上下文:")
with torch.no_grad():
z = x * w + b
print(f" z.requires_grad: {z.requires_grad}") # 输出:z.requires_grad: False
print(f" z.grad_fn: {z.grad_fn}") # 输出:z.grad_fn: None
# 即使输入需要梯度,输出也不会
k = x * 5
print(f" k.requires_grad: {k.requires_grad}") # 输出:k.requires_grad: False
# 在上下文之外,如果输入需要梯度,追踪会恢复
print("\n退出torch.no_grad()上下文:")
p = x * w
print(f"p.requires_grad: {p.requires_grad}") # 输出:p.requires_grad: True
print(f"p.grad_fn: {p.grad_fn}") # 输出:p.grad_fn: <MulBackward0 object at ...>
如示例所示,with torch.no_grad(): 块内的操作会产生 requires_grad=False 且没有关联 grad_fn 的输出(z、k),这表明它们已脱离计算图历史。这正是你在评估循环中想要的结果:
# 评估循环片段
model.eval() # 将模型设置为评估模式(对dropout、batchnorm等层很重要)
total_loss = 0
correct_predictions = 0
with torch.no_grad(): # 禁用评估期间的梯度计算
for inputs, labels in validation_dataloader:
inputs, labels = inputs.to(device), labels.to(device) # 将数据移动到相应的设备
outputs = model(inputs) # 前向传播
loss = criterion(outputs, labels) # 计算损失
total_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
correct_predictions += (predicted == labels).sum().item()
# 计算平均损失和准确率...
.detach() 方法另一种阻止特定张量进行梯度追踪的方法是使用 .detach() 方法。此方法会创建一个新张量,它与原始张量共享底层数据存储,但明确地脱离了当前计算图。它将拥有 requires_grad=False。
import torch
# 需要梯度的原始张量
a = torch.randn(3, 3, requires_grad=True)
b = a * 2
print(f"b.requires_grad: {b.requires_grad}") # 输出:b.requires_grad: True
print(f"b.grad_fn: {b.grad_fn}") # 输出:b.grad_fn: <MulBackward0 object at ...>
# 分离张量 'b'
c = b.detach()
print(f"\n分离b以创建c后:")
print(f"c.requires_grad: {c.requires_grad}") # 输出:c.requires_grad: False
print(f"c.grad_fn: {c.grad_fn}") # 输出:c.grad_fn: None
# 重要的是,原始张量 'b' 未改变
print(f"\n原始张量 b 仍保持连接:")
print(f"b.requires_grad: {b.requires_grad}") # 输出:b.requires_grad: True
print(f"b.grad_fn: {b.grad_fn}") # 输出:b.grad_fn: <MulBackward0 object at ...>
# 使用分离张量 'c' 的操作将不会被追踪
d = c + 1
print(f"\n在分离张量 c 上的操作:")
print(f"d.requires_grad: {d.requires_grad}") # 输出:d.requires_grad: False
何时使用 .detach() 与 torch.no_grad()?
torch.no_grad(),这通常用于推理或评估代码段。为此目的,它通常更高效。.detach(),例如为了记录其值、在不应影响梯度的操作中使用它(如更新指标),或将其传递给期望非梯度追踪张量的函数,同时可能仍需要在其他地方使用原始张量的梯度历史。由于 .detach() 共享数据,原地修改分离的张量会影响原始张量,如果处理不当,这可能会对梯度计算产生影响。requires_grad你也可以直接原地修改张量的 requires_grad 属性,但与上下文管理器或 .detach() 相比,这种临时禁用方法通常不太常见。它通常用于定义你明确不希望训练的参数。
my_tensor = torch.randn(5, requires_grad=True)
print(f"初始requires_grad: {my_tensor.requires_grad}") # 输出:初始requires_grad: True
# 原地禁用梯度追踪
my_tensor.requires_grad_(False) # 注意下划线表示原地操作
print(f"requires_grad_(False)后: {my_tensor.requires_grad}") # 输出:requires_grad_(False)后: False
使用 torch.no_grad() 是进行高效推理和评估的标准做法,而 .detach() 在你需要将特定张量从梯度历史中隔离时,提供更细致的控制。了解何时以及如何禁用梯度追踪对于编写高效且正确的PyTorch代码非常重要,特别是在你进一步学习基础训练循环之后。
这部分内容有帮助吗?
torch.no_grad() 上下文管理器的官方文档,详细说明了其在推理和评估期间梯度跟踪的用法和影响。.detach() 方法的官方文档,解释了如何创建一个与计算图分离的新张量。© 2026 ApX Machine Learning用心打造