跳转至

为什么注意力机制中要除以 \(\sqrt{d_k}\) :从方差到梯度的推导

\[ Attention(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]

在 Transformer 的注意力机制中,计算点积注意力 \(QK^T\) 之后,需要除以一个 \(\sqrt{d_k}\) 进行缩放。这一操作通常被解释为“为了数值稳定性”。这里的“稳定”究竟指的是什么?如果不除以 \(\sqrt{d_k}\) 就不稳定了么?为什么不除以 \(d_k\) 或其他数值呢?

本文分析了点积的方差如何随维度增长而增大,并进一步推导 \(\text{softmax}\) 变换得到的行向量的雅可比矩阵,展示当输入数值过大时梯度如何逐渐趋近于零。通过这一过程,我们将会理解,除以 \(\sqrt{d_k}\) 并不是随意的设置,而是确保注意力机制在高维空间中仍能保持可训练性的必要条件。

问题拆解:从方差到梯度

对于为什么要对点积进行缩放,Transformer 原论文中给出了如下解释:

Attention Is All You Need

We suspect that for large values of \(d_k\), the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients.

也就是说,当 \(d_k\) 取较大值时,点积 \(QK^T\) 中的数值会变得非常大,从而使 \(\text{softmax}\) 的输入落入梯度极小的区域,导致训练困难。

因此,接下来的分析可以拆解为两个问题:

  1. 为什么当 \(d_k\) 较大时,\(QK^T\) 这个点积注意力矩阵中有较高的概率会出现数值较大的元素?
  2. 为什么当 \(QK^T\) 中存在数值较大的元素时,\(\text{softmax}\) 的梯度会趋近于 \(0\)

为什么当 \(d_k\) 较大时,\(QK^T\) 这个点积注意力矩阵中有较高的概率会出现数值较大的元素?

Transformer 原论文中指出,若 \(q\)\(k\) 的各维度独立,且均值为 0、方差为 1,则点积注意力矩阵中的元素 \(q \cdot k\) 的均值为 \(0\)、方差为 \(d_k\)

Attention Is All You Need

To illustrate why the dot products get large, assume that the components of \(q\) and \(k\) are independent random variables with mean 0 and variance 1. Then their dot product

\[ q \cdot k = \sum_{i=1}^{d_k} (q_i \cdot k_i) \]

has mean 0 and variance \(d_k\).

下面我们来用概率统计的知识推导一下这个结论。

均值 \(\mathbb E[q\cdot k] = 0\)

\[ \begin{aligned} \mathbb E[q\cdot k] &= \mathbb E\left[\sum_{i=1}^{d_k} q_i k_i\right] \\ &= \sum_{i=1}^{d_k} \mathbb E[q_i k_i] && \text{(期望线性)} \\ &= \sum_{i=1}^{d_k} \mathbb E[q_i]\mathbb E[k_i] && \text{(由于 $q_i$ 和 $k_i$ 独立)}\\ &= \sum_{i=1}^{d_k} 0 \cdot 0 && \text{(由于 $q_i$ 和 $k_i$ 均值为 0)} \\ &= 0 \end{aligned} \]

方差 \(\text{Var}(q\cdot k) = d_k\)

\[ \begin{aligned} \text{Var}(q\cdot k) &= \text{Var}\left(\sum_{i=1}^{d_k} q_i k_i\right) \\ &= \sum_{i=1}^{d_k} \text{Var}(q_i k_i) && \text{(由于 $q_i k_i$ 和 $q_j k_j$ 独立 $\forall i \neq j$)}\\ &= \sum_{i=1}^{d_k} \left[ \mathbb E[(q_i k_i)^2] - (\mathbb E[q_i k_i])^2 \right] && \text{(方差定义)}\\ &= \sum_{i=1}^{d_k} \left[ \mathbb E[q_i^2]\mathbb E[k_i^2] - (\mathbb E[q_i]\mathbb E[k_i])^2 \right] && \text{(由于 $q_i$ 和 $k_i$ 独立)}\\ &= \sum_{i=1}^{d_k} \left[ 1 \cdot 1 - \left(0 \cdot 0\right)^2 \right] && \text{($\mathbb E[q_i^2] = \text{Var}(q_i) + (\mathbb E[q_i])^2 = 1 + 0^2 = 1$)}\\ &= d_k \end{aligned} \]

以上推导告诉我们,对于点积注意力矩阵中的每个元素 \(q \cdot k\),其均值为 0,方差为 \(d_k\)。当 \(d_k\) 很大的时候,\(q \cdot k\) 的方差也会相应增大,这意味着 \(q \cdot k\) 会有较高的概率取到数值较大的元素。

为什么当 \(QK^T\) 中存在数值较大的元素时,\(\text{softmax}\) 的梯度会趋近于 \(0\)

接下来我们要看,当 \(q \cdot k\) 数值太大时,为什么 \(\text{softmax}\) 的梯度会趋近于 \(0\)

我们将 \(QK^T\) 中的第 \(i\) 行第 \(j\) 列的元素表示为 \(q_i \cdot k_j\)

\[ QK^T= \begin{bmatrix} q_{1} \cdot k_{1} & q_{1} \cdot k_{2} & \cdots & q_{1} \cdot k_{n} \\ q_{2} \cdot k_{1} & q_{2} \cdot k_{2} & \cdots & q_{2} \cdot k_{n} \\ \vdots & \vdots & \ddots & \vdots \\ q_{n} \cdot k_{1} & q_{n} \cdot k_{2} & \cdots & q_{n} \cdot k_{n} \end{bmatrix} \]

即:

\[ QK^T_{ij} = q_i \cdot k_j \]

其中,矩阵 \(Q\)\(K\) 的形状均为 \(n \times d_k\)\(n\) 表示序列中 token 的数量,\(d_k\) 表示键向量的维度,它通常等于嵌入向量的维度 \(d_\text{model}\)\(q_i\) 表示查询矩阵 \(Q\) 的第 \(i\) 行(第 \(i\) 个 token 的查询向量),\(k_i\) 表示键矩阵 \(K\) 的第 \(i\) 行(第 \(i\) 个 token 的键向量)。

\(QK^T\) 应用 \(\text{softmax}\) 时,我们会对每一行进行 \(\text{softmax}\) 变换,即

\[ \text{softmax}(QK^T)_{ij} = \frac{\exp(q_i \cdot k_j)}{\sum_{l=1}^{n} \exp(q_i \cdot k_l)} \]

其中 \(i\) 表示行索引,\(j\) 表示列索引。

由于 \(\text{softmax}\) 是对每一行进行的,因此,不同行之间的数据是互不影响其 \(\text{softmax}\) 变换的结果的。

不失一般性,我们考察第 \(1\) 行的导数,也就是 \(\text{softmax}(QK^T)\) 中第一行向量关于 \(q_1 \cdot k_i, \forall i \in \{1, 2, \cdots, n\}\) 的雅可比矩阵。

\(\text{softmax}\) 变换前的值 \(z_i = q_1 \cdot k_i\)\(\text{softmax}\) 变换后的值 \(s_i = \frac{\exp(z_i)}{\sum_{l=1}^{n} \exp(z_l)}\)

则有:

\[ J_{softmax} = \begin{pmatrix} \frac{\partial s_1}{\partial z_1} & \frac{\partial s_1}{\partial z_2} & \cdots & \frac{\partial s_1}{\partial z_n} \\ \frac{\partial s_2}{\partial z_1} & \frac{\partial s_2}{\partial z_2} & \cdots & \frac{\partial s_2}{\partial z_n} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial s_n}{\partial z_1} & \frac{\partial s_n}{\partial z_2} & \cdots & \frac{\partial s_n}{\partial z_n} \end{pmatrix} \]

为了推导 \(s_i\) 关于 \(z_j\) 的偏导数,我们先把 \(\text{softmax}\) 写成更方便求导的形式。记分母归一化项为

\[ Z = \sum_{l=1}^{n} \exp(z_l) \]

则有

\[ s_i = \frac{\exp(z_i)}{Z} \]

\(z_j\) 求偏导时分两种情况讨论:

  • \(i = j\) 时,
\[ \frac{\partial s_i}{\partial z_i} = \frac{\exp(z_i) \cdot Z - \exp(z_i) \cdot \exp(z_i)}{Z^2} = \frac{\exp(z_i)}{Z}\left(1 - \frac{\exp(z_i)}{Z}\right) = s_i (1 - s_i) \]
  • \(i \neq j\) 时,
\[ \frac{\partial s_i}{\partial z_j} = \frac{0 \cdot Z - \exp(z_i) \cdot \exp(z_j)}{Z^2} = -\frac{\exp(z_i)}{Z} \cdot \frac{\exp(z_j)}{Z} = -s_i s_j \]

将上述两种情况合并,可以写成更紧凑的形式:

\[ \frac{\partial s_i}{\partial z_j} = s_i \left( \mathbf{1}\{i = j\} - s_j \right), \quad \forall i, j = 1, \ldots, n \]

将结果代回雅可比矩阵,可以具体写出:

\[ J_{softmax} = \begin{pmatrix} s_1 (1 - s_1) & -s_1 s_2 & \cdots & -s_1 s_n \\ -s_2 s_1 & s_2 (1 - s_2) & \cdots & -s_2 s_n \\ \vdots & \vdots & \ddots & \vdots \\ -s_n s_1 & -s_n s_2 & \cdots & s_n (1 - s_n) \end{pmatrix} \]

当任意一个 \(z_i\) 非常大、对应的 \(s_i \approx 1\) 时,其他 \(s_j, j \neq i\) 都会趋近于 \(0\)。也就是 \(\text{softmax}\) 变换后得到的行向量中,只有一个元素为 \(1\),其他元素均为 \(0\)。此时,\(J_{softmax}\) 中的所有元素都会趋近于 \(0\)

数值示例

通过数值示例,可以更好地帮助我们理解,\(\text{softmax}\) 中存在数值较大的元素时,\(\text{softmax}\) 变换后最大的元素会趋近于 \(1\),其他元素会趋近于 \(0\)

Python
"""
Visualizing how softmax output changes as input variance increases.
Variances: [1, 10, 20, 30, 50, 100]
Each sample has 64 elements.
Generates two 3x2 figures:
  1) Before softmax
  2) After softmax
"""
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(0)
variances = [1, 10, 20, 30, 50, 100]
n = 64


def stable_softmax(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x, dtype=np.float64)
    x_shift = x - np.max(x)
    exps = np.exp(x_shift)
    return exps / np.sum(exps)


# Generate samples
samples = [np.random.normal(0.0, np.sqrt(v), size=n) for v in variances]
softmaxed = [stable_softmax(x) for x in samples]

# Plot before softmax
fig1, axes1 = plt.subplots(3, 2, figsize=(14, 8), constrained_layout=True)
axes1 = axes1.ravel()
for ax, var, x in zip(axes1, variances, samples):
    ax.bar(np.arange(n), x)
    ax.set_title(f"Before softmax: variance={var} (std≈{np.sqrt(var):.2f})")
    ax.set_xlabel("Index")
    ax.set_ylabel("Value")
fig1.suptitle(
    "Samples with mean=0 and different variances (Before softmax)", fontsize=14, y=1.02
)

# Plot after softmax
fig2, axes2 = plt.subplots(3, 2, figsize=(14, 8), constrained_layout=True)
axes2 = axes2.ravel()
for ax, var, s in zip(axes2, variances, softmaxed):
    ax.bar(np.arange(n), s)
    ax.set_title(f"After softmax: variance={var}")
    ax.set_xlabel("Index")
    ax.set_ylabel("Probability")
    ax.set_ylim(0, max(0.2, s.max() * 1.1))
fig2.suptitle(
    "Softmax probability distributions under different variances", fontsize=14, y=1.02
)

plt.show()

Softmax 前

Softmax 前

Softmax 后

Softmax 后

回顾与总结

我们首先证明了在 \(q\)\(k\) 各维度独立且均值为 \(0\)、方差为 \(1\) 的假设下,点积注意力的元素方差满足 \(\operatorname{Var}(q\cdot k)=d_k\)。因此当维度增大时,\(q\cdot k\) 的取值会越频繁地落到较大的区间。

\(\text{softmax}\) 的输入元素的存在非常大的数值时,向量中的最大值在转换后会趋近于 \(1\),其他元素会趋近于 \(0\)。这会导致 \(\text{softmax}\) 的雅可比矩阵中的每一项元素 \(\partial s_i/\partial z_j = s_i(\mathbf 1\{i=j\}-s_j)\) 几乎全部衰减为 \(0\)。梯度消失后,模型将难以通过反向传播来更新参数。

Transformer 原论文中将点积注意力缩放为 \(\frac{q\cdot k}{\sqrt{d_k}}\) 后,可以把注意力矩阵中元素的方差重新拉回 \(1\)

\[ \operatorname{Var}\!\left(\frac{q\cdot k}{\sqrt{d_k}}\right)=\frac{\operatorname{Var}(q\cdot k)}{d_k}=1, \]

从而让 \(\text{softmax}\) 的输入维持在既不会饱和又足够区分的范围,使得梯度不再因为输入数值太大而消失。

评论