正弦位置编码对于向Transformer告知序列顺序很重要。直观了解这些编码的特性可以带来有益的直观认识。在此,将实现位置编码函数并可视化生成的向量。我们将使用Python配合NumPy进行数值计算,并使用Plotly进行交互式可视化,这非常适合基于网络的课程材料。实现位置编码函数首先,让我们将正弦位置编码的数学公式转换为代码。回顾这些公式:$$ PE_{(位置, 2i)} = \sin(位置 / 10000^{2i/d_{model}}) $$ $$ PE_{(位置, 2i+1)} = \cos(位置 / 10000^{2i/d_{model}}) $$此处 pos 是序列中的位置,$i$ 是嵌入向量中维度的索引,$d_{model}$ 是嵌入的维度。下面是一个使用NumPy生成这些编码的Python函数:import numpy as np def get_positional_encoding(max_seq_len, d_model): """ 生成正弦位置编码。 Args: max_seq_len: 最大序列长度。 d_model: 模型嵌入的维度。 Returns: 一个形状为 (max_seq_len, d_model) 的NumPy数组,包含 位置编码。 """ if d_model % 2 != 0: raise ValueError("d_model 必须是偶数才能容纳正弦/余弦对。") # 初始化位置编码矩阵 pos_encoding = np.zeros((max_seq_len, d_model)) # 创建一个位置列向量 [0, 1, ..., max_seq_len-1] position = np.arange(max_seq_len)[:, np.newaxis] # 形状:(max_seq_len, 1) # 计算除数项:1 / (10000^(2i / d_model)) # 对应 i = 0, 1, ..., d_model/2 - 1 div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model)) # 形状:(d_model/2,) # 对偶数索引 (2i) 应用正弦函数 pos_encoding[:, 0::2] = np.sin(position * div_term) # 对奇数索引 (2i + 1) 应用余弦函数 pos_encoding[:, 1::2] = np.cos(position * div_term) return pos_encoding # 示例用法: max_len = 50 # 最大序列长度 d_model = 128 # 嵌入维度(必须是偶数) positional_encodings = get_positional_encoding(max_len, d_model) print(f"生成的位置编码形状:{positional_encodings.shape}") # 输出:生成的位置编码形状:(50, 128)此函数接收最大序列长度和模型的嵌入维度作为输入。它计算偶数索引的正弦值和奇数索引的余弦值,基于位置和表示频率成分的 div_term。结果是一个矩阵,其中每行对应序列中的一个位置,每列对应位置编码向量中的一个维度。可视化位置编码可视化这个矩阵有助于理解这些编码的结构。热力图是查看编码值如何跨位置和维度变化的有效方式。我们将生成序列长度为50、嵌入维度为128的编码。{"layout": {"title": "正弦位置编码", "xaxis": {"title": "嵌入维度索引 (i)", "tickangle": -45}, "yaxis": {"title": "序列中的位置 (pos)"}, "colorscale": "viridis", "width": 700, "height": 500, "margin": {"l": 50, "r": 20, "t": 50, "b": 80}}, "data": [{"type": "heatmap", "zcolorbar": {"title": "PE 值"}}]}热力图可视化了长度为50、嵌入维度为128的序列的正弦位置编码。每行代表一个位置,每列代表一个维度索引。颜色强度指示编码值。分析可视化结果从热力图中,之前讨论的几个特性变得直观显现:每个位置的独特编码: 每行(位置)都有独特的颜色模式,代表其独特的编码向量。这种独特性使得模型能够区分不同位置。变化的频率: 观察维度轴(X轴)上的波长。最左侧的列(低维度索引,小 $i$)呈现高频变化(沿位置轴快速的颜色变化)。这些维度编码细粒度的位置信息。最右侧的列(高维度索引,大 $i$)显示低频变化(缓慢的颜色变化)。这些维度编码较粗略的、长距离的位置信息。平滑过渡: 正弦特性确保了相邻位置编码之间的平滑过渡。有界值: 由于正弦和余弦函数的作用,所有值本身都在 [-1, 1] 的范围内。让我们通过绘制几个特定位置(例如位置0、10和25)在所有维度上的编码向量来进一步检查这种独特性。{"layout": {"title": "特定位置的位置编码向量", "xaxis": {"title": "嵌入维度索引 (i)"}, "yaxis": {"title": "编码值"}, "width": 700, "height": 400, "margin": {"l": 50, "r": 20, "t": 50, "b": 50}}, "data": [{"type": "scatter", "mode": "lines", "name": "位置 0", "y": [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], "line": {"color": "#4263eb"}}, {"type": "scatter", "mode": "lines", "name": "位置 10", "y": [0.5403023058681398, 0.8414709848078965, 0.18897749149822277, 0.9819798938452818, -0.14052376354208757, 0.990051317567887, -0.4161468365471424, 0.9092974268256817, -0.621121998159741, 0.7837279448760467, -0.7511631690346772, 0.6601452129967232, -0.812564329230846, 0.5828488768295093, -0.8184081117524332, 0.5744383504860577, -0.7857976999910893, 0.6184857618425337, -0.7289702229146082, 0.6845402317243667, -0.6601452129967232, 0.7511631690346772, -0.5883683593029396, 0.8085882352877561, -0.5200430856635284, 0.8541400788291753, -0.45916043223893715, 0.8883631804443152, -0.4074939824715184, 0.9132005700872774, -0.3655754224019375, 0.9307856561027774, -0.3329684427015127, 0.9429403887229937, -0.3086672391045201, 0.951170684492078, -0.2914575772598071, 0.956584590890684, -0.2801166813039812, 0.9600575464302771, -0.2734321946600792, 0.9619619710977373, -0.27028013522313586, 0.9628444317175782, -0.2696707021858682, 0.9630136797128074, -0.2707495165361665, 0.9627174643379192, -0.27280944628165846, 0.9621447556553874, -0.2752684773833135, 0.9613524104700063, -0.27765302451878775, 0.9603927664830504, -0.2796065088049152, 0.9593107354812005, -0.2808804021430193, 0.9581422858563513, -0.2813277444934208, 0.9569144746249047, -0.2808900952046198, 0.955646648526002, -0.2795868082856521, 0.9543521456175499, -0.2774821108159553, 0.9530413413489827, -0.2746617705116992, 0.9517222198440056, -0.2712231782659786, 0.9503999410700009, -0.26726556474604746, 0.9490780031242018, -0.2628880892504738, 0.947758271453384, -0.2581813460502937, 0.9464412351776197, -0.253228381419316, 0.9451262242362557, -0.24810541245493634, 0.9438119368622691, -0.2428808121618345, 0.9424964830254986, -0.23761513633571205, 0.9411775546965355, -0.2323610198817543, 0.9398524775669785, -0.2271649903761108, 0.9385183596243761, -0.2220682164089234, 0.9371719952651193, -0.2171059613155731, 0.9358098761362579, -0.21229990165664096, 0.9344283402575424, -0.20766663909843656, 0.9330234925729956, -0.20321863953768652, 0.9315912082554639, -0.19896502728571397, 0.9301271787895746, -0.1949117078316356, 0.9286268897658035], "line": {"color": "#fa5252"}}, {"type": "scatter", "mode": "lines", "name": "位置 25", "y": [-0.9905827874384348, 0.13688989313579105, -0.776137602863411, 0.6305628789841774, -0.314952012865825, 0.9491011781702696, 0.2153478406908885, 0.9765346466025795, 0.6346205390369582, 0.7728365189386228, 0.8881307079386816, 0.45959189997587015, 0.9817917177274125, 0.1899798283765998, 0.9380918933463169, -0.3463977739968054, 0.7781997181033315, -0.6280066642471608, 0.5439994304990682, -0.8390883630524981, 0.2774821108159553, -0.9607292305329883, -0.0018992702233388535, -0.9999981961659431, -0.2581813460502937, -0.9660991754565172, -0.4719995149864666, -0.8816085560256154, -0.6305628789841774, -0.776137602863411, -0.7349607671925405, -0.6781124239831851, -0.7909090322684492, -0.6119280264383237, -0.8085882352877561, -0.5883683593029396, -0.796159899774068, -0.6050906685665812, -0.7600058152812625, -0.6499068776943778, -0.7055794044529136, -0.7086330134030763, -0.6378013241675584, -0.7601432137869585, -0.5612626269081336, -0.8276456378798706, -0.47972566424985876, -0.8774131106809474, -0.3966264558233681, -0.9180986570096914, -0.3148910659497917, -0.949121804865178, -0.2368108521746069, -0.971556115669016, -0.16412091133920836, -0.9864444366062577, -0.09816981777990183, -0.9951701063409906, -0.04001169465248977, -0.999199328784382, 0.009991600483724762, -0.9999500806616277, 0.05086204636432838, -0.9987059475660337, 0.08157421912986426, -0.9966679654429727, 0.1012239651731595, -0.9948676595097934, 0.11001178501054878, -0.9939310640387985, 0.10821788080831847, -0.9941225304479194, 0.09628113317249609, -0.9953561199383344, 0.07480519859761726, -0.9972011499566382, 0.04450743006836251, -0.9990079287015746, 0.006112502707762947, -0.9999813244984171, -0.038715054462849345, -0.9992504339537643, -0.08888678529012994, -0.9960345490336076, -0.1329593752844065, -0.9911178092975667, -0.1700008084976757, -0.9854497300301147, -0.1993565671034955, -0.9799251668523311, -0.22062032894824937, -0.9752809405508342, -0.23364570758395636, -0.9723199998763365, -0.23853538792899084, -0.9711413314151337, -0.23563280369217227, -0.9718410491995881], "line": {"color": "#12b886"}}]}线图比较了位置0、10和25的128维位置编码向量。每条线独特的形状突出了分配给每个序列位置的独特编码。这些可视化结果证实了正弦位置编码为每个位置提供了独特的信号,并在不同频率的维度上平滑变化。随后,这个位置信号被添加到输入词元嵌入中,使后续的自注意力层能够考虑序列中元素的顺序。在下一章中,我们将组装这些组件以及多头注意力机制,构成完整的Transformer编码器和解码器堆栈。