标准VAE凭借其连续潜在空间提供了一个扎实的基础,但有时会生成缺乏清晰度的样本,这通常归因于连续表示的特性和KL散度项。当您的目标是生成更清晰的输出或处理本质上离散的数据结构时,向量 (vector)量化 (quantization)变分自编码器(VQ-VAE)提供了一种值得考虑的架构替代方案。VQ-VAE通过引入离散潜在空间来实现这一点,该空间通过有限的嵌入 (embedding)向量码本学习得到。这种设计选择通常能明显提高样本质量,并提供了一种不同的机制来形成信息瓶颈,从而摆脱了传统VAE中显式的KL散度项。
VQ-VAE的架构
VQ-VAE的核心创新点在于对编码器输出的量化 (quantization)。它不直接使用连续潜在向量 (vector),而是将编码器的输出映射到学习到的有限嵌入 (embedding)向量集中最接近的向量,这个集合被称为码本。
VQ-VAE通常由以下部分组成:
- 编码器(Encoder): 这个神经网络 (neural network)fenc处理输入x并产生一个连续表示ze(x)。如果输入是图像,ze(x)可能是一个形状为H′×W′×D的特征图;对于其他数据类型,它可以是一个单一向量。这个ze(x)是一个中间输出,而不是最终的潜在变量。
- 码本(嵌入空间): 这是一个可学习字典E={e1,e2,...,eK},包含K个嵌入向量,其中每个ei∈RD。您可以将这个码本视为模型学习到的代表性特征向量调色板。
- 量化器(Quantizer): 对于编码器输出ze(x)中的每个向量(或者如果ze(x)本身是一个单一向量,则直接对它),量化器在码本E中执行最近邻查找。它识别出在欧几里得距离上最接近的嵌入向量ek:
k=argminj∣∣ze(x)−ej∣∣2
量化后的潜在表示,zq(x),就是这个被选中的码本向量:zq(x)=ek。
- 解码器(Decoder): 这个网络fdec以量化表示zq(x)为输入,旨在重构原始输入x,得到x^。
VQ-VAE架构的概览。编码器产生ze(x),通过在学习到的码本E中找到最接近的向量,将其量化为zq(x)。解码器从zq(x)重构输入。
训练动态与不可微分挑战
训练VQ-VAE的一大挑战是量化 (quantization)步骤中的argmin操作不可微分。这阻止了梯度从解码器直接反向传播 (backpropagation)到编码器。VQ-VAE通过采用直通估计器(STE)巧妙地解决了这个问题。
在反向传播过程中,来自解码器输入的梯度∇zqL直接传递到编码器的输出ze(x)。本质上,解码器的梯度信号绕过了不可微分的量化操作。
∂ze(x)∂L≈∂zq(x)∂L
这使得编码器能够学习产生在量化后能得到良好重构的输出ze(x)。编码器学习生成接近有用码本条目的连续向量 (vector)。
VQ-VAE的损失函数 (loss function)
训练VQ-VAE的总损失函数通常包含三个不同的部分:
-
重构损失(Lrecon): 这一项促使解码器从其量化 (quantization)潜在表示zq(x)精确重构输入x。对于图像等连续数据,这通常是均方误差(MSE):
Lrecon=∣∣x−dec(zq(x))∣∣22
对于其他数据类型,如二值数据,二元交叉熵(BCE)可能更合适。
-
码本损失(或嵌入 (embedding)损失,Lcodebook): 这种损失负责更新码本E中的嵌入向量 (vector)ei。它促使被选中的码本向量ek(最接近ze(x)的那个)向编码器的输出ze(x)移动。对于此更新,编码器的输出ze(x)被视为常数(使用停止梯度操作符sg进行分离)。
Lcodebook=∣∣sg[ze(x)]−ek∣∣22
这个部分类似于k-means聚类中的质心更新规则。它将选定的码本向量ek拉向映射到它的编码器输出集群。
-
承诺损失(Lcommit): 这一项对编码器进行正则化 (regularization),促使其输出ze(x)保持“归属于”所选码本向量ek。它有助于防止编码器输出过度波动或变得过大,确保它们保持接近离散表示集。对于这种损失,码本向量ek被视为常数(分离)。超参数 (parameter) (hyperparameter)β控制此项的影响。
Lcommit=β∣∣ze(x)−sg[ek]∣∣22
如果没有这种损失,编码器输出ze(x)可能会严重偏离它们所映射的实际嵌入ek,可能使码本更新的稳定性和效果下降。
要最小化的组合损失函数为:
LVQVAE=Lrecon+Lcodebook+Lcommit
需要注意的是,标准VAE中显式的KL散度项在这里缺失,该项将编码器的潜在分布q(z∣x)正则化到先验p(z)。在VQ-VAE中,正则化作用主要通过有限码本施加的信息瓶颈和承诺损失来实现。
使用VQ-VAE的优点
通过学习码本引入离散潜在空间带来了几个显著的优点:
- 更清晰的生成样本:具有连续潜在空间的标准VAE有时会因为潜在空间中的平均效应而产生模糊或过于平滑的样本。VQ-VAE中zq(x)的离散性通常会迫使解码器从更明确的“原型”特征集中选择,从而产生更清晰、更详细、更高保真度的输出,特别是对于图像和音频等复杂数据。
- 缓解后验坍缩:训练标准VAE时的一个常见难题是“后验坍缩”,即如果ELBO中的KL散度项严重惩罚与先验的偏差,潜在变量就会变得不包含信息。VQ-VAE避开了这个特定问题,因为它们不对编码器输出使用相同的KL正则化 (regularization)机制。信息瓶颈转而由量化 (quantization)过程本身强制实现。
- 学习到的离散表示:VQ-VAE学习到的离散码(索引k)本身就很有价值。例如,在语音建模中,这些码可能捕获类似音素的单元。此外,这些离散序列可以作为强大自回归 (autoregressive)模型的输入,实现我们将在稍后讨论的两阶段生成过程。
- 受控的潜在容量:潜在空间的信息容量由码本大小K和嵌入 (embedding)的维度D决定。与调整标准VAE中KL散度的权重 (weight)相比,这提供了一种更直接的方式来控制表示瓶颈。
考量与潜在挑战
虽然VQ-VAE功能强大,但仍有一些实际方面和挑战需要注意:
- 码本大小(K):选择嵌入 (embedding)向量 (vector)的数量K是一个重要的设计决策。小的K可能会限制模型捕获数据完整多样性的能力,导致重构效果不佳。相反,非常大的K会增加计算成本和内存,并可能导致许多码未被充分利用。
- 码本坍缩(死码):训练期间可能只有码本向量的一个子集被主动使用,而许多嵌入(“死码”)很少或从未被编码器选中。承诺损失在一定程度上有所帮助,但有时会采用针对未使用的码的特定初始化或定期重置策略。
- 训练稳定性:编码器学习产生与现有码本条目匹配的输出,以及码本条目向编码器输出移动,两者之间的关系可能会引入与标准VAE不同的训练动态。承诺损失的超参数 (parameter) (hyperparameter)β通常需要仔细调整。
- 量化 (quantization)的计算成本:对于非常大的码本,量化步骤中的最近邻搜索可能会变得计算密集。然而,对于典型的码本大小(例如,K从几百到几千),通过高效的搜索算法通常可以很好地管理。
VQ-VAE与离散潜在变量上的自回归 (autoregressive)先验
VQ-VAE最具有影响力的用途之一是它们与自回归模型的联系,用于学习离散潜在空间上的先验。与标准VAE中假设一个简单、因子化的先验p(z)不同,您可以训练一个独立的、强大的自回归模型(例如用于图像的PixelCNN或用于序列的Transformer)来模拟VQ-VAE编码器生成的离散潜在码k的分布。
这通常涉及一个两阶段过程:
- 训练VQ-VAE:VQ-VAE按照描述进行训练,学习一个编码器、一个解码器和码本E。一旦训练完成,编码器可以将输入x映射到离散码本索引序列k1,k2,...,kM。
- 训练自回归先验:随后,在这些索引序列ki上训练一个独立的自回归模型,以学习先验分布p(k)=∏ip(ki∣k<i)。
对于生成:
- 从训练好的自回归先验p(k)中采样一个潜在码序列k1,...,kM。
- 对于每个采样的索引ki,从VQ-VAE学习到的码本E中检索相应的嵌入 (embedding)向量 (vector)eki。
- 将这个嵌入序列通过VQ-VAE的解码器,以生成新的数据样本x^。
使用VQ-VAE的两阶段生成过程。阶段1训练VQ-VAE。阶段2在学习到的离散码上训练一个自回归先验。对于生成,从该先验中采样码,转换为嵌入,然后进行解码。
这种两阶段方法在VQ-VAE-2等模型中得到了著名的展示,它有效地分离了关注点。VQ-VAE擅长学习紧凑、高质量的局部特征词汇(即“是什么”),而自回归先验则专注于模拟这些特征如何组合的远距离依赖关系和全局结构(即“如何”)。这种组合使得生成高度真实和连贯的图像和音频成为可能,展现了离散表示与富有表现力的序列模型结合时的能力。随着您学习高级架构,理解VQ-VAE将为您提供一个有效的工具,用于高保真生成和学习结构化离散表示。