除了对现有序列中的值进行分类或预测,循环神经网络还具备生成全新序列的能力,这些序列能模拟从训练数据中学到的模式。这种生成能力拓展了应用,从生成可信文本、创作音乐到合成逼真时间序列数据。
使用RNN进行序列生成的基本思路出奇地简单:我们训练模型来预测序列中的下一个元素,根据它之前的元素。一旦训练完成,我们就可以迭代地使用这种预测能力来构建新序列。
核心机制:预测下一步
设想你有一个序列 x1,x2,...,xN。在训练期间,我们向模型提供子序列,并要求它预测紧随该子序列的元素。例如,给定 x1,...,xt,模型学习预测 xt+1。
- 对于离散序列(如文本): 此任务通常被视为分类问题。如果我们在字符级别操作,模型会预测词汇表中所有可能字符在下一个时间步的概率分布。如果操作在词语级别,它会预测词汇表中所有词语的概率。为此,输出层通常使用
softmax激活函数。
- 对于连续序列(如时间序列): 这通常被视为回归问题。给定之前的值 x1,...,xt,模型预测连续值 x^t+1。输出层通常使用线性激活函数。
RNN的隐藏状态在此处扮演重要作用,它总结了先前元素(x1,...,xt)的信息,以预测xt+1。LSTM和GRU特别有效,因为它们的门控机制使它们能够捕捉到更长范围的依赖关系,这对于生成连贯的序列通常非常重要。
迭代生成序列
一旦模型被训练来预测下一个元素,我们就可以一步步地生成新序列:
- 提供一个种子: 从一个初始的“种子”序列开始。这可以是一个单一元素、训练数据中的一个短序列,或者一个自定义提示。对于文本生成,这可能是句子的开头,例如“今天的天气是”。
- 预测下一个元素: 将种子序列输入到训练好的RNN中。模型会输出对下一个应出现元素的预测。对于分类(例如文本),此输出是词汇表上的概率分布。对于回归,它是一个预测值。
- 选择下一个元素:
- 回归: 直接使用预测值。
- 分类: 根据概率分布选择下一个元素。我们很少只选择最可能的单个元素(贪婪搜索),因为这通常会导致重复和可预测的结果。相反,我们使用采样策略(如下所述)。
- 添加并重复: 将选定的元素附加到当前序列。这个新的、更长的序列成为下一个时间步的输入。重复步骤2和3,逐个生成序列元素,直到达到所需长度或生成一个特殊的序列结束标记。
这个迭代过程使模型能够增量地构建序列,每个新元素都以迄今生成的序列为条件。
使用RNN进行迭代序列生成的过程。一个种子序列启动该过程,RNN预测下一个元素的概率,然后采样一个元素,并将序列扩展以作为下一步的输入。
离散序列的采样策略
生成文本等离散序列时,从预测的概率分布中选择下一个元素并非易事。不同的策略在连贯性、多样性和可预测性之间提供权衡:
- 贪婪搜索: 在每一步只选择概率最高的元素。这是确定性的,并且通常导致重复或无趣的序列。它可能会陷入诸如“is is is is...”之类的循环中。
- 带温度的采样: 这是一种流行的方法,用于控制选择的随机性。在对网络的原始输出分数(logits,zi)应用
softmax函数之前,我们将logits除以一个温度值(T):
Pi=∑jexp(zj/T)exp(zi/T)
- T=1:根据学到的概率进行标准采样。
- T<1(例如0.5):使分布“更尖锐”。高概率元素变得更加可能,从而产生更集中和可预测、可能更连贯的输出,更接近贪婪搜索。
- T>1(例如1.2):使分布“更平坦”。低概率元素变得更可能,增加随机性、新颖性和多样性,但可能降低连贯性或语法正确性。
- 找到合适的温度通常需要通过实验。
- Top-k 采样: 不考虑词汇表中的所有元素,只考虑概率最高的k个元素。重新归一化这k个最高概率元素中的概率,并从这个缩小集合中采样。这可以避免低概率(通常无意义)的选择,同时仍允许一定的变化。
- Top-p(核)采样: 这是Top-k的一种改进。不是选择固定数量的k个元素,而是选择累积概率大于或等于阈值p(例如p=0.9)的最小元素集合。只从这个“核”中采样,即那些可能的元素。这会根据模型在每一步的确定性来调整采样池的大小。如果模型非常确定(某个词的概率非常高),则核很小;如果模型不确定(概率分散),则核会更大。
字符级与词语级生成
对于文本生成,一个重要的选择是你操作的级别:
- 字符级:
- 词汇表由单个字符(字母、标点符号、空格)组成。
- 优点:可以处理任何词(没有词汇表外问题),可以生成新颖的拼写或词形,词汇量较小。
- 缺点:模型不仅需要学习句子结构,还需要学习词语结构。由于序列长度(以字符计)变得非常大,捕获连贯意义所需的远距离依赖关系变得更加困难。训练可能计算量很大。
- 词语级:
- 词汇表由训练数据中遇到的唯一词语组成。
- 优点:更直接地建模语言结构(词语是意义的主要单位),通常在更长的范围内生成更连贯的文本,相同文本的序列长度(以词语计)比字符序列更短。
- 缺点:词汇量可能变得非常大,导致内存和计算挑战。对训练期间未见的词语(词汇表外或OOV问题)处理能力不足,尽管使用特殊
<UNK>(未知)标记或子词分词(如BPE或WordPiece)等方法可以缓解这种情况。无法完全发明新词。
选择取决于具体的任务、数据集大小和期望的输出特性。
实际考量
- 种子选择: 初始种子序列极大地影响生成输出的主题和风格。尝试不同的种子是很常见的。
- 训练数据: 模型从训练数据中学习模式。如果你用莎士比亚的作品训练,它会生成听起来像莎士比亚风格的文本。如果你用代码训练,它会生成代码。训练数据的质量、大小和领域非常重要。
- 评估: 评估生成的序列可能带有主观性。虽然存在诸如困惑度(衡量模型预测测试集表现的指标,将在下一章讨论)之类的指标,但对于故事创作或对话生成等任务,通常需要人工判断来评估其连贯性、创造性和相关性。
序列生成展示了RNN学习和再现复杂时间模式的能力。虽然这里描述的方法构成了根本,但生成真正具有长距离连贯性和符合上下文的序列,通常会受益于更先进的架构,如编码器-解码器框架和注意力机制,这些将在本章后面简要介绍。