趋近智
requires_grad)backward()).grad)torch.nn 搭建模型torch.nn.Module 基类torch.nn 损失)torch.optim)torch.utils.data.Datasettorchvision.transforms)torch.utils.data.DataLoader张量具有数据类型,通常称为 dtype。数据类型决定了张量可以存储的数值种类(例如整数或浮点数)以及每个元素占用多少内存。选择合适的数据类型对于管理计算资源和确保深度学习模型中的数值精度非常重要。
PyTorch 支持多种数值数据类型,与 NumPy 中的类似。每种类型都有不同的用途,平衡了内存使用、计算速度以及可表示数字的范围或精度。
dtype每个张量都有一个 dtype 属性,用于指定其元素的类型。默认情况下,PyTorch 创建浮点张量时使用 torch.float32,整数张量时使用 torch.int64。你可以这样检查张量的数据类型:
import torch
# 默认浮点张量
a = torch.tensor([1.0, 2.0, 3.0])
print(f"Tensor a: {a}")
print(f"dtype of a: {a.dtype}")
# 默认整数张量
b = torch.tensor([1, 2, 3])
print(f"\nTensor b: {b}")
print(f"dtype of b: {b.dtype}")
输出:
Tensor a: tensor([1., 2., 3.])
dtype of a: torch.float32
Tensor b: tensor([1, 2, 3])
dtype of b: torch.int64
在创建张量时,你也可以明确指定 dtype:
# 创建一个64位浮点数张量
c = torch.tensor([1.0, 2.0], dtype=torch.float64)
print(f"\nTensor c: {c}")
print(f"dtype of c: {c.dtype}")
# 创建一个32位整数张量
d = torch.ones(2, 2, dtype=torch.int32)
print(f"\nTensor d:\n{d}")
print(f"dtype of d: {d.dtype}")
输出:
Tensor c: tensor([1., 2.], dtype=torch.float64)
dtype of c: torch.float64
Tensor d:
tensor([[1, 1],
[1, 1]], dtype=torch.int32)
dtype of d: torch.int32
你可以使用 torch.get_default_dtype() 查看 PyTorch 默认使用的浮点类型。
以下是 PyTorch 中一些最常用的数据类型:
浮点类型:
torch.float32 (或 torch.float):标准的32位单精度浮点数。由于它在CPU和GPU上兼顾了精度和性能,因此是模型参数和一般计算中最常见的类型。torch.float64 (或 torch.double):64位双精度浮点数。提供更高的精度,但占用两倍内存,并且速度可能明显较慢,尤其是在未针对双精度进行优化的GPU上。当绝对需要高数值精度时使用它。torch.float16 (或 torch.half):16位半精度浮点数。占用更少内存,并可在现代GPU(如NVIDIA Tensor Cores)上显著加快计算速度。然而,其有限的范围和精度有时可能导致数值不稳定(溢出或下溢)。常用于混合精度训练。torch.bfloat16:一种替代的16位格式(脑浮点)。它与 float32 具有相似的范围,但精度较低。它正变得越来越受兼容硬件(例如,较新的NVIDIA GPU、Google TPU)上深度学习训练的欢迎,因为它提供了内存节省和速度提升,同时通常比 float16 保持更好的稳定性。整数类型:
torch.int64 (或 torch.long):64位有符号整数。默认整数类型。常用于张量索引和分类任务中表示类别标签。torch.int32 (或 torch.int):32位有符号整数。torch.int16:16位有符号整数。torch.int8:8位有符号整数。较小的整数类型可以节省内存,并对某些操作更快,常用于模型量化。torch.uint8)。布尔类型:
torch.bool:表示布尔值 True 或 False。对于逻辑操作、使用掩码进行索引以及条件逻辑非常重要。你经常需要将张量从一种数据类型转换为另一种。这就是所谓的类型转换。转换张量的主要方式是使用 .to() 方法,该方法我们在张量在设备(CPU/GPU)之间移动的背景下也见过。
float_tensor = torch.tensor([1.1, 2.2, 3.3], dtype=torch.float32)
print(f"Original tensor: {float_tensor}, dtype: {float_tensor.dtype}")
# 使用 .to() 转换为 int64
int_tensor = float_tensor.to(torch.int64)
print(f"Casted to int64: {int_tensor}, dtype: {int_tensor.dtype}") # 注意截断
# 使用 .to() 转换回 float16
half_tensor = int_tensor.to(dtype=torch.float16) # 可以只指定 dtype
print(f"Casted to float16: {half_tensor}, dtype: {half_tensor.dtype}")
输出:
Original tensor: tensor([1.1000, 2.2000, 3.3000]), dtype: torch.float32
Casted to int64: tensor([1, 2, 3]), dtype: torch.int64
Casted to float16: tensor([1., 2., 3.], dtype=torch.float16), dtype: torch.float16
注意,从浮点数转换为整数会截断小数部分。
PyTorch 还提供了便捷方法来进行常见的类型转换:
tensor_a = torch.tensor([0, 1, 0, 1])
# 使用 .float() 转换为浮点数
tensor_b = tensor_a.float() # 等同于 .to(torch.float32)
print(f"\n.float(): {tensor_b}, dtype: {tensor_b.dtype}")
# 使用 .long() 转换为长整型
tensor_c = tensor_b.long() # 等同于 .to(torch.int64)
print(f".long(): {tensor_c}, dtype: {tensor_c.dtype}")
# 使用 .bool() 转换为布尔型
tensor_d = tensor_a.bool() # 等同于 .to(torch.bool)
print(f".bool(): {tensor_d}, dtype: {tensor_d.dtype}")
输出:
.float(): tensor([0., 1., 0., 1.]), dtype: torch.float32
.long(): tensor([0, 1, 0, 1]), dtype: torch.int64
.bool(): tensor([False, True, False, True]), dtype: torch.bool
请记住,类型转换通常会在内存中创建一个具有指定数据类型的新张量,而不是就地修改原始张量。
当你对不同数据类型的张量执行操作时,PyTorch 通常会自动提升类型以确保兼容性。一般规则是,整数类型与浮点类型进行操作时,结果将是浮点类型。不同浮点类型之间的操作通常会得到更高精度的类型。
int_t = torch.tensor([1, 2], dtype=torch.int32)
float_t = torch.tensor([0.5, 0.5], dtype=torch.float32)
double_t = torch.tensor([0.1, 0.1], dtype=torch.float64)
# int32 + float32 -> float32
result1 = int_t + float_t
print(f"\nint32 + float32 = {result1}, dtype: {result1.dtype}")
# float32 + float64 -> float64
result2 = float_t + double_t
print(f"float32 + float64 = {result2}, dtype: {result2.dtype}")
输出:
int32 + float32 = tensor([1.5000, 2.5000]), dtype: torch.float32
float32 + float64 = tensor([0.6000, 0.6000], dtype=torch.float64), dtype: torch.float64
虽然方便,但要注意自动类型提升,因为它如果未被预料到,可能会导致意外结果或性能影响。使用 .to() 进行显式类型转换可以让你对计算中所需的数据类型有更清晰的控制。
理解和管理张量数据类型是高效 PyTorch 编程的重要组成部分。它能让你控制内存占用,运用硬件加速(如GPU上的FP16),并为你的特定深度学习任务保持必要的数值精度。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造