10.5. 多头注意力
Open the notebook in Colab
Open the notebook in Colab
Open the notebook in Colab

在实践中,当给定相同的查询、键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,例如捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依赖)。因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。

为此,与使用单独一个注意力汇聚不同,我们可以用独立学习得到的 \(h\) 组不同的 线性投影(linear projections)来变换查询、键和值。然后,这 \(h\) 组变换后的查询、键和值将并行地送到注意力汇聚中。最后,将这 \(h\) 个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。这种设计被称为 多头注意力,其中 \(h\) 个注意力汇聚输出中的每一个输出都被称作一个 (head) [Vaswani et al., 2017]图10.5.1 展示了使用全连接层来实现可学习的线性变换的多头注意力。

../_images/multi-head-attention.svg

图10.5.1 多头注意力,多个头连结然后线性变换。

10.5.1. 模型

在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。给定查询 \(\mathbf{q} \in \mathbb{R}^{d_q}\)、键 \(\mathbf{k} \in \mathbb{R}^{d_k}\) 和值 \(\mathbf{v} \in \mathbb{R}^{d_v}\),每个注意力头 \(\mathbf{h}_i\) (\(i = 1, \ldots, h\)) 的计算方法为:

(10.5.1)\[\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},\]

其中,可学习的参数包括 \(\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}\)\(\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}\)\(\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}\) ,以及代表注意力汇聚的函数 \(f\)\(f\) 可以是 10.3节 中的加性注意力和缩放点积注意力。多头注意力的输出需要经过另一个线性转换,它对应着 \(h\) 个头连结后的结果,因此其可学习参数是 \(\mathbf W_o\in\mathbb R^{p_o\times h p_v}\)

(10.5.2)\[\begin{split}\mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.\end{split}\]

基于这种设计,每个头都可能会关注输入的不同部分。可以表示比简单加权平均值更复杂的函数。

import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()
import math
import torch
from torch import nn
from d2l import torch as d2l

10.5.2. 实现

在实现过程中,我们选择缩放点积注意力作为每一个注意力头。为了避免计算成本和参数数量的大幅增长,我们设定 \(p_q = p_k = p_v = p_o / h\)。值得注意的是,如果我们将查询、键和值的线性变换的输出数量设置为 \(p_q h = p_k h = p_v h = p_o\),则可以并行计算 \(h\) 个头。在下面的实现中,\(p_o\) 是通过参数 num_hiddens 指定的。

#@save
class MultiHeadAttention(nn.Block):
    def __init__(self, num_hiddens, num_heads, dropout, use_bias=False,
                 **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)

    def forward(self, queries, keys, values, valid_lens):
        # `queries`, `keys`, or `values` 的形状:
        # (`batch_size`, 查询或者“键-值”对的个数, `num_hiddens`)
        # `valid_lens` 的形状:
        # (`batch_size`,) 或 (`batch_size`, 查询的个数)
        # 经过变换后,输出的 `queries`, `keys`, or `values` 的形状:
        # (`batch_size` * `num_heads`, 查询或者“键-值”对的个数,
        # `num_hiddens` / `num_heads`)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴 0,将第一项(标量或者矢量)复制 `num_heads` 次,
            # 然后如此复制第二项,然后诸如此类。
            valid_lens = valid_lens.repeat(self.num_heads, axis=0)

        # `output` 的形状: (`batch_size` * `num_heads`, 查询的个数,
        # `num_hiddens` / `num_heads`)
        output = self.attention(queries, keys, values, valid_lens)

        # `output_concat` 的形状: (`batch_size`, 查询的个数, `num_hiddens`)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)
#@save
class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # `queries`, `keys`, or `values` 的形状:
        # (`batch_size`, 查询或者“键-值”对的个数, `num_hiddens`)
        # `valid_lens` 的形状:
        # (`batch_size`,) or (`batch_size`, 查询的个数)
        # 经过变换后,输出的 `queries`, `keys`, or `values` 的形状:
        # (`batch_size` * `num_heads`, 查询或者“键-值”对的个数,
        # `num_hiddens` / `num_heads`)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴 0,将第一项(标量或者矢量)复制 `num_heads` 次,
            # 然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # `output` 的形状: (`batch_size` * `num_heads`, 查询的个数,
        # `num_hiddens` / `num_heads`)
        output = self.attention(queries, keys, values, valid_lens)

        # `output_concat` 的形状: (`batch_size`, 查询的个数, `num_hiddens`)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

为了能够使多个头并行计算,上面的 MultiHeadAttention 类使用了下面定义的两个转置函数。具体来说,transpose_output 函数反转了 transpose_qkv 函数的操作。

#@save
def transpose_qkv(X, num_heads):
    # 输入 `X` 的形状: (`batch_size`, 查询或者“键-值”对的个数, `num_hiddens`).
    # 输出 `X` 的形状: (`batch_size`, 查询或者“键-值”对的个数, `num_heads`,
    # `num_hiddens` / `num_heads`)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出 `X` 的形状: (`batch_size`, `num_heads`, 查询或者“键-值”对的个数,
    # `num_hiddens` / `num_heads`)
    X = X.transpose(0, 2, 1, 3)

    # `output` 的形状: (`batch_size` * `num_heads`, 查询或者“键-值”对的个数,
    # `num_hiddens` / `num_heads`)
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    """逆转 `transpose_qkv` 函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.transpose(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)
#@save
def transpose_qkv(X, num_heads):
    # 输入 `X` 的形状: (`batch_size`, 查询或者“键-值”对的个数, `num_hiddens`).
    # 输出 `X` 的形状: (`batch_size`, 查询或者“键-值”对的个数, `num_heads`,
    # `num_hiddens` / `num_heads`)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出 `X` 的形状: (`batch_size`, `num_heads`, 查询或者“键-值”对的个数,
    # `num_hiddens` / `num_heads`)
    X = X.permute(0, 2, 1, 3)

    # `output` 的形状: (`batch_size` * `num_heads`, 查询或者“键-值”对的个数,
    # `num_hiddens` / `num_heads`)
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    """逆转 `transpose_qkv` 函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

让我们使用键和值相同的小例子来测试我们编写的 MultiHeadAttention 类。多头注意力输出的形状是 (batch_size, num_queries, num_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()

batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
Y = np.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
(2, 4, 100)
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])

10.5.3. 小结

  • 多头注意力融合了来自于相同的注意力汇聚产生的不同的知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。

  • 基于适当的张量操作,可以实现多头注意力的并行计算。

10.5.4. 练习

  1. 分别可视化这个实验中的多个头的注意力权重。

  2. 假设我们已经拥有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。应该如何设计实验来衡量注意力头的重要性?