正弦位置编码对于向Transformer告知序列顺序很重要。直观了解这些编码的特性可以带来有益的直观认识。在此,将实现位置编码函数并可视化生成的向量。我们将使用Python配合NumPy进行数值计算,并使用Plotly进行交互式可视化,这非常适合基于网络的课程材料。实现位置编码函数首先,让我们将正弦位置编码的数学公式转换为代码。回顾这些公式:$$ PE_{(位置, 2i)} = \sin(位置 / 10000^{2i/d_{model}}) $$ $$ PE_{(位置, 2i+1)} = \cos(位置 / 10000^{2i/d_{model}}) $$此处 pos 是序列中的位置,$i$ 是嵌入向量中维度的索引,$d_{model}$ 是嵌入的维度。下面是一个使用NumPy生成这些编码的Python函数:import numpy as np def get_positional_encoding(max_seq_len, d_model): """ 生成正弦位置编码。 Args: max_seq_len: 最大序列长度。 d_model: 模型嵌入的维度。 Returns: 一个形状为 (max_seq_len, d_model) 的NumPy数组,包含 位置编码。 """ if d_model % 2 != 0: raise ValueError("d_model 必须是偶数才能容纳正弦/余弦对。") # 初始化位置编码矩阵 pos_encoding = np.zeros((max_seq_len, d_model)) # 创建一个位置列向量 [0, 1, ..., max_seq_len-1] position = np.arange(max_seq_len)[:, np.newaxis] # 形状:(max_seq_len, 1) # 计算除数项:1 / (10000^(2i / d_model)) # 对应 i = 0, 1, ..., d_model/2 - 1 div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model)) # 形状:(d_model/2,) # 对偶数索引 (2i) 应用正弦函数 pos_encoding[:, 0::2] = np.sin(position * div_term) # 对奇数索引 (2i + 1) 应用余弦函数 pos_encoding[:, 1::2] = np.cos(position * div_term) return pos_encoding # 示例用法: max_len = 50 # 最大序列长度 d_model = 128 # 嵌入维度(必须是偶数) positional_encodings = get_positional_encoding(max_len, d_model) print(f"生成的位置编码形状:{positional_encodings.shape}") # 输出:生成的位置编码形状:(50, 128)此函数接收最大序列长度和模型的嵌入维度作为输入。它计算偶数索引的正弦值和奇数索引的余弦值,基于位置和表示频率成分的 div_term。结果是一个矩阵,其中每行对应序列中的一个位置,每列对应位置编码向量中的一个维度。可视化位置编码可视化这个矩阵有助于理解这些编码的结构。热力图是查看编码值如何跨位置和维度变化的有效方式。我们将生成序列长度为50、嵌入维度为128的编码。{"layout": {"title": "正弦位置编码", "xaxis": {"title": "嵌入维度索引 (i)", "tickangle": -45}, "yaxis": {"title": "序列中的位置 (pos)"}, "colorscale": "viridis", "width": 700, "height": 500, "margin": {"l": 50, "r": 20, "t": 50, "b": 80}}, "data": [{"type": "heatmap", "z": [[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], [0.8414709848078965, 0.5403023058681398, 0.8775825618903728, 0.4794621373316841, 0.9092974268256817, 0.4161636136386472, 0.9359800720848767, 0.3511794274959062, 0.957167288647771, 0.2853200096676728, 0.9726241135354941, 0.21943088328258944, 0.9823486654420262, 0.1444935167101887, 0.986621210308548, 0.07158596772311953, 0.9859832755324335, 0.0018992702233388535, 0.9809782008750537, -0.06345507245367797, 0.9721568290771454, -0.12345883954073808, 0.9600843304100076, -0.17720633665454442, 0.9453351962988635, -0.2240376991462646, 0.9284922310133183, -0.2634838291455655, 0.9099755557578655, -0.29535198190766715, 0.8902613659026436, -0.32003844855060365, 0.8698023231781304, -0.3384248851383302, 0.8489738668527743, -0.3510316001619001, 0.8280776801460794, -0.3587937002961375, 0.8073357046348322, -0.36227767539237747, 0.7869072119620067, -0.3619104148243143, 0.766908874796896, -0.3580788497616465, 0.7474257712247096, -0.351138293578585, 0.728517365048824, -0.341421429945975, 0.710220280152522, -0.3292403859899767, 0.6925545133142473, -0.3148910659497917, 0.6755239348221682, -0.2986581957131764, 0.6591213389773723, -0.2808158212264793, 0.6433302656656797, -0.2616278138086091, 0.6281262040576656, -0.24135021493881506, 0.6134801709515981, -0.2202241953963803, 0.5993576585319589, -0.19848094482549626, 0.5857205613321617, -0.1763389392724801, 0.5725277526845391, -0.15399716111050198, 0.5597357507339302, -0.13163926450332512, 0.5472988330938691, -0.10943613184178361, 0.5351704986353054, -0.08754515574974138, 0.5233043221200266, -0.06610996613685261, 0.5116551286801524, -0.04526203420359618, 0.5001791465753631, -0.02511804849356034, 0.4888350266495903, -0.005782391699461689, 0.4775838704230799, 0.012637212060098437, 0.4663902401710957, 0.03004515307047458, 0.4552211708931967, 0.04635615897709488, 0.4440461431184715, 0.0614969947250869, 0.4328370599217558, 0.07540565632609201, 0.421568316648159, 0.08803164704240924, 0.41021677566891227, 0.09933557071247374, 0.398761657496558, 0.10929942991002057, 0.3871854665310259, 0.11792632988475984, 0.3754740022989012, 0.12523997081110914, 0.36361626267528296, 0.13128369877988164, 0.351604260033571, 0.1361194305473714, 0.3394330075552909, 0.13982522828135583, 0.32710034974739245, 0.1424938023989063, 0.3146071431821944, 0.14422311656730878, 0.3019570028552581, 0.14511424679836578], [0.9092974268256817, -0.4161468365471424, 0.990051317567887, -0.14052376354208757, 0.9738476308987014, 0.2272213300895165, 0.8774131106809474, 0.47972566424985876, 0.7274081086626375, 0.6862016723829984, 0.5439994304990682, 0.8390883630524981, 0.3463977739968054, 0.9380918933463169, 0.15160683649643978, 0.9884277138999634, -0.03379685595403166, 0.9994284771050007, -0.2011819698083945, 0.9795574642261906, -0.3448871252455501, 0.9386369877999581, -0.4618775333986173, 0.8869548521570036, -0.5516098585835745, 0.834102856882146, -0.615618377512519, 0.7880461774228927, -0.6568151198622444, 0.7539732479458866, -0.6786328671587832, 0.7344792967634055, -0.6845402317243667, 0.7289702229146082, -0.6781124239831851, 0.7349607671925405, -0.6628122466459876, 0.7487872593557883, -0.6419282870118471, 0.7667844942139249, -0.6184857618425337, 0.7857976999910893, -0.5952160221344067, 0.8035677306100003, -0.5744383504860577, 0.8184081117524332, -0.5580488175887216, 0.8301372299388907, -0.5474892955890541, 0.8368077998071836, -0.5437330621433164, 0.8392569735463234, -0.5472552841377183, 0.8369540249214177, -0.5580675366722038, 0.8301295128046776, -0.5757557174087019, 0.8176013664721096, -0.5994458015849391, 0.8004156819259502, -0.6280066642471608, 0.7781997181033315, -0.6599933548987801, 0.7512703899449745, -0.6940078140940975, 0.7201816723078561, -0.7285804795957063, 0.684967666618778, -0.7622678419722884, 0.647273600101545, -0.793706637073495, 0.608287366020119, -0.8216785471917834, 0.5700041146736355, -0.8452071101603048, 0.5344368330819036, -0.8636011877662686, 0.5041463648227066, -0.8764521926656247, 0.4815002867358958, -0.883628291494123, 0.46818024464145177, -0.8852374485592888, 0.4651287723725904, -0.8816085560256154, 0.4719995149864666, -0.8732699349157511, 0.4872564719244687, -0.8608571112488515, 0.5088233141911096, -0.8450948445941374, 0.5346206896491224, -0.8267889334177648, 0.562526890770957, -0.8067923694490332, 0.5908267881506066, -0.7859179210667349, 0.6184586846318105, -0.7648757672162811, 0.6441790277075823, -0.7442989303579306, 0.6678336321062879, -0.7247195135534931, 0.6889989733278746, -0.7066014573823135, 0.707611129970412, -0.6899397389423581, 0.7238678128357774, -0.6746565978092815, 0.7381353755812053, -0.6606460645173546, 0.7507081008504335]], "colorbar": {"title": "PE 值"}}]}热力图可视化了长度为50、嵌入维度为128的序列的正弦位置编码。每行代表一个位置,每列代表一个维度索引。颜色强度指示编码值。分析可视化结果从热力图中,之前讨论的几个特性变得直观显现:每个位置的独特编码: 每行(位置)都有独特的颜色模式,代表其独特的编码向量。这种独特性使得模型能够区分不同位置。变化的频率: 观察维度轴(X轴)上的波长。最左侧的列(低维度索引,小 $i$)呈现高频变化(沿位置轴快速的颜色变化)。这些维度编码细粒度的位置信息。最右侧的列(高维度索引,大 $i$)显示低频变化(缓慢的颜色变化)。这些维度编码较粗略的、长距离的位置信息。平滑过渡: 正弦特性确保了相邻位置编码之间的平滑过渡。有界值: 由于正弦和余弦函数的作用,所有值本身都在 [-1, 1] 的范围内。让我们通过绘制几个特定位置(例如位置0、10和25)在所有维度上的编码向量来进一步检查这种独特性。{"layout": {"title": "特定位置的位置编码向量", "xaxis": {"title": "嵌入维度索引 (i)"}, "yaxis": {"title": "编码值"}, "width": 700, "height": 400, "margin": {"l": 50, "r": 20, "t": 50, "b": 50}}, "data": [{"type": "scatter", "mode": "lines", "name": "位置 0", "y": [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], "line": {"color": "#4263eb"}}, {"type": "scatter", "mode": "lines", "name": "位置 10", "y": [0.5403023058681398, 0.8414709848078965, 0.18897749149822277, 0.9819798938452818, -0.14052376354208757, 0.990051317567887, -0.4161468365471424, 0.9092974268256817, -0.621121998159741, 0.7837279448760467, -0.7511631690346772, 0.6601452129967232, -0.812564329230846, 0.5828488768295093, -0.8184081117524332, 0.5744383504860577, -0.7857976999910893, 0.6184857618425337, -0.7289702229146082, 0.6845402317243667, -0.6601452129967232, 0.7511631690346772, -0.5883683593029396, 0.8085882352877561, -0.5200430856635284, 0.8541400788291753, -0.45916043223893715, 0.8883631804443152, -0.4074939824715184, 0.9132005700872774, -0.3655754224019375, 0.9307856561027774, -0.3329684427015127, 0.9429403887229937, -0.3086672391045201, 0.951170684492078, -0.2914575772598071, 0.956584590890684, -0.2801166813039812, 0.9600575464302771, -0.2734321946600792, 0.9619619710977373, -0.27028013522313586, 0.9628444317175782, -0.2696707021858682, 0.9630136797128074, -0.2707495165361665, 0.9627174643379192, -0.27280944628165846, 0.9621447556553874, -0.2752684773833135, 0.9613524104700063, -0.27765302451878775, 0.9603927664830504, -0.2796065088049152, 0.9593107354812005, -0.2808804021430193, 0.9581422858563513, -0.2813277444934208, 0.9569144746249047, -0.2808900952046198, 0.955646648526002, -0.2795868082856521, 0.9543521456175499, -0.2774821108159553, 0.9530413413489827, -0.2746617705116992, 0.9517222198440056, -0.2712231782659786, 0.9503999410700009, -0.26726556474604746, 0.9490780031242018, -0.2628880892504738, 0.947758271453384, -0.2581813460502937, 0.9464412351776197, -0.253228381419316, 0.9451262242362557, -0.24810541245493634, 0.9438119368622691, -0.2428808121618345, 0.9424964830254986, -0.23761513633571205, 0.9411775546965355, -0.2323610198817543, 0.9398524775669785, -0.2271649903761108, 0.9385183596243761, -0.2220682164089234, 0.9371719952651193, -0.2171059613155731, 0.9358098761362579, -0.21229990165664096, 0.9344283402575424, -0.20766663909843656, 0.9330234925729956, -0.20321863953768652, 0.9315912082554639, -0.19896502728571397, 0.9301271787895746, -0.1949117078316356, 0.9286268897658035], "line": {"color": "#fa5252"}}, {"type": "scatter", "mode": "lines", "name": "位置 25", "y": [-0.9905827874384348, 0.13688989313579105, -0.776137602863411, 0.6305628789841774, -0.314952012865825, 0.9491011781702696, 0.2153478406908885, 0.9765346466025795, 0.6346205390369582, 0.7728365189386228, 0.8881307079386816, 0.45959189997587015, 0.9817917177274125, 0.1899798283765998, 0.9380918933463169, -0.3463977739968054, 0.7781997181033315, -0.6280066642471608, 0.5439994304990682, -0.8390883630524981, 0.2774821108159553, -0.9607292305329883, -0.0018992702233388535, -0.9999981961659431, -0.2581813460502937, -0.9660991754565172, -0.4719995149864666, -0.8816085560256154, -0.6305628789841774, -0.776137602863411, -0.7349607671925405, -0.6781124239831851, -0.7909090322684492, -0.6119280264383237, -0.8085882352877561, -0.5883683593029396, -0.796159899774068, -0.6050906685665812, -0.7600058152812625, -0.6499068776943778, -0.7055794044529136, -0.7086330134030763, -0.6378013241675584, -0.7601432137869585, -0.5612626269081336, -0.8276456378798706, -0.47972566424985876, -0.8774131106809474, -0.3966264558233681, -0.9180986570096914, -0.3148910659497917, -0.949121804865178, -0.2368108521746069, -0.971556115669016, -0.16412091133920836, -0.9864444366062577, -0.09816981777990183, -0.9951701063409906, -0.04001169465248977, -0.999199328784382, 0.009991600483724762, -0.9999500806616277, 0.05086204636432838, -0.9987059475660337, 0.08157421912986426, -0.9966679654429727, 0.1012239651731595, -0.9948676595097934, 0.11001178501054878, -0.9939310640387985, 0.10821788080831847, -0.9941225304479194, 0.09628113317249609, -0.9953561199383344, 0.07480519859761726, -0.9972011499566382, 0.04450743006836251, -0.9990079287015746, 0.006112502707762947, -0.9999813244984171, -0.038715054462849345, -0.9992504339537643, -0.08888678529012994, -0.9960345490336076, -0.1329593752844065, -0.9911178092975667, -0.1700008084976757, -0.9854497300301147, -0.1993565671034955, -0.9799251668523311, -0.22062032894824937, -0.9752809405508342, -0.23364570758395636, -0.9723199998763365, -0.23853538792899084, -0.9711413314151337, -0.23563280369217227, -0.9718410491995881], "line": {"color": "#12b886"}}]}线图比较了位置0、10和25的128维位置编码向量。每条线独特的形状突出了分配给每个序列位置的独特编码。这些可视化结果证实了正弦位置编码为每个位置提供了独特的信号,并在不同频率的维度上平滑变化。随后,这个位置信号被添加到输入词元嵌入中,使后续的自注意力层能够考虑序列中元素的顺序。在下一章中,我们将组装这些组件以及多头注意力机制,构成完整的Transformer编码器和解码器堆栈。