.. _sec_multihead-attention: 多头注意力 ========== 在实践中,当给定相同的查询、键和值的集合时, 我们希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如,短距离依赖和长距离依赖关系)。 因此,允许注意力机制组合使用查询、键和值的不同 *子空间表示*\ (representation subspaces)可能是有益的。 为此,与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的\ :math:`h`\ 组不同的 *线性投影*\ (linear projections)来变换查询、键和值。 然后,这\ :math:`h`\ 组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这\ :math:`h`\ 个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。 这种设计被称为\ *多头注意力*\ (multihead attention) :cite:`Vaswani.Shazeer.Parmar.ea.2017`\ 。 对于\ :math:`h`\ 个注意力汇聚输出,每一个注意力汇聚都被称作一个\ *头*\ (head)。 :numref:`fig_multi-head-attention` 展示了使用全连接层来实现可学习的线性变换的多头注意力。 .. _fig_multi-head-attention: .. figure:: ../img/multi-head-attention.svg 多头注意力:多个头连结然后线性变换 模型 ---- 在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。 给定查询\ :math:`\mathbf{q} \in \mathbb{R}^{d_q}`\ 、 键\ :math:`\mathbf{k} \in \mathbb{R}^{d_k}`\ 和 值\ :math:`\mathbf{v} \in \mathbb{R}^{d_v}`\ , 每个注意力头\ :math:`\mathbf{h}_i`\ (\ :math:`i = 1, \ldots, h`\ )的计算方法为: .. math:: \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v}, 其中,可学习的参数包括 :math:`\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}`\ 、 :math:`\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}`\ 和 :math:`\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}`\ , 以及代表注意力汇聚的函数\ :math:`f`\ 。 :math:`f`\ 可以是 :numref:`sec_attention-scoring-functions`\ 中的 加性注意力和缩放点积注意力。 多头注意力的输出需要经过另一个线性转换, 它对应着\ :math:`h`\ 个头连结后的结果,因此其可学习参数是 :math:`\mathbf W_o\in\mathbb R^{p_o\times h p_v}`\ : .. math:: \mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}. 基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。 .. raw:: html
mxnetpytorchtensorflowpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import math from mxnet import autograd, np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import math import torch from torch import nn from d2l import torch as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf from d2l import tensorflow as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import warnings from d2l import paddle as d2l warnings.filterwarnings("ignore") import math import paddle from paddle import nn .. raw:: html
.. raw:: html
实现 ---- 在实现过程中通常选择缩放点积注意力作为每一个注意力头。 为了避免计算代价和参数代价的大幅增长, 我们设定\ :math:`p_q = p_k = p_v = p_o / h`\ 。 值得注意的是,如果将查询、键和值的线性变换的输出数量设置为 :math:`p_q h = p_k h = p_v h = p_o`\ , 则可以并行计算\ :math:`h`\ 个头。 在下面的实现中,\ :math:`p_o`\ 是通过参数\ ``num_hiddens``\ 指定的。 .. raw:: html
mxnetpytorchtensorflowpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save class MultiHeadAttention(nn.Block): """多头注意力""" def __init__(self, num_hiddens, num_heads, dropout, use_bias=False, **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) def forward(self, queries, keys, values, valid_lens): # queries,keys,values的形状: # (batch_size,查询或者“键-值”对的个数,num_hiddens) # valid_lens 的形状: # (batch_size,)或(batch_size,查询的个数) # 经过变换后,输出的queries,keys,values 的形状: # (batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads) if valid_lens is not None: # 在轴0,将第一项(标量或者矢量)复制num_heads次, # 然后如此复制第二项,然后诸如此类。 valid_lens = valid_lens.repeat(self.num_heads, axis=0) # output的形状:(batch_size*num_heads,查询的个数, # num_hiddens/num_heads) output = self.attention(queries, keys, values, valid_lens) # output_concat的形状:(batch_size,查询的个数,num_hiddens) output_concat = transpose_output(output, self.num_heads) return self.W_o(output_concat) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save class MultiHeadAttention(nn.Module): """多头注意力""" def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias) def forward(self, queries, keys, values, valid_lens): # queries,keys,values的形状: # (batch_size,查询或者“键-值”对的个数,num_hiddens) # valid_lens 的形状: # (batch_size,)或(batch_size,查询的个数) # 经过变换后,输出的queries,keys,values 的形状: # (batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads) if valid_lens is not None: # 在轴0,将第一项(标量或者矢量)复制num_heads次, # 然后如此复制第二项,然后诸如此类。 valid_lens = torch.repeat_interleave( valid_lens, repeats=self.num_heads, dim=0) # output的形状:(batch_size*num_heads,查询的个数, # num_hiddens/num_heads) output = self.attention(queries, keys, values, valid_lens) # output_concat的形状:(batch_size,查询的个数,num_hiddens) output_concat = transpose_output(output, self.num_heads) return self.W_o(output_concat) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save class MultiHeadAttention(tf.keras.layers.Layer): """多头注意力""" def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super().__init__(**kwargs) self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=bias) self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=bias) self.W_v = tf.keras.layers.Dense(num_hiddens, use_bias=bias) self.W_o = tf.keras.layers.Dense(num_hiddens, use_bias=bias) def call(self, queries, keys, values, valid_lens, **kwargs): # queries,keys,values的形状: # (batch_size,查询或者“键-值”对的个数,num_hiddens) # valid_lens 的形状: # (batch_size,)或(batch_size,查询的个数) # 经过变换后,输出的queries,keys,values 的形状: # (batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads) if valid_lens is not None: # 在轴0,将第一项(标量或者矢量)复制num_heads次, # 然后如此复制第二项,然后诸如此类。 valid_lens = tf.repeat(valid_lens, repeats=self.num_heads, axis=0) # output的形状:(batch_size*num_heads,查询的个数, # num_hiddens/num_heads) output = self.attention(queries, keys, values, valid_lens, **kwargs) # output_concat的形状:(batch_size,查询的个数,num_hiddens) output_concat = transpose_output(output, self.num_heads) return self.W_o(output_concat) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save class MultiHeadAttention(nn.Layer): def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.Linear(query_size, num_hiddens, bias_attr=bias) self.W_k = nn.Linear(key_size, num_hiddens, bias_attr=bias) self.W_v = nn.Linear(value_size, num_hiddens, bias_attr=bias) self.W_o = nn.Linear(num_hiddens, num_hiddens, bias_attr=bias) def forward(self, queries, keys, values, valid_lens): # queries,keys,values的形状: # (batch_size,查询或者“键-值”对的个数,num_hiddens) # valid_lens 的形状: # (batch_size,)或(batch_size,查询的个数) # 经过变换后,输出的queries,keys,values 的形状: # (batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads) if valid_lens is not None: # 在轴0,将第一项(标量或者矢量)复制num_heads次, # 然后如此复制第二项,然后诸如此类。 valid_lens = paddle.repeat_interleave( valid_lens, repeats=self.num_heads, axis=0) # output的形状:(batch_size*num_heads,查询的个数, # num_hiddens/num_heads) output = self.attention(queries, keys, values, valid_lens) # output_concat的形状:(batch_size,查询的个数,num_hiddens) output_concat = transpose_output(output, self.num_heads) return self.W_o(output_concat) .. raw:: html
.. raw:: html
为了能够使多个头并行计算, 上面的\ ``MultiHeadAttention``\ 类将使用下面定义的两个转置函数。 具体来说,\ ``transpose_output``\ 函数反转了\ ``transpose_qkv``\ 函数的操作。 .. raw:: html
mxnetpytorchtensorflowpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def transpose_qkv(X, num_heads): """为了多注意力头的并行计算而变换形状""" # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens) # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads, # num_hiddens/num_heads) X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) X = X.transpose(0, 2, 1, 3) # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) return X.reshape(-1, X.shape[2], X.shape[3]) #@save def transpose_output(X, num_heads): """逆转transpose_qkv函数的操作""" X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) X = X.transpose(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def transpose_qkv(X, num_heads): """为了多注意力头的并行计算而变换形状""" # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens) # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads, # num_hiddens/num_heads) X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) X = X.permute(0, 2, 1, 3) # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) return X.reshape(-1, X.shape[2], X.shape[3]) #@save def transpose_output(X, num_heads): """逆转transpose_qkv函数的操作""" X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) X = X.permute(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def transpose_qkv(X, num_heads): """为了多注意力头的并行计算而变换形状""" # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens) # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads, # num_hiddens/num_heads) X = tf.reshape(X, shape=(X.shape[0], X.shape[1], num_heads, -1)) # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) X = tf.transpose(X, perm=(0, 2, 1, 3)) # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) return tf.reshape(X, shape=(-1, X.shape[2], X.shape[3])) #@save def transpose_output(X, num_heads): """逆转transpose_qkv函数的操作""" X = tf.reshape(X, shape=(-1, num_heads, X.shape[1], X.shape[2])) X = tf.transpose(X, perm=(0, 2, 1, 3)) return tf.reshape(X, shape=(X.shape[0], X.shape[1], -1)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def transpose_qkv(X, num_heads): """为了多注意力头的并行计算而变换形状""" # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens) # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads, # num_hiddens/num_heads) X = X.reshape((X.shape[0], X.shape[1], num_heads, -1)) # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) X = X.transpose((0, 2, 1, 3)) # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) return X.reshape((-1, X.shape[2], X.shape[3])) #@save def transpose_output(X, num_heads): """逆转transpose_qkv函数的操作""" X = X.reshape((-1, num_heads, X.shape[1], X.shape[2])) X = X.transpose((0, 2, 1, 3)) return X.reshape((X.shape[0], X.shape[1], -1)) .. raw:: html
.. raw:: html
下面使用键和值相同的小例子来测试我们编写的\ ``MultiHeadAttention``\ 类。 多头注意力输出的形状是(\ ``batch_size``\ ,\ ``num_queries``\ ,\ ``num_hiddens``\ )。 .. raw:: html
mxnetpytorchtensorflowpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_heads, 0.5) attention.initialize() batch_size, num_queries = 2, 4 num_kvpairs, valid_lens = 6, np.array([3, 2]) X = np.ones((batch_size, num_queries, num_hiddens)) Y = np.ones((batch_size, num_kvpairs, num_hiddens)) attention(X, Y, Y, valid_lens).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [07:00:49] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (2, 4, 100) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5) attention.eval() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output MultiHeadAttention( (attention): DotProductAttention( (dropout): Dropout(p=0.5, inplace=False) ) (W_q): Linear(in_features=100, out_features=100, bias=False) (W_k): Linear(in_features=100, out_features=100, bias=False) (W_v): Linear(in_features=100, out_features=100, bias=False) (W_o): Linear(in_features=100, out_features=100, bias=False) ) .. raw:: latex \diilbookstyleinputcell .. code:: python batch_size, num_queries = 2, 4 num_kvpairs, valid_lens = 6, torch.tensor([3, 2]) X = torch.ones((batch_size, num_queries, num_hiddens)) Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) attention(X, Y, Y, valid_lens).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output torch.Size([2, 4, 100]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5) batch_size, num_queries = 2, 4 num_kvpairs, valid_lens = 6, tf.constant([3, 2]) X = tf.ones((batch_size, num_queries, num_hiddens)) Y = tf.ones((batch_size, num_kvpairs, num_hiddens)) attention(X, Y, Y, valid_lens, training=False).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output TensorShape([2, 4, 100]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5) attention.eval() batch_size, num_queries = 2, 4 num_kvpairs, valid_lens = 6, paddle.to_tensor([3, 2]) X = paddle.ones((batch_size, num_queries, num_hiddens)) Y = paddle.ones((batch_size, num_kvpairs, num_hiddens)) attention(X, Y, Y, valid_lens).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [2, 4, 100] .. raw:: html
.. raw:: html
小结 ---- - 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。 - 基于适当的张量操作,可以实现多头注意力的并行计算。 练习 ---- 1. 分别可视化这个实验中的多个头的注意力权重。 2. 假设有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。如何设计实验来衡量注意力头的重要性呢? .. raw:: html
mxnetpytorchpaddle
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html