卷积自编码器(ConvAEs)非常适用于图像数据,通过使用卷积层和池化层来保持空间层级结构。构建一个 ConvAE 以从图像中提取特征。这个练习将使用流行的 MNIST 数据集,该数据集包含手写数字的灰度图像。这个实践应用将巩固您对 ConvAE 架构及其在特征学习中应用的理解。我们的目标是训练一个 ConvAE 来重构 MNIST 图像,然后使用其编码器部分将这些图像转换成低维特征表示。设置环境首先,请确保您已安装 PyTorch 和 Torchvision。如果您一直跟着课程学习,您的环境应该已经准备就绪。我们还将使用 NumPy 进行数值运算,并使用 Matplotlib 或 Plotly 进行可视化。对于此处嵌入的可视化,我们将准备 Plotly JSON。import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision.transforms import ToTensor, Resize # 如果您在notebook中运行,用于可视化: # import matplotlib.pyplot as plt # 对于 t-SNE: # from sklearn.manifold import TSNE1. 加载和预处理 MNIST 数据集MNIST 图像为 28x28 像素。对于 PyTorch 中的卷积层,我们需要 (通道, 高度, 宽度) 的格式。我们还会将像素值归一化到 [0, 1] 范围,这对于训练神经网络而言是很好的做法。torchvision 库使这变得简单。# 加载 MNIST 数据集并应用转换 transform = ToTensor() # 将图像转换为 PyTorch 张量并归一化到 [0, 1] train_dataset = MNIST(root='./data', train=True, download=True, transform=transform) test_dataset = MNIST(root='./data', train=False, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False) # 获取一个样本以检查形状 sample_data, _ = next(iter(train_loader)) print(f"样本批次形状: {sample_data.shape}") 您应该看到类似如下的输出: Sample batch shape: torch.Size([128, 1, 28, 28])2. 构建卷积编码器编码器的作用是将输入图像压缩成紧凑的潜在表示。它通常包含一系列 nn.Conv2d 层(用于学习特征),然后是 nn.MaxPool2d 层(用于下采样和降维)。让我们定义一个编码器,将 1x28x28 的输入图像映射到 64 维的潜在向量。latent_dim = 64 # 潜在空间的维度 class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) # -> 16x28x28 self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # -> 16x14x14 self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) # -> 32x14x14 self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # -> 32x7x7 self.flatten = nn.Flatten() # 展平后的尺寸是 32 * 7 * 7 = 1568 self.fc = nn.Linear(32 * 7 * 7, latent_dim) self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.conv1(x)) x = self.pool1(x) x = self.relu(self.conv2(x)) x = self.pool2(x) x = self.flatten(x) x = self.relu(self.fc(x)) return x # 实例化并打印编码器 encoder = Encoder() print(encoder)打印 encoder 将显示其架构和层。请注意,空间维度如何减小,而滤波器(特征)的数量可以增加,在被压缩成 latent_dim 向量之前,捕获更复杂的模式。padding=1 和 kernel_size=3 确保输出特征图与输入具有相同的空间维度(在池化之前),使架构设计更加直接。3. 构建卷积解码器解码器的任务与编码器相反:从潜在表示中重构原始图像。它通常镜像编码器的架构,但使用 nn.ConvTranspose2d 层来增加空间维度。class Decoder(nn.Module): def __init__(self): super(Decoder, self).__init__() # 密集层,用于从潜在维度上采样到展平前的大小 self.fc = nn.Linear(latent_dim, 32 * 7 * 7) # 重塑将在正向传播中使用 .view() 完成 self.convT1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2) # -> 16x14x14 self.convT2 = nn.ConvTranspose2d(16, 1, kernel_size=2, stride=2) # -> 1x28x28 self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.relu(self.fc(x)) x = x.view(-1, 32, 7, 7) # 重塑为 32x7x7 x = self.relu(self.convT1(x)) x = self.sigmoid(self.convT2(x)) # Sigmoid 用于 [0,1] 像素值 return x # 实例化并打印解码器 decoder = Decoder() print(decoder)带有 stride=2 的 nn.ConvTranspose2d 层能有效地在每一步将空间维度加倍。最后一层使用 sigmoid 激活,因为我们的输入图像被归一化到 0 到 1 之间。4. 组装自编码器并定义损失函数/优化器现在,我们将编码器和解码器组合成一个完整的自编码器模型。在 PyTorch 中,这是另一个 nn.Module,它按顺序调用编码器和解码器。我们还定义了损失函数和优化器。class Autoencoder(nn.Module): def __init__(self, encoder, decoder): super(Autoencoder, self).__init__() self.encoder = encoder self.decoder = decoder def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return decoded autoencoder = Autoencoder(encoder, decoder) print(autoencoder) # 定义损失函数和优化器 criterion = nn.BCELoss() # 用于像素级比较的二元交叉熵损失 optimizer = optim.Adam(autoencoder.parameters(), lr=1e-3) 我们使用 BCELoss 作为损失函数,它适合比较 0 到 1 之间的像素值(由于解码器最后一层中的 sigmoid 激活)。Adam 优化器是一个常用且有效的选择。一个图表可以帮助可视化这个架构:digraph G { rankdir=TB; node [shape=box, style="filled", fillcolor="#e9ecef", fontname="sans-serif"]; edge [fontname="sans-serif"]; subgraph cluster_encoder { label = "编码器"; style="dashed"; fillcolor="#f8f9fa"; InputImg [label="输入图像\n(1x28x28)", fillcolor="#a5d8ff"]; Conv1 [label="nn.Conv2d (16 filters, 3x3, ReLU)\n+ nn.MaxPool2d (2x2)", fillcolor="#74c0fc"]; Conv2 [label="nn.Conv2d (32 filters, 3x3, ReLU)\n+ nn.MaxPool2d (2x2)", fillcolor="#4dabf7"]; FlattenLayer [label="nn.Flatten\n(32x7x7 -> 1568)", fillcolor="#339af0"]; LatentVec [label="nn.Linear (瓶颈层)\n潜在向量 (64 维, ReLU)", fillcolor="#228be6"]; InputImg -> Conv1 -> Conv2 -> FlattenLayer -> LatentVec; } subgraph cluster_decoder { label = "解码器"; style="dashed"; fillcolor="#f8f9fa"; DenseDecode [label="nn.Linear (1568, ReLU)\n+ 重塑 (32x7x7)", fillcolor="#91a7ff"]; ConvT1 [label="nn.ConvTranspose2d (16 filters, 2x2, stride=2, ReLU)", fillcolor="#748ffc"]; ConvT2 [label="nn.ConvTranspose2d (1 filter, 2x2, stride=2, Sigmoid)\n重构图像 (1x28x28)", fillcolor="#5c7cfa"]; DenseDecode -> ConvT1 -> ConvT2; } LatentVec -> DenseDecode [label="潜在表示"]; }卷积自编码器架构。编码器将输入图像映射到低维潜在向量,而解码器尝试从该向量重构原始图像。5. 训练自编码器定义好模型、损失函数和优化器后,我们可以编写训练循环。自编码器学习重构其输入,因此输入图像既作为输入也作为目标。import torch import matplotlib.pyplot as plt from torchvision.transforms import Resize import numpy as np # 假设 autoencoder, test_loader 和 device 已在之前的设置中定义 # 在测试图像上进行预测 autoencoder.eval() # 将模型设置为评估模式 reconstructed_imgs = [] original_imgs = [] # 定义一个将图像大小调整为 16x16 像素的转换 resize_transform = Resize((16, 16)) with torch.no_grad(): for i, data in enumerate(test_loader): imgs, _ = data imgs = imgs.to(device) outputs = autoencoder(imgs) # 从第一个批次中存储前 5 张图像 if i == 0: # 将原始图像和重构图像调整为 16x16 original_imgs = resize_transform(imgs).cpu().numpy() reconstructed_imgs = resize_transform(outputs).cpu().numpy() break # 准备使用 Matplotlib 显示 n_display = 5 fig, axes = plt.subplots(2, n_display, figsize=(n_display * 2, 4)) # 根据需要调整 figsize for i in range(n_display): # 原始图像 axes[0, i].imshow(original_imgs[i, 0], cmap='Greys') axes[0, i].axis('off') # 关闭坐标轴 if i == 0: axes[0, i].set_title("Original Images", fontsize=12) # 重构图像 axes[1, i].imshow(reconstructed_imgs[i, 0], cmap='Greys') axes[1, i].axis('off') # 关闭坐标轴 if i == 0: axes[1, i].set_title("Reconstructed Images", fontsize=12) plt.tight_layout(rect=[0, 0, 1, 0.95]) # 调整布局以腾出标题空间 plt.suptitle("Original vs. Reconstructed Images", fontsize=16, y=1.0) # 总体标题 plt.show(){ "data": [ { "type": "heatmap", "z": [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.9, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ], "colorscale": "Greys", "showscale": false, "xaxis": "x1", "yaxis": "y1" }, { "type": "heatmap", "z": [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ], "colorscale": "Greys", "showscale": false, "xaxis": "x2", "yaxis": "y2" }, { "type": "heatmap", "z": [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ], "colorscale": "Greys", "showscale": false, "xaxis": "x3", "yaxis": "y3" }, { "type": "heatmap", "z": [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ], "colorscale": "Greys", "showscale": false, "xaxis": "x4", "yaxis": "y4" }, { "type": "heatmap", "z": [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ], "colorscale": "Greys", "showscale": false, "xaxis": "x5", "yaxis": "y5" }, { "type": "heatmap", "z": [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.4, 0.8, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ], "colorscale": "Greys", "showscale": false, "xaxis": "x6", "yaxis": "y6" }, { "type": "heatmap", "z": [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ], "colorscale": "Greys", "showscale": false, "xaxis": "x7", "yaxis": "y7" }, { "type": "heatmap", "z": [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ], "colorscale": "Greys", "showscale": false, "xaxis": "x8", "yaxis": "y8" }, { "type": "heatmap", "z": [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ], "colorscale": "Greys", "showscale": false, "xaxis": "x9", "yaxis": "y9" }, { "type": "heatmap", "z": [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ], "colorscale": "Greys", "showscale": false, "xaxis": "x10", "yaxis": "y10" } ], "layout": { "height": 350, "width": 700, "grid": { "rows": 2, "columns": 5, "pattern": "independent" }, "annotations": [ { "text": "原始图像", "showarrow": false, "xref": "paper", "yref": "paper", "x": 0.5, "y": 1.05, "yanchor": "bottom", "font": { "size": 16 } }, { "text": "重构图像", "showarrow": false, "xref": "paper", "yref": "paper", "x": 0.5, "y": 0.48, "yanchor": "bottom", "font": { "size": 16 } } ], "margin": { "l": 20, "r": 20, "t": 70, "b": 20 }, "xaxis1": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.0, 0.18] }, "yaxis1": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.55, 0.95] }, "xaxis2": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.2, 0.38] }, "yaxis2": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.55, 0.95] }, "xaxis3": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.4, 0.58] }, "yaxis3": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.55, 0.95] }, "xaxis4": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.6, 0.78] }, "yaxis4": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.55, 0.95] }, "xaxis5": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.8, 0.98] }, "yaxis5": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.55, 0.95] }, "xaxis6": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.0, 0.18] }, "yaxis6": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.0, 0.45] }, "xaxis7": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.2, 0.38] }, "yaxis7": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.0, 0.45] }, "xaxis8": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.4, 0.58] }, "yaxis8": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.0, 0.45] }, "xaxis9": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.6, 0.78] }, "yaxis9": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.0, 0.45] }, "xaxis10": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.8, 0.98] }, "yaxis10": { "showticklabels": false, "showgrid": false, "zeroline": false, "domain": [0.0, 0.45] } } }MNIST 原始测试图像(上排)与 ConvAE 重构图像(下排)的比较。重构的图像应该可识别,尽管可能比原始图像略模糊。重构的质量取决于模型架构、潜在维度大小和训练时长。更复杂的模型或更长的训练时间可能会产生更清晰的图像。7. 提取特征本次练习的主要目标是特征提取。我们训练好的自编码器的编码器部分现在可以用于将输入图像转换为其 latent_dim 维度的特征向量。# 使用训练好的编码器获取潜在表示(特征) encoder.eval() # 将编码器设置为评估模式 all_features = [] all_labels = [] # 对于完整数据集,遍历加载器 full_train_loader = DataLoader(train_dataset, batch_size=1024) full_test_loader = DataLoader(test_dataset, batch_size=1024) with torch.no_grad(): for data in full_train_loader: # 使用更大的批次大小进行推理 imgs, labels = data imgs = imgs.to(device) features = encoder(imgs) all_features.append(features.cpu().numpy()) all_labels.append(labels.numpy()) encoded_features_train = np.concatenate(all_features, axis=0) y_train = np.concatenate(all_labels, axis=0) # 对测试集重复此操作 all_features = [] all_labels = [] with torch.no_grad(): for data in full_test_loader: imgs, labels = data imgs = imgs.to(device) features = encoder(imgs) all_features.append(features.cpu().numpy()) all_labels.append(labels.numpy()) encoded_features_test = np.concatenate(all_features, axis=0) y_test = np.concatenate(all_labels, axis=0) print(f"训练特征的形状: {encoded_features_train.shape}") print(f"测试特征的形状: {encoded_features_test.shape}")这将输出: Shape of training features: (60000, 64) Shape of test features: (10000, 64)现在,每张图像都由一个 64 个数字的向量表示。这些特征由自编码器学习,用于捕获重构原始图像所需的核心信息。它们通常比原始像素值在语义上更具意义,可用于分类或聚类等后续任务。8. 可视化潜在空间(可选)为了了解自编码器如何在其潜在空间中组织数据,我们可以使用 t-SNE 等降维技术,将 64 维特征投影到 2 维,然后绘制出来,并按原始数字标签着色。# # 以下代码使用 scikit-learn 进行 t-SNE,将在 Python 环境中运行。 # # 对于完整数据集而言,它可能计算量较大,因此通常使用子集进行可视化。 # from sklearn.manifold import TSNE # import plotly.express as px # # 使用测试特征的子集进行 t-SNE(例如,前 5000 个样本) # num_samples_tsne = 5000 # tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=300) # latent_2d = tsne.fit_transform(encoded_features_test[:num_samples_tsne]) # # 创建 Plotly 散点图 # fig = px.scatter(x=latent_2d[:, 0], y=latent_2d[:, 1], color=y_test[:num_samples_tsne].astype(str), # # 标签={'color': '数字'}, 标题="MNIST 潜在空间的 t-SNE 可视化 (ConvAE 特征)") # # fig.show() # 在 notebook 中 # 对于静态显示,这里有一个 t-SNE 图的 Plotly JSON 结构示例。 # 这将填充实际的 t-SNE 结果。 tsne_plot_data = [] # 占位符:手动为 10 个类别创建一些样本点作为 Plotly JSON # 这仅是说明性的;真实的 t-SNE 将生成这些点。 sample_points = { 0: [[-5, -5], [-5.5, -4.5]], 1: [[5, 5], [5.5, 4.5]], 2: [[-5, 5], [-4.5, 5.5]], 3: [[5, -5], [4.5, -5.5]], 4: [[0, 0], [0.5, 0.5]], 5: [[-2, -2], [-1.5, -2.5]], 6: [[2, 2], [1.5, 2.5]], 7: [[-2, 2], [-2.5, 1.5]], 8: [[2, -2], [2.5, -1.5]], 9: [[0, 3], [0, 3.5]] } colors_map_plotly = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'] for digit_class in range(10): points = np.array(sample_points.get(digit_class, [])) if points.shape[0] > 0: tsne_plot_data.append({ "type": "scatter", "mode": "markers", "x": points[:,0].tolist(), "y": points[:,1].tolist(), "name": f"数字 {digit_class}", "marker": {"color": colors_map_plotly[digit_class], "size": 8} }){"data": [{"type": "scatter", "mode": "markers", "x": [-5, -5.5], "y": [-5, -4.5], "name": "数字 0", "marker": {"color": "#1f77b4", "size": 8}}, {"type": "scatter", "mode": "markers", "x": [5, 5.5], "y": [5, 4.5], "name": "数字 1", "marker": {"color": "#ff7f0e", "size": 8}}, {"type": "scatter", "mode": "markers", "x": [-5, -4.5], "y": [5, 5.5], "name": "数字 2", "marker": {"color": "#2ca02c", "size": 8}}, {"type": "scatter", "mode": "markers", "x": [5, 4.5], "y": [-5, -5.5], "name": "数字 3", "marker": {"color": "#d62728", "size": 8}}, {"type": "scatter", "mode": "markers", "x": [0, 0.5], "y": [0, 0.5], "name": "数字 4", "marker": {"color": "#9467bd", "size": 8}}, {"type": "scatter", "mode": "markers", "x": [-2, -1.5], "y": [-2, -2.5], "name": "数字 5", "marker": {"color": "#8c564b", "size": 8}}, {"type": "scatter", "mode": "markers", "x": [2, 1.5], "y": [2, 2.5], "name": "数字 6", "marker": {"color": "#e377c2", "size": 8}}, {"type": "scatter", "mode": "markers", "x": [-2, -2.5], "y": [2, 1.5], "name": "数字 7", "marker": {"color": "#7f7f7f", "size": 8}}, {"type": "scatter", "mode": "markers", "x": [2, 2.5], "y": [-2, -1.5], "name": "数字 8", "marker": {"color": "#bcbd22", "size": 8}}, {"type": "scatter", "mode": "markers", "x": [0, 0], "y": [3, 3.5], "name": "数字 9", "marker": {"color": "#17becf", "size": 8}}], "layout": {"title": "MNIST 潜在空间的示例性 t-SNE (ConvAE 特征)", "xaxis": {"title": "t-SNE 分量 1"}, "yaxis": {"title": "t-SNE 分量 2"}, "width": 600, "height": 500, "legend": {"title": {"text":"数字"}}}}ConvAE 学习到的潜在特征的示例性 t-SNE 可视化。理想情况下,对应于相同数字的点将聚集在一起,而不同数字将形成独立(或有些分离)的簇。如果自编码器学习得好,您应该会看到不同数字簇之间存在一些分离。这表明潜在特征捕获了关于数字类别的判别信息,即使自编码器纯粹是基于重构进行训练,没有任何标签信息。本次实践环节总结在本次会话中,您已成功:使用 PyTorch 的 torchvision 加载并预处理了 MNIST 图像数据集。在 PyTorch 中设计并构建了卷积编码器,将图像映射到潜在空间。在 PyTorch 中设计并构建了卷积解码器,从潜在表示重构图像。将它们组合成一个卷积自编码器,并使用 PyTorch 训练循环对其进行了训练。可视化了图像重构的质量。使用了训练好的编码器从图像中提取特征向量。可选地,可视化了这些特征在潜在空间中的结构。这些提取的特征,即 encoded_features_train 和 encoded_features_test,现在可用于各种后续机器学习任务,例如分类,我们将在第 7 章中进一步讨论。本次练习展示了 ConvAEs 从图像数据中学习紧凑且有用表示的能力。