去掉Attention的Softmax,复杂度降为O(n)

众所周知,尽管基于Attention机制的Transformer类模型有着良好的并行性能,但它的空间和时间复杂度都是 O ( n 2 ) \mathcal{O}(n^2) O(n2)级别的, n n n是序列长度,所以当 n n n比较大时Transformer模型的计算量难以承受。近来,也有不少工作致力于降低Transformer模型的计算量,比如模型剪枝、量化、蒸馏等精简技术,又或者修改Attention结构,使得其复杂度能降低到 O ( n l o g ⁡ n ) \mathcal{O}(nlog⁡n) O(nlogn)甚至 O ( n ) \mathcal{O}(n) O(n)

论文《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》当中提到一种线性化Attention(Linear Attention)的方法,由此引发了我的兴趣,继而阅读了一些相关博客,有一些不错的收获,最后将自己对线性化Attention的理解汇总在此文中

Attention

当前最流行的Attention机制当属Scaled-Dot Attention,即
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K ⊤ ) V (1) \begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\boldsymbol{Q}\boldsymbol{K}^{\top}\right)\boldsymbol{V}\tag{1}\end{aligned} Attention(Q,K,V)=softmax(QK)V(1)
这里的 Q ∈ R n × d k , K ∈ R m × d k , V ∈ R m × d v \boldsymbol{Q}\in \mathbb{R}^{n\times d_k}, \boldsymbol{K}\in \mathbb{R}^{m\times d_k}, \boldsymbol{V}\in \mathbb{R}^{m\times d_v} QRn×dk,KRm×dk,VRm×dv,简单起见我就没显示的写出Attention的缩放因子 1 d \frac{1}{\sqrt{d}} d 1了。本文我们主要关心Self Attention的场景,所以为了介绍上的方便,统一设 Q , K , V ∈ R n × d \boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}\in \mathbb{R}^{n\times d} Q,K,VRn×d

摘掉Softmax

读者也许想不到,制约Attention性能的关键因素,其实是定义里边的Softmax!事实上,简单地推导一下就可以得到这个结论。 Q K T QK^T QKT这一步我们得到一个 n × n n\times n n×n的矩阵,之后还要做一个Softmax

对一个 1 × n 1\times n 1×n的行向量进行Softmax,时间复杂度是 O ( n ) O(n) O(n),但是对一个 n × n n\times n n×n矩阵的每一行做一个Softmax,时间复杂度就是 O ( n 2 ) O(n^2) O(n2)

如果没有Softmax,那么Attention的公式就变为三个矩阵连乘 Q K ⊤ V \boldsymbol{QK^{\top}V} QKV,而矩阵乘法是满足结合率的,所以我们可以先算 K ⊤ V \boldsymbol{K^{\top}V} KV,得到一个 d × d d\times d d×d的矩阵(这一步的时间复杂度是 O ( d 2 n ) O(d^2n) O(d2n)),然后再用 Q Q Q左乘它(这一步的时间复杂度是 O ( d 2 n ) O(d^2n) O(d2n)),由于 d ≪ n d \ll n dn,所以这样算大致的时间复杂度只是 O ( n ) O(n) O(n)

对于BERT base来说, d = 64 d=64 d=64而不是768,why?因为768实际上是通过Multi-Head拼接得到的,而每个head的 d = 64 d=64 d=64

也就是说,去掉Softmax的Attention复杂度可以降到最理想的线性级别 O ( n ) \mathcal{O}(n) O(n)!这显然就是我们的终极追求:Linear Attention

一般的定义

问题是,直接去掉Softmax还能算是Attention吗?他还能有标准的Attention的效果吗?为了回答这个问题,我们先将Scaled-Dot Attention的定义等价的改写为(本文的向量都是列向量)
A t t e n t i o n ( Q , K , V ) i = ∑ j = 1 n e q i ⊤ k j v j ∑ j = 1 n e q i ⊤ k j (2) \begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\boldsymbol{v}_j}{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}}\tag{2}\end{aligned} Attention(Q,K,V)i=j=1neqikjj=1neqikjvj(2)

这里稍微解释下,首先我们知道 Q , K ∈ R n × d \boldsymbol{Q},\boldsymbol{K}\in \mathbb{R}^{n\times d} Q,KRn×d,令 M = Q × K ⊤ \boldsymbol{M} = \boldsymbol{Q}\times \boldsymbol{K^{\top}} M=Q×K,由矩阵乘法法则可知, M \boldsymbol{M} M的第一行是由 Q \boldsymbol{Q} Q的第一行乘以 K ⊤ \boldsymbol{K^{\top}} K的所有列得到的

A t t e n t i o n ( Q , K , V ) i Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i Attention(Q,K,V)i表示最终输出结果矩阵的第 i i i

q i ⊤ \boldsymbol{q}_i^{\top} qi表示 Q ∈ R n × d \boldsymbol{Q}\in \mathbb{R}^{n\times d} QRn×d矩阵的第 i i i行(行向量)

k j \boldsymbol{k}_j kj表示 K ⊤ ∈ R d × n \boldsymbol{K^{\top}}\in \mathbb{R}^{d\times n} KRd×n矩阵的第 j j j列(列向量)

v j \boldsymbol{v}_j vj表示 V ⊤ ∈ R d × n V^{\top}\in \mathbb{R}^{d\times n} VRd×n矩阵的的第 j j j列(列向量)

所以,Scaled-Dot Attention其实就是以 e q i ⊤ k j e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} eqikj为权重对 v j \boldsymbol{v}_j vj做加权平均。所以我们可以提出一个Attention的一般化定义
A t t e n t i o n ( Q , K , V ) i = ∑ j = 1 n sim ( q i , k j ) v j ∑ j = 1 n sim ( q i , k j ) (3) \begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)}\tag{3}\end{aligned} Attention(Q,K,V)i=j=1nsim(qi,kj)j=1nsim(qi,kj)vj(3)
也就是把 e q i ⊤ k j e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} eqikj换成 q i , k i \boldsymbol{q}_i,\boldsymbol{k}_i qi,ki的一般函数 sim ( q i , k j ) \text{sim}(\boldsymbol{q}_i,\boldsymbol{k}_j) sim(qi,kj),为了保留Attention相似的分布特性,我们要求 sim ( q i , k j ) ≥ 0 \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0 sim(qi,kj)0恒成立。也就是说,我们如果要定义新的Attention,必须要保留式(3)的形式,并且满足 sim ( q i , k j ) ≥ 0 \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0 sim(qi,kj)0

这种一般形式的Attention在CV中也被称为Non-Local网络,出自论文《Non-local Neural Networks》

几个例子

如果直接去掉Softmax,那么就是 sim ( q i , k j ) = q i ⊤ k j \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \boldsymbol{q}_i^{\top}\boldsymbol{k}_j sim(qi,kj)=qikj,问题是内积无法保证非负性,所以这还不是一个合理的选择。下面我们介绍几种可取的方案

值得一提的是,下面介绍的这几种Linear Attention,前两种来自CV领域,第三种是苏剑林大佬构思的(除了下面的介绍外,还有EMANet等CV领域对Attention的改进工作)

核函数形式

一个自然的想法是:如果 q i , k j \boldsymbol{q}_i, \boldsymbol{k}_j qi,kj的每个元素都是非负的,那么内积自然也是非负的。为了完成这点,我们可以给 q i , k j \boldsymbol{q}_i, \boldsymbol{k}_j qi,kj各自加个激活函数 ϕ , φ \phi,\varphi ϕ,φ,即
sim ( q i , k j ) = ϕ ( q i ) ⊤ φ ( k j ) (4) \begin{aligned}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\tag{4}\end{aligned} sim(qi,kj)=ϕ(qi)φ(kj)(4)
其中 ϕ ( ⋅ ) , φ ( ⋅ ) \phi(\cdot), \varphi(\cdot) ϕ(),φ()是值域非负的激活函数。本文开头提到的论文《Transformers are RNNs》选择的是 ϕ ( x ) = φ ( x ) = elu ( x ) + 1 \phi(x)=\varphi(x)=\text{elu}(x)+1 ϕ(x)=φ(x)=elu(x)+1,其中

elu ( x ) = { x if  x > 0 α ( e x − 1 ) if  x < 0 \text{elu}(x)=\begin{cases}x& \text{if} \ x>0\\ \alpha (e^x-1) & \text{if}\ x<0\end{cases} elu(x)={xα(ex1)if x>0if x<0

常见的 α \alpha α取值为 [ 0.1 , 0.3 ] [0.1, 0.3] [0.1,0.3]

非要讲故事的话,式(4)可以联想到"核方法",尤其是 ϕ = φ \phi=\varphi ϕ=φ时, ϕ \phi ϕ就相当于一个核函数,而 ⟨ ϕ ( q i ) , ϕ ( k j ) ⟩ \langle \phi(\boldsymbol{q}_i), \phi(\boldsymbol{k}_j)\rangle ϕ(qi),ϕ(kj)就是通过核函数所定义的内积。这方面的思考可以参考论文《Transformer dissection: An unified understanding for transformer’s attention via the lens of kernel》,此处不做过多延伸

妙用Softmax

另一篇更早的文章《Efficient Attention: Attention with Linear Complexities》则给出了一个更有意思的选择。它留意到在 Q K ⊤ \boldsymbol{QK^{\top}} QK中, Q , K ∈ R n × d \boldsymbol{Q},\boldsymbol{K}\in \mathbb{R}^{n\times d} Q,KRn×d,如果“ Q \boldsymbol{Q} Q d d d那一维是归一化的,并且 K \boldsymbol{K} K n n n那一维是归一化的”,那么 Q K ⊤ \boldsymbol{QK^{\top}} QK就是自动满足归一化了,所以它给出的选择是
A t t e n t i o n ( Q , K , V ) = s o f t m a x 2 ( Q ) s o f t m a x 1 ( K ) ⊤ V (5) \begin{aligned}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax_2\left(\boldsymbol{Q}\right)softmax_1(\boldsymbol{K})^{\top}\boldsymbol{V}\tag{5}\end{aligned} Attention(Q,K,V)=softmax2(Q)softmax1(K)V(5)
其中 s o f t m a x 1 softmax_1 softmax1 s o f t m a x 2 softmax_2 softmax2分别表示在第一个 ( n ) (n) (n)、第二个维度 ( d ) (d) (d)进行Softmax运算。也就是说,这时候我们是各自给 Q , K \boldsymbol{Q},\boldsymbol{K} Q,K加Softmax,而不是算完 Q K ⊤ \boldsymbol{QK^{\top}} QK之后再加Softmax

其实可以证明这个形式也是式(4)​的一个特例,此时对应于 ϕ ( q i ) = s o f t m a x ( q i ) , φ ( k j ) = e k j \phi(\boldsymbol{q}_i)=softmax(\boldsymbol{q}_i),\varphi(\boldsymbol{k}_j)=e^{\boldsymbol{k}_j} ϕ(qi)=softmax(qi),φ(kj)=ekj,读者可以自行推导一下

苏神的构思

在这里,苏神给出了一种构思。这个构思的出发点不再是式(4),而是源于我们对原始定义(2)​的泰勒展开。由泰勒展开我们有
e q i ⊤ k j ≈ 1 + q i ⊤ k j (6) \begin{aligned}e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} \approx 1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\tag{6}\end{aligned} eqikj1+qikj(6)
如果 q i ⊤ k j ≥ − 1 \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\geq -1 qikj1,那么就可以保证右端的非负性,从而可以让 sim ( q i , k j ) = 1 + q i ⊤ k j \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)=1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_j sim(qi,kj)=1+qikj。到这里读者可能已经想到了,想要保证 q i ⊤ k j ≥ − 1 \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\geq -1 qikj1,只需要分别对 q i , k j \boldsymbol{q}_i,\boldsymbol{k}_j qi,kj l 2 l_2 l2归一化。所以,苏神最终提出的方案就是:
sim ( q i , k j ) = 1 + ( q i ∥ q i ∥ ) ⊤ ( k j ∥ k j ∥ ) (7) \begin{aligned}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = 1 + \left( \frac{\boldsymbol{q}_i}{\Vert \boldsymbol{q}_i\Vert}\right)^{\top}\left(\frac{\boldsymbol{k}_j}{\Vert \boldsymbol{k}_j\Vert}\right)\tag{7}\end{aligned} sim(qi,kj)=1+(qiqi)(kjkj)(7)

x = [ x 1 , x 2 , . . . , x n ] \boldsymbol{x}=[x_1,x_2,...,x_n] x=[x1,x2,...,xn],则 ∥ x ∥ = x 1 2 + x 2 2 + ⋅ ⋅ ⋅ + x n 2 \Vert x\Vert=\sqrt{x_1^2+x_2^2+···+x_n^2} x=x12+x22++xn2

这不同于式(4),但理论上它更加接近原始的Scaled-Dot Attention

实现

这里主要是针对苏神所提出的方法进行实现,但是由于笔者本人水平有限,因此最终实现的代码当中其实存在一些问题,主要是:

  1. 从测试结果来看,改进后的计算速度并没有提升
  2. 无法做到求和为1

代码实现主要是针对BERT的PyTorch实现这篇文章的代码,更具体的说,其实仅修改了ScaledDotProductAttention这个函数,因此下面只放出这部分代码

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        Q = F.normalize(Q, dim=3)
        K = F.normalize(K, dim=3)
        M = (torch.ones(Q.shape[0], Q.shape[1], Q.shape[2], K.shape[2]) + torch.matmul(Q, K.transpose(-1, -2))) # scores : [batch_size, n_heads, seq_len, seq_len]
        M_sum = torch.sum(M, dim=3)
        M = M / M_sum.unsqueeze(3).repeat(1, 1, 1, M.shape[3])
        attn = M.masked_fill(attn_mask, 0) # Fills elements of self tensor with value where mask is one.
        context = torch.matmul(attn, V)
        return context

如果您有更好的实现方法,还望不吝赐教

Reference

相关推荐
<p> 需要学习Windows系统YOLOv4的同学请前往《Windows版YOLOv4目标检测实战:原理与源码解析》, </p> <p> 课程链接 https://edu.csdn.net/course/detail/29865 </p> <h3> <span style="color:#3598db;">【为什么要学习这门课】</span> </h3> <p> <span>Linux</span>创始人<span>Linus Torvalds</span>有一句名言:<span>Talk is cheap. Show me the code. </span><strong><span style="color:#ba372a;">冗谈不够,放码过来!</span></strong> </p> <p> <span> </span>代码阅读是从基础到提高的必由之路。尤其对深度学习,许多框架隐藏了神经网络底层的实现,只能在上层调包使用,对其内部原理很难认识清晰,不利于进一步优化和创新。 </p> <p> YOLOv4是最近推出的基于深度学习的端到端实时目标检测方法。 </p> <p> YOLOv4的实现darknet是使用C语言开发的轻型开源深度学习框架,依赖少,可移植性好,可以作为很好的代码阅读案例,让我们深入探究其实现原理。 </p> <h3> <span style="color:#3598db;">【课程内容与收获】</span> </h3> <p> 本课程将解析YOLOv4的实现原理和源码,具体内容包括: </p> <p> - YOLOv4目标检测原理<br /> - 神经网络及darknet的C语言实现,尤其是反向传播的梯度求解和误差计算<br /> - 代码阅读工具及方法<br /> - 深度学习计算的利器:BLAS和GEMM<br /> - GPU的CUDA编程方法及在darknet的应用<br /> - YOLOv4的程序流程 </p> <p> - YOLOv4各层及关键技术的源码解析 </p> <p> 本课程将提供注释后的darknet的源码程序文件。 </p> <h3> <strong><span style="color:#3598db;">【相关课程】</span></strong> </h3> <p> 除本课程《YOLOv4目标检测:原理与源码解析》外,本人推出了有关YOLOv4目标检测的系列课程,包括: </p> <p> 《YOLOv4目标检测实战:训练自己的数据集》 </p> <p> 《YOLOv4-tiny目标检测实战:训练自己的数据集》 </p> <p> 《YOLOv4目标检测实战:人脸口罩佩戴检测》<br /> 《YOLOv4目标检测实战:中国交通标志识别》 </p> <p> 建议先学习一门YOLOv4实战课程,对YOLOv4的使用方法了解以后再学习本课程。 </p> <h3> <span style="color:#3598db;">【YOLOv4网络模型架构图】</span> </h3> <p> 下图由白勇老师绘制 </p> <p> <img alt="" src="https://img-bss.csdnimg.cn/202006291526195469.jpg" /> </p> <p>   </p> <p> <img alt="" src="https://img-bss.csdnimg.cn/202007011518185782.jpg" /> </p>
©️2020 CSDN 皮肤主题: Age of Ai 设计师:meimeiellie 返回首页
实付 9.90元
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值