变分自编码器(VAE)依赖于两个主要神经网络 (neural network)组成部分:编码器,它近似后验分布 qϕ(z∣x);以及解码器,它对数据似然 pθ(x∣z) 进行建模。这些网络的设计并非固定方案,而是一系列架构选择,它们会明显影响 VAE 学习有意义表示和生成连贯数据的能力。正确设计这些网络对构建高效 VAE 十分重要。下面我们来考察这些网络的构建考量。
编码器网络:近似后验 qϕ(z∣x) 的参数 (parameter)化
编码器网络,通常表示为带有参数 ϕ 的 qϕ(z∣x),其主要职责是将输入数据点 x 映射到潜在空间中某个分布的参数。对于大多数 VAE,此潜在分布通常选择为高斯分布,这意味着编码器需要为每个输入 x 输出一个均值向量 (vector) μz 和一个方差向量 σz2(或更常见的是其对数,logσz2,以获得数值稳定性并确保正性)。
常见架构模式:
- 多层感知机(MLP): 对于不具有强空间或序列结构的数据,例如表格数据或扁平化的简单图像(如 MNIST),MLP 是一个直接的选择。典型的 MLP 编码器可能包含多个全连接层,逐渐降低维度,然后分成两个输出头以输出 μz 和 logσz2。
- 卷积神经网络 (neural network)(CNN): 处理图像数据时,CNN 是标准选择。它们捕捉局部模式和空间层级的能力,使得它们在提取相关特征方面非常有用。基于 CNN 的编码器通常包含一系列卷积层(通常增加滤波器数量,步长大于 1 或使用池化层以降低空间维度),随后是一个或多个全连接层,然后生成 μz 和 logσz2。
- 循环神经网络(RNN)或 Transformer: 对于文本或时间序列等序列数据,LSTMs、GRUs 或 Transformer 等架构被用作编码器来处理时间依赖性。这些内容将在第 6 章中进行更详尽的讨论。
编码器的设计要素:
- 深度和宽度: 层的数量(深度)和每层的单元数(宽度)决定了编码器的容量。更深或更宽的网络可以建模从 x 到潜在参数的更复杂映射。然而,容量过大可能导致过拟合 (overfitting)或增加训练难度。编码器通常呈“漏斗”形状,从输入空间到潜在空间逐渐降低维度。
- 激活函数 (activation function): 对于隐藏层,修正线性单元(ReLU)及其变体,如 LeakyReLU 或指数线性单元(ELU),是常见选择,因为它们在对抗梯度消失方面作用良好。生成 μz 的输出层通常使用线性激活。生成 logσz2 的层也使用线性激活;随后的重参数化步骤将使用此对数方差。
- 归一化 (normalization)层: 批归一化(BN)或层归一化(LN)可以用来稳定训练并可能允许使用更高的学习率。然而,它们与 VAE 的影响方式,特别是批归一化,需要仔细考量。BN 在批次中引入样本间的依赖性以进行统计计算,这有时可能干扰 VAE 重建和 KL 散度项的实例级性质。如果使用,它通常放置在卷积/线性层之后和激活函数之前。
- 潜在维度 (dz): 潜在空间 z 的维度是一个重要的超参数 (hyperparameter)。非常低的 dz 会形成过于限制性的信息瓶颈,导致重建质量差。非常高的 dz 可能导致压缩程度较低、潜在地解耦性较差的表示,甚至导致“后验坍塌”,即 qϕ(z∣x) 变得与先验 p(z) 非常相似,从而使潜在变量失去信息性(这在“常见 VAE 训练难题”中有更多讨论)。
解码器网络:对数据分布 pθ(x∣z) 进行建模
解码器网络 pθ(x∣z),带有参数 (parameter) θ,从潜在空间中获取样本 z(在训练时来自 qϕ(z∣x),在生成时来自先验 p(z)),并将其映射回原始数据 x 的分布参数。
常见架构模式:
在架构上,解码器通常与其对应的编码器结构相反:
- MLP: 如果编码器是 MLP,解码器通常也是 MLP,它接收 z 并逐渐增加维度,直至恢复 x 的维度。
- 转置 CNN: 对于图像数据,解码器使用转置卷积层(有时不精确地称为反卷积层)对潜在表示进行上采样,逐步增加空间维度并减少滤波器数量,直到达到原始图像的维度。
- RNN 或 Transformer: 对于序列数据,这些架构在 z 和先前生成元素的基础上,逐步生成序列。
解码器的设计要素:
- 输出层设计和数据似然: 这可以说是解码器设计中非常核心的部分,因为它直接决定了 ELBO 中重建损失项的形式。
- 高斯似然: 对于连续数据(例如,自然图像中的像素强度,通常归一化 (normalization)到 [0,1] 或 [−1,1]),高斯似然 pθ(x∣z)=N(x∣μx(z),σx2(z)) 是常见做法。
- 解码器的最终层输出 μx(z),通常使用线性激活(如果数据归一化到 [−1,1],则可能会使用
tanh 激活)。
- 方差 σx2 可以通过几种方式处理:
- 固定标量: 通常,σx2 假定为一个固定常数(例如,σx2=1)。在这种情况下,负对数似然(重建损失)简化为输入 x 与预测均值 μx(z) 之间的缩放均方误差(MSE)。
- 学习标量: 单个全局 σx2 可以作为 θ 的一部分进行学习。
- 逐维度/像素学习: 解码器可以拥有一个用于 logσx2(z) 的额外输出头,使模型能够预测 x 每个维度的不确定性。这更灵活但增加了复杂性。logσx2(z) 的输出激活将是线性的。
- 伯努利似然: 对于二值数据(例如,像素为 0 或 1 的二值化 MNIST 图像),每个维度 xi 都被建模为伯努利试验。解码器的输出层使用 sigmoid 激活来为每个维度生成概率 pi(z)∈[0,1]。重建损失则是 x 与这些概率之间的二元交叉熵(BCE)。
- 类别似然: 对于每个 xi 可以取 K 个类别中一个的离散数据(例如,量化 (quantization)彩色图像中的像素),使用类别分布。解码器使用 softmax 激活为每个类别输出概率。重建损失是类别交叉熵。
- 隐藏层激活: 与编码器类似,ReLU、LeakyReLU 或 ELU 是隐藏层的标准选择。
- 归一化层: BN 或 LN 也可以在解码器中使用,与编码器中有着类似的考量。
一般架构原则和考量
除了每个网络的具体细节,一些总体原则也指导着 VAE 的架构设计:
- 对称性(或缺乏对称性): 解码器与编码器大致对称是常见做法(例如,一个带有 N 个下采样层的 CNN 编码器可能会与一个带有 N 个上采样层的转置 CNN 解码器配对)。然而,严格对称并非必需。每个网络的复杂性应根据数据和具体任务进行调整。例如,如果生成质量非常重要,解码器可以做得比编码器更强大。
- 网络容量: 编码器和解码器都必须拥有足够的容量(深度、宽度、滤波器数量)来执行各自的任务。编码器容量不足会限制其捕捉 x 主要特征并将其映射到 z 的能力。解码器容量不足则会使其无法从 z 生成逼真的重建 xrecon。然而,过于复杂的网络可能更难训练,容易过拟合 (overfitting),并可能加剧后验坍塌等问题,尤其是在 KL 散度正则化 (regularization)未适当加权或优化困难的情况下。
- 权重 (weight)初始化: 采用标准的权重初始化方案,例如 Xavier/Glorot 初始化(用于 tanh 或 sigmoid 激活层)或 He 初始化(用于 ReLU 激活层),以促进训练期间稳定的梯度流。
- 正则化(KL 项的考量): 尽管 ELBO 中的 DKL(qϕ(z∣x)∣∣p(z)) 项已起到对潜在空间进行正则化的作用,但标准神经网络 (neural network)正则化器如 L2 权重衰减有时也可以应用于参数 (parameter) ϕ 和 θ。Dropout 也是一个选择,但它与批归一化 (normalization)以及 VAE 随机性之间的影响方式应仔细评估。在 VAE 中,它即使使用也往往是少量应用。
示例:基于 CNN 的图像 VAE
让我们可视化一个常见的图像 VAE 架构模式,其中编码器使用卷积层,解码器使用转置卷积层。
图像数据(例如 MNIST)的常见 VAE 架构。编码器采用卷积层进行下采样和特征提取,最终由全连接层输出潜在高斯分布的参数 (parameter) (μz,logσz2)。通过重参数化技巧采样 z 后,解码器使用全连接层,随后是转置卷积层来对 z 进行上采样,重建图像。最终激活函数 (activation function)(例如 MNIST 的 sigmoid)取决于假定的数据分布。
对 ELBO 和学习动态的影响
编码器和解码器的架构选择直接影响证据下界(ELBO)的两个项:重建项 Eqϕ(z∣x)[logpθ(x∣z)] 和 KL 散度项 DKL(qϕ(z∣x)∣∣p(z))。
- 表现力强的解码器可以实现较低的重建误差(ELBO 第一项的值较高)。然而,如果解码器相对于编码器或潜在空间的信息容量(由 KL 项正则化 (regularization))过于强大,它可能学会忽略 z 仍能产生不错的重建,特别是对于简单数据集。这可能导致后验坍塌,即 qϕ(z∣x) 变得与先验 p(z) 非常接近(使 KL 项接近零),从而使得 z 对 x 不提供信息。
- 编码器的设计决定了 qϕ(z∣x) 能多大程度地近似真实的(但难以处理的)后验 p(z∣x)。能力有限的编码器可能导致 ELBO 宽松,这意味着界限不紧密,模型可能难以良好学习。
在两个网络的容量和表现力之间找到平衡点,并仔细进行超参数 (parameter) (hyperparameter)调整(包括 KL 项的权重 (weight),在 β-VAE 中通常表示为 β,在第 3 章中有讨论),对 VAE 的成功训练不可或缺。
高级网络设计预览
虽然这里讨论的架构构成了许多 VAE 的主干,但可以集成更复杂的组件以获得更好的性能或处理更复杂的数据:
- 残差连接: ResNet 等模块可以帮助训练更深层的编码器和解码器,缓解梯度消失/爆炸问题。
- 注意力机制 (attention mechanism): 特别是对于序列或高分辨率图像数据,注意力机制可以使解码器选择性地关注潜在代码或编码器特征的相关部分。(更多内容见第 6 章)
- 归一化 (normalization)流: 这些可以用来定义更灵活的(非高斯)近似后验 qϕ(z∣x) 或先验 p(z),甚至更具表现力的解码器。(在第 3 章和第 4 章中讨论)
- 自回归 (autoregressive)解码器: 使用 PixelCNN 或 WaveNet 等强大的自回归模型作为解码器可以显著提升样本质量,但通常以较慢的生成速度为代价。(在第 3 章中讨论)
理解这些编码器和解码器网络的基本设计原则将使你能够构建、诊断 VAE 并在此基础上进行创新。下一节的实际实现将使你能够将这些想法付诸实践。