趋近智
通常,你会发现现有张量的结构不太适合后续计算步骤,尤其是在将数据送入特定的神经网络 (neural network)层时。PyTorch 提供了灵活的工具,可以在不改变底层数据元素本身的情况下,改变张量的形状或调整其维度。用于这些操作的主要方法是:view()、reshape() 和 permute()。
view() 和 reshape() 改变形状view() 和 reshape() 都允许你改变张量的维度,前提是总元素数量保持不变。它们在将多维张量展平后传递给线性层,或增加/移除大小为1的维度等任务中非常有用。
view()view() 方法返回一个新的张量,该张量与原始张量共享相同的底层数据,但具有不同的形状。它非常高效,因为它避免了数据复制。然而,view() 要求张量在内存中是连续的。连续张量是指其元素在内存中按维度顺序连续存储,没有间隙的张量。大多数新创建的张量是连续的,但某些操作(如切片或使用 t() 进行转置)会产生非连续张量。
我们来看一个例子:
import torch
# 创建一个连续张量
x = torch.arange(12) # tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
print(f"原始张量: {x}")
print(f"原始形状: {x.shape}")
print(f"是否连续? {x.is_contiguous()}")
# 使用 view() 重塑
y = x.view(3, 4)
print("\nview(3, 4) 后的张量:")
print(y)
print(f"新形状: {y.shape}")
print(f"与 x 共享存储吗? {y.storage().data_ptr() == x.storage().data_ptr()}") # 检查它们是否共享内存
print(f"y 是否连续? {y.is_contiguous()}")
# 尝试另一个视图
z = y.view(2, 6)
print("\nview(2, 6) 后的张量:")
print(z)
print(f"新形状: {z.shape}")
print(f"与 x 共享存储吗? {z.storage().data_ptr() == x.storage().data_ptr()}")
print(f"z 是否连续? {z.is_contiguous()}")
你可以在 view() 调用中对一个维度使用 -1,PyTorch 将根据总元素数量和其它维度的尺寸自动推断出该维度的正确尺寸。
# 使用 -1 进行推断
w = x.view(2, 2, -1) # 推断出最后一个维度为 3 (12 / (2*2) = 3)
print("\nview(2, 2, -1) 后的张量:")
print(w)
print(f"新形状: {w.shape}")
如果你尝试在非连续张量上调用 view(),你会得到一个 RuntimeError。
# view() 在非连续张量上失败的例子
a = torch.arange(12).view(3, 4)
b = a.t() # 转置操作会创建一个非连续张量
print(f"\nb 是否连续? {b.is_contiguous()}")
try:
c = b.view(12)
except RuntimeError as e:
print(f"\n尝试 b.view(12) 时出错: {e}")
reshape()reshape() 方法的行为类似于 view(),但提供了更多灵活性。如果张量对于目标形状是连续的,它会尝试返回一个视图。如果无法返回视图(例如,因为原始张量在与新形状兼容的方式上不是连续的),reshape() 将把数据复制到一个新的、具有所需形状的连续张量中。这使得 reshape() 通常更安全、更通用,尽管如果发生复制,性能可能会降低。
我们再次查看使用 reshape() 的转置例子:
# 在非连续张量 'b' 上使用 reshape()
print(f"\n原始非连续张量 b:\n{b}")
print(f"b 的形状: {b.shape}")
print(f"b 是否连续? {b.is_contiguous()}")
# 即使 'b' 不连续,reshape 也能工作
c = b.reshape(12)
print(f"\nb.reshape(12) 后的张量 c:\n{c}")
print(f"c 的形状: {c.shape}")
print(f"c 是否连续? {c.is_contiguous()}")
# 检查 'c' 是否与 'b' 共享存储。由于 reshape 可能进行了复制,所以它们很可能不共享。
print(f"与 b 共享存储吗? {c.storage().data_ptr() == b.storage().data_ptr()}")
# reshape 也可以用 -1 推断维度
d = b.reshape(2, -1) # 推断出最后一个维度为 6
print(f"\nb.reshape(2, -1) 后的张量 d:\n{d}")
print(f"d 的形状: {d.shape}")
何时使用哪个方法?
view()。如果连续性假设有误,请准备好处理可能的 RuntimeError。reshape() 适用于连续和非连续张量。如果可能,它会返回一个视图,否则会创建一个副本。除非性能绝对关键且你能保证连续性,否则这通常是优选方法。permute() 调整维度顺序view() 和 reshape() 通过重新安排元素在维度间的解释方式来改变形状,而 permute() 则明确地交换维度本身。它不改变总元素数量,也不改变每个轴上元素数量方面的形状,但它改变的是哪个轴对应哪个原始维度。
假设你有一个图像数据,存储为(通道,高,宽)格式,但为了特定的库或可视化需求,需要其格式为(高,宽,通道)。permute() 就是为此而设计的工具。你将所需的维度顺序作为参数 (parameter)提供。
# 创建一个三维张量(例如,表示通道、高、宽)
image_tensor = torch.randn(3, 32, 32) # 通道,高,宽
print(f"原始形状: {image_tensor.shape}") # torch.Size([3, 32, 32])
# 调整为(高,宽,通道)
permuted_tensor = image_tensor.permute(1, 2, 0) # 指定新顺序:维度 1,维度 2,维度 0
print(f"调整后的形状: {permuted_tensor.shape}") # torch.Size([32, 32, 3])
# permute 通常返回一个非连续的视图
print(f"permuted_tensor 是否连续? {permuted_tensor.is_contiguous()}")
# 调回原状
original_again = permuted_tensor.permute(2, 0, 1) # 回到通道,高,宽
print(f"调回后的形状: {original_again.shape}") # torch.Size([3, 32, 32])
print(f"original_again 是否连续? {original_again.is_contiguous()}") # (可能仍然是非连续的)
# 检查存储共享
print(f"与原始张量共享存储吗? {original_again.storage().data_ptr() == image_tensor.storage().data_ptr()}")
和 view() 一样,permute() 返回一个与原始张量共享底层数据的张量。它不复制数据。然而,生成的张量通常不是连续的。如果你在调换维度后需要一个连续张量(例如,为了后续使用 view()),你可以链式调用 .contiguous() 方法:
# 使调整维度的张量连续
contiguous_permuted = permuted_tensor.contiguous()
print(f"\ncontiguous_permuted 是否连续? {contiguous_permuted.is_contiguous()}")
# 现在可以安全地使用 view()
flattened_permuted = contiguous_permuted.view(-1)
print(f"展平后的形状: {flattened_permuted.shape}")
掌握 view()、reshape() 和 permute() 让你能够精确控制张量的结构,这是将数据适配到不同 PyTorch 操作和模型层要求所需的一项必备技能。请记住这些权衡:view() 速度快但要求连续性,reshape() 灵活但可能会复制,而 permute() 交换维度而不复制,但通常会产生非连续张量。
这部分内容有帮助吗?
view()方法在张量重塑中的用法、对张量连续性的要求以及内存共享机制。reshape()方法在重塑连续和非连续张量时的灵活性,以及其在数据复制方面的行为。permute()方法如何交换张量维度,并阐明了其对张量形状和连续性的影响。© 2026 ApX Machine LearningAI伦理与透明度•