Transformer架构:位置编码(sin/cos编码) 您所在的位置:网站首页 为什么要叫正弦 Transformer架构:位置编码(sin/cos编码)

Transformer架构:位置编码(sin/cos编码)

2024-06-04 05:53| 来源: 网络整理| 查看: 265

原文链接:Transformer Architecture: The Positional Encoding 本文将原文做了一个翻译。

文章目录 什么是位置编码?为什么我们需要把它放在一开始的地方?提出的方法直觉其他细节相对位置FAQ为什么位置嵌入是与字词嵌入相加的而不是相连?一旦它到了更上面的层,位置信息不会消失吗?为什么正弦和余弦都用了? 总结本文参考评论区置顶

Transformer架构是由Vaswani等人提出的一个新颖的纯注意力序列到序列的架构。它并行训练的能力和它普遍的性能提升使得它在NLP和CV等领域的研究人员当中十分热门。

感谢Transofmer在几种常用深度学习框架的实现,很多学生包括我自己都可以很容易地拿它来做实验。尽管更容易使用它是件好事,但不好的是这样可能会导致模型的细节被忽略。

在这篇文章中,我不打算深挖它的架构,因为当前这方面已经有几篇很棒的教程了(这里、这里和这里),但相对地,我想去讨论Transformer架构中具体的一部分——位置编码。

当我读到论文中的这一部分的时候,我对它产生了很多疑问,但不巧的是作者没有提供充分的信息来回答它们。所以在这篇文章里,我想尝试把这一模块拆开来看看它是怎么运作的。

NOTE:如果要理解这篇博客剩下的部分,我非常建议你阅读一篇上面提到地那些教程来熟悉Transformer架构。 图1 - Transformer架构图

图1 - Transformer架构图 什么是位置编码?为什么我们需要把它放在一开始的地方?

字词的位置和顺序是所有语言的基本组成部分。它们定义了语法从而定义了句子的实际语义。循环神经网络(RNNs)自带了对字词顺序的考虑,它们通过顺序地逐字解析一条句子,这样就把字词的顺序集成到了RNNs的主干当中。

但是Transformer架构抛弃了循环机制,选择了多头自注意力机制。它避免了RNNs循环的方法,使得训练时间大大加速。并且理论上,它可以捕获句子中更长距离的依赖。

由于句子当中的每个字都同时流过了Transformer编码/解码栈,模型本身没有关于每个字词位置/顺序的任何概念。因此,模型仍然需要一种方法来把字词的顺序合并到我们的模型当中。

一种让模型感知顺序的方法是对每个词加一小块它在句子中的位置信息,我们把这"一小块信息"叫做位置编码。

脑子冒出来的第一种想法是对每个时间戳赋一个[0,1]的值,其中0代表第一个字,1代表最后一个时间戳。你能想到这样会导致什么样的问题吗?其中一个它会产生的问题是,指定了这样的范围,你不知道里面总共表示了多少个词。换句话说,不同句子之间,时间戳的增量没有一致的意义。

另一个想法是线性地赋予每个时间戳一个数字。就是说,第一个字给它个"1",第二个字给它个"2",以此类推。这个方法的问题是不仅这个值可能会变得很大,而且我们的模型可能会遇到比训练里更长的句子。再更进一步说,我们的模型可能会从来没看过某个特定长度的样本,从而损伤了我们模型的泛化能力。

理想地说,位置编码要满足以下这些标准:

它应该对每个时间戳(句子当中字词的位置)输出一个唯一的编码不同长度的句子当中,任意两个时间戳的距离应该要一致。(前提这两个时间戳相对距离一样)我们的模型应该在不付出任何努力的条件下泛化到更长的句子。它的编码值应该要有限(有界)。它必须是确定性的。 提出的方法

作者提出的编码方法是一个简单而又天才的技术,满足了上面所有的标准。首先,它不只是一个数字,而是一个 d d d 维的向量,包含了句子里某个特定位置的信息;其次,这个编码没有集成到模型本身,而是把这个带有句子位置信息的向量配备进字词当中。换句话说,我们通过注入字词的顺序来加强了模型的输入。

用 t t t 表示输入句子中需要的位置, p t ⃗ ∈ R d \vec{p_t}\in\Bbb{R}^d pt​ ​∈Rd 表示对应的编码,其中 d d d 表示编码维度( d d d 为偶数)。然后 f : N → R d f:\Bbb{N}\to\Bbb{R}^d f:N→Rd 表示产生输出向量 p t ⃗ \vec{p_t} pt​ ​ 的函数,它被定义为: p t ⃗ ( i ) = f ( t ) ( i ) : = { sin ⁡ ( ω k ⋅ t ) , if  i = 2 k cos ⁡ ( ω k ⋅ t ) , if  i = 2 k + 1 \vec{p_t}^{(i)}=f(t)^{(i)}:= \begin{cases} \sin(\omega_k\cdot t), & \text{if $i=2k$} \\ \cos(\omega_k\cdot t), & \text{if $i=2k+1$} \end{cases} pt​ ​(i)=f(t)(i):={sin(ωk​⋅t),cos(ωk​⋅t),​if i=2kif i=2k+1​其中 ω t = 1 1000 0 2 k / d \omega_t=\frac{1}{10000^{2k/d}} ωt​=100002k/d1​可以从函数定义中推导出,频率会随着向量维度 k k k 的增大而减小。因此它在波长上会形成从 2 π 2\pi 2π 到 10000 ⋅ 2 π 10000\cdot2\pi 10000⋅2π 的几何级数。

你可以把位置嵌入 p t ⃗ \vec{p_t} pt​ ​ 想象成一个包含每个频率的正余弦对的向量(注意 d d d 能被2整除): p t ⃗ = [ sin ⁡ ( ω 1 ⋅ t ) cos ⁡ ( ω 1 ⋅ t ) sin ⁡ ( ω 2 ⋅ t ) cos ⁡ ( ω 2 ⋅ t ) ⋮ sin ⁡ ( ω d / 2 ⋅ t ) cos ⁡ ( ω d / 2 ⋅ t ) ] d × 1 \vec{p_t}= \begin{bmatrix} \sin(\omega_1\cdot t) \\ \cos(\omega_1\cdot t) \\ \\ \sin(\omega_2\cdot t) \\ \cos(\omega_2\cdot t) \\ \\ \vdots \\ \\ \sin(\omega_{d/2}\cdot t) \\ \cos(\omega_{d/2}\cdot t) \\ \end{bmatrix}_{d\times1} pt​ ​= ​sin(ω1​⋅t)cos(ω1​⋅t)sin(ω2​⋅t)cos(ω2​⋅t)⋮sin(ωd/2​⋅t)cos(ωd/2​⋅t)​ ​d×1​

直觉

你可能会想这个正余弦的组合怎么能表示一个位置或者顺序呢?它实际上非常简单,假设你想要用一个二进制的形式表示一个数字,那会怎么样? 0 : 0 0 0 0 8 : 1 0 0 0 1 : 0 0 0 1 9 : 1 0 0 1 2 : 0 0 1 0 10 : 1 0 1 0 3 : 0 0 1 1 11 : 1 0 1 1 4 : 0 1 0 0 12 : 1 1 0 0 5 : 0 1 0 1 13 : 1 1 0 1 6 : 0 1 1 0 14 : 1 1 1 0 7 : 0 1 1 1 15 : 1 1 1 1 \begin{matrix} 0:&\color{orange}0&\color{green}0&\color{blue}0&\color{red}0&&8:&\color{orange}1&\color{green}0&\color{blue}0&\color{red}0 \\ 1:&\color{orange}0&\color{green}0&\color{blue}0&\color{red}1&&9:&\color{orange}1&\color{green}0&\color{blue}0&\color{red}1 \\ 2:&\color{orange}0&\color{green}0&\color{blue}1&\color{red}0&&10:&\color{orange}1&\color{green}0&\color{blue}1&\color{red}0 \\ 3:&\color{orange}0&\color{green}0&\color{blue}1&\color{red}1&&11:&\color{orange}1&\color{green}0&\color{blue}1&\color{red}1 \\ 4:&\color{orange}0&\color{green}1&\color{blue}0&\color{red}0&&12:&\color{orange}1&\color{green}1&\color{blue}0&\color{red}0 \\ 5:&\color{orange}0&\color{green}1&\color{blue}0&\color{red}1&&13:&\color{orange}1&\color{green}1&\color{blue}0&\color{red}1 \\ 6:&\color{orange}0&\color{green}1&\color{blue}1&\color{red}0&&14:&\color{orange}1&\color{green}1&\color{blue}1&\color{red}0 \\ 7:&\color{orange}0&\color{green}1&\color{blue}1&\color{red}1&&15:&\color{orange}1&\color{green}1&\color{blue}1&\color{red}1 \\ \end{matrix} 0:1:2:3:4:5:6:7:​00000000​00001111​00110011​01010101​​8:9:10:11:12:13:14:15:​11111111​00001111​00110011​01010101​你可以观察到不同位之间变化的速度,最低位每个数字之间交替变化(红),第二低位每两个数字翻转一次(蓝),以此类推。

但是用二进制值会在浮点世界里造成空间浪费。所以反过来,我们可以使用它们对应的浮点连续值——三角函数。实际上,它们等价于交替变换的位。而且,通过降低它们的频率,我们也可以实现从红色的位走到橙色的位。(上面的例子而不是下面的图) 图2 - 句子长度最长50的128维位置编码。每行表示一个嵌入向量

图2 - 句子长度最长50的128维位置编码。每行表示一个嵌入向量 其他细节

在博客前面,我提到了位置嵌入是用来将位置信息配备到输入的字词中的。但它是怎么实现的?实际上,原论文把位置编码加到了实际嵌入的上面。就是说对于一个句子 [ ω 1 , … ω n ] [\omega_1,\dots\omega_n] [ω1​,…ωn​] 中的每个字词 ω t \omega_t ωt​,计算对应喂给模型的嵌入的过程如下: ψ ′ ( ω t ) = ψ ( ω t ) + p t ⃗ \psi'(\omega_t)=\psi(\omega_t)+\vec{p_t} ψ′(ωt​)=ψ(ωt​)+pt​ ​为了使这样的求和能够实现,我们保持位置嵌入的维度与字词嵌入的维度相等,即 d word embedding = d positional embedding d_\text{word embedding}=d_\text{positional embedding} dword embedding​=dpositional embedding​

相对位置

另外一个三角位置编码的特性是,它允许模型毫无费力地把握相对位置。这是一段原论文的引用:

We chose this function because we hypothesized it would allow the model to easily learn to attend by relative positions, since for any fixed offset k k k , P E p o s + k PE_{pos+k} PEpos+k​ can be represented as a linear function of P E p o s PE_{pos} PEpos​.

但为什么这个论述成立呢?为了去完全理解它,请参考这篇优秀的文章来阅读详细的证明。而我这里准备了一个更短的版本。

对于频率 ω k \omega_k ωk​ 对应的每一对正余弦对,会有一个线性变换 M ∈ R 2 × 2 M\in\Bbb{R}^{2\times2} M∈R2×2 (独立于 t t t)使得下面的方程成立: M ⋅ [ sin ⁡ ( ω k ⋅ t ) cos ⁡ ( ω k ⋅ t ) ] = [ sin ⁡ ( ω k ⋅ ( t + ϕ ) ) cos ⁡ ( ω k ⋅ ( t + ϕ ) ) ] M\cdot\begin{bmatrix} \sin(\omega_k\cdot t) \\ \cos(\omega_k\cdot t) \\ \end{bmatrix}=\begin{bmatrix} \sin(\omega_k\cdot(t+\phi)) \\ \cos(\omega_k\cdot(t+\phi)) \\ \end{bmatrix} M⋅[sin(ωk​⋅t)cos(ωk​⋅t)​]=[sin(ωk​⋅(t+ϕ))cos(ωk​⋅(t+ϕ))​]证明: 让 M M M 表示一个 2 × 2 2\times2 2×2 矩阵,我们想要找到 u 1 u_1 u1​、 v 1 v_1 v1​、 u 2 u_2 u2​ 和 v 2 v_2 v2​ 使得: [ u 1 v 1 u 2 v 2 ] ⋅ [ sin ⁡ ( ω k ⋅ t ) cos ⁡ ( ω k ⋅ t ) ] = [ sin ⁡ ( ω k ⋅ ( t + ϕ ) ) cos ⁡ ( ω k ⋅ ( t + ϕ ) ) ] \begin{bmatrix} u_1 & v_1 \\ u_2 & v_2 \\ \end{bmatrix}\cdot\begin{bmatrix} \sin(\omega_k\cdot t) \\ \cos(\omega_k\cdot t) \\ \end{bmatrix}=\begin{bmatrix} \sin(\omega_k\cdot(t+\phi)) \\ \cos(\omega_k\cdot(t+\phi)) \\ \end{bmatrix} [u1​u2​​v1​v2​​]⋅[sin(ωk​⋅t)cos(ωk​⋅t)​]=[sin(ωk​⋅(t+ϕ))cos(ωk​⋅(t+ϕ))​]通过和角定理,我们可以将右手边展开: [ u 1 v 1 u 2 v 2 ] ⋅ [ sin ⁡ ( ω k ⋅ t ) cos ⁡ ( ω k ⋅ t ) ] = [ sin ⁡ ( ω k ⋅ t ) cos ⁡ ( ω k ⋅ ϕ ) + cos ⁡ ( ω k ⋅ t ) sin ⁡ ( ω k ⋅ ϕ ) cos ⁡ ( ω k ⋅ t ) cos ⁡ ( ω k ⋅ ϕ ) − sin ⁡ ( ω k ⋅ t ) sin ⁡ ( ω k ⋅ ϕ ) ] \begin{bmatrix} u_1 & v_1 \\ u_2 & v_2 \\ \end{bmatrix}\cdot\begin{bmatrix} \sin(\omega_k\cdot t) \\ \cos(\omega_k\cdot t) \\ \end{bmatrix}=\begin{bmatrix} \sin(\omega_k\cdot t)\cos(\omega_k\cdot\phi)+\cos(\omega_k\cdot t)\sin(\omega_k\cdot\phi)\\ \cos(\omega_k\cdot t)\cos(\omega_k\cdot\phi)-\sin(\omega_k\cdot t)\sin(\omega_k\cdot\phi)\\ \end{bmatrix} [u1​u2​​v1​v2​​]⋅[sin(ωk​⋅t)cos(ωk​⋅t)​]=[sin(ωk​⋅t)cos(ωk​⋅ϕ)+cos(ωk​⋅t)sin(ωk​⋅ϕ)cos(ωk​⋅t)cos(ωk​⋅ϕ)−sin(ωk​⋅t)sin(ωk​⋅ϕ)​]从而得到下面两条方程: u 1 sin ⁡ ( ω k ⋅ t ) + v 1 cos ⁡ ( ω k ⋅ t ) = cos ⁡ ( ω k ⋅ ϕ ) sin ⁡ ( ω k ⋅ t ) + sin ⁡ ( ω k ⋅ ϕ ) cos ⁡ ( ω k ⋅ t ) u 2 sin ⁡ ( ω k ⋅ t ) + v 2 cos ⁡ ( ω k ⋅ t ) = − sin ⁡ ( ω k ⋅ ϕ ) sin ⁡ ( ω k ⋅ t ) + cos ⁡ ( ω k ⋅ ϕ ) cos ⁡ ( ω k ⋅ t ) \begin{align} u_1\sin(\omega_k\cdot t)+v_1\cos(\omega_k\cdot t)&=\cos(\omega_k\cdot\phi)\sin(\omega_k\cdot t)+\sin(\omega_k\cdot\phi)\cos(\omega_k\cdot t)\tag1\\ u_2\sin(\omega_k\cdot t)+v_2\cos(\omega_k\cdot t)&=-\sin(\omega_k\cdot \phi)\sin(\omega_k\cdot t)+\cos(\omega_k\cdot\phi)\cos(\omega_k\cdot t)\tag2\\ \end{align} u1​sin(ωk​⋅t)+v1​cos(ωk​⋅t)u2​sin(ωk​⋅t)+v2​cos(ωk​⋅t)​=cos(ωk​⋅ϕ)sin(ωk​⋅t)+sin(ωk​⋅ϕ)cos(ωk​⋅t)=−sin(ωk​⋅ϕ)sin(ωk​⋅t)+cos(ωk​⋅ϕ)cos(ωk​⋅t)​(1)(2)​解上面的方程,我们得到: u 1 = cos ⁡ ( ω k ⋅ ϕ ) v 1 = sin ⁡ ( ω k ⋅ ϕ ) u 2 = − sin ⁡ ( ω k ⋅ ϕ ) v 2 = cos ⁡ ( ω k ⋅ ϕ ) \begin{aligned} u_1=&\cos(\omega_k\cdot\phi)&&v_1=\sin(\omega_k\cdot\phi)\\ u_2=&-\sin(\omega_k\cdot \phi)&&v_2=\cos(\omega_k\cdot\phi)\\ \end{aligned} u1​=u2​=​cos(ωk​⋅ϕ)−sin(ωk​⋅ϕ)​​v1​=sin(ωk​⋅ϕ)v2​=cos(ωk​⋅ϕ)​所以最后的变换矩阵 M M M 是: M ϕ , k = [ cos ⁡ ( ω k ⋅ ϕ ) sin ⁡ ( ω k ⋅ ϕ ) − sin ⁡ ( ω k ⋅ ϕ ) cos ⁡ ( ω k ⋅ ϕ ) ] M_{\phi,k}=\begin{bmatrix} \cos(\omega_k\cdot\phi)&\sin(\omega_k\cdot\phi)\\ -\sin(\omega_k\cdot\phi)&\cos(\omega_k\cdot\phi)\\ \end{bmatrix} Mϕ,k​=[cos(ωk​⋅ϕ)−sin(ωk​⋅ϕ)​sin(ωk​⋅ϕ)cos(ωk​⋅ϕ)​]你可以看到,最后的变换时不依赖于 t t t的,注意你可以发现这个矩阵 M M M 和旋转矩阵很相似。

相似地,我们可以为其他正余弦对找到 M M M,从而最终允许我们对任何固定的偏置 ϕ \phi ϕ,将 p t + ϕ ⃗ \vec{p_{t+\phi}} pt+ϕ​ ​ 表示成 p t ⃗ \vec{p_t} pt​ ​ 的线性函数。这个性质使得模型可以很容易地学习把握相对位置。

另一个三角位置编码的性质是,相邻时间戳的距离是对称的,并且随时间很好地衰减。 图3 - 所有时间戳位置嵌入的点积

图3 - 所有时间戳位置嵌入的点积 FAQ 为什么位置嵌入是与字词嵌入相加的而不是相连?

对于这个问题我找不到任何理论的解释。因为相加(对比于相连)保留了模型的参数,可以将原来的问题重改成“把位置嵌入加到字词中有什么缺点吗?”。我会回答说:不一定!

最开始,如果我们注意图2,我们会发现整个嵌入只有最开始的一些维度用来存储位置的信息(注意报告的维度是512的,尽管我们的过家家示例很小)。而因为Transformer中的嵌入是重新训练过的,所以参数可能是以一种字词的语义不储存在最开始的一些维度的方式来避免干扰到位置编码。

同样的原因,我认为最后Transformer可以将字词的语义从他们的位置信息中分离开来。而且,没有理由把可分离性作为一种优点。可能这种相加为模型学习特征提供了一个好的来源吧。

对于更多信息,我推荐你看看这些链接:链接1、链接2。

一旦它到了更上面的层,位置信息不会消失吗?

好巧不巧,Transformer架构是带有残差链接的。所以模型输入的信息(保留了位置嵌入)可以高效地传播到更远的网络层中,这里会有更复杂的计算。

为什么正弦和余弦都用了?

从个人角度来说,我觉得,只有都用正弦和余弦,我们才能将 sin ⁡ ( x + k ) \sin(x+k) sin(x+k) 和 cos ⁡ ( x + k ) \cos(x+k) cos(x+k) 表达成 sin ⁡ ( x ) \sin(x) sin(x) 和 cos ⁡ ( x ) \cos(x) cos(x) 的线性组合。而你只用一个正弦或者余弦好像不能做到同样的事。如果你能找到单个正弦或余弦的线性变换,请在评论部分告诉我。

总结

感谢一直您陪我看到这篇文章的结尾。我希望这篇文章对你的问题有所帮助。请随意指正或者反馈,评论区随时为您敞开。

这样引用:

@article{kazemnejad2019:pencoding, title = "Transformer Architecture: The Positional Encoding", author = "Kazemnejad, Amirhossein", journal = "kazemnejad.com", year = "2019", url = "https://kazemnejad.com/blog/transformer_architecture_positional_encoding/" } 本文参考 The Illustrated TransformerAttention Is All You Need - The TransformerLinear Relationships in the Transformer’s Positional Encodingposition_encoding.ipynbTensor2Tensor Github issue #1591Reddit thread - Positional Encoding in TransformerReddit thread - Positional Encoding in Transformer model 评论区置顶

太牛啦!我觉得位置嵌入还有一种更直觉的解释,就是把它想象成一个钟(因为余弦和正弦就是单元圆的概念)。位置编码的每两个维度可以看成是钟的针(时针、分针、秒针等)。它们从一个位置移动到下一个位置就是在不同的频率下旋转它们的针。所以,尽管没有公式推导,它也能立刻告诉你为什么那个旋转矩阵存在。 (其他评论去原文看吧,里面还讨论了如果指针转回到原来的地方了怎么办)



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有