为什么注意力机制中要除以 \(\sqrt{d_k}\) :从方差到梯度的推导¶
在 Transformer 的注意力机制中,计算点积注意力 \(QK^T\) 之后,需要除以一个 \(\sqrt{d_k}\) 进行缩放。这一操作通常被解释为“为了数值稳定性”。这里的“稳定”究竟指的是什么?如果不除以 \(\sqrt{d_k}\) 就不稳定了么?为什么不除以 \(d_k\) 或其他数值呢?
本文分析了点积的方差如何随维度增长而增大,并进一步推导 \(\text{softmax}\) 变换得到的行向量的雅可比矩阵,展示当输入数值过大时梯度如何逐渐趋近于零。通过这一过程,我们将会理解,除以 \(\sqrt{d_k}\) 并不是随意的设置,而是确保注意力机制在高维空间中仍能保持可训练性的必要条件。
问题拆解:从方差到梯度¶
对于为什么要对点积进行缩放,Transformer 原论文中给出了如下解释:
Attention Is All You Need
We suspect that for large values of \(d_k\), the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients.
也就是说,当 \(d_k\) 取较大值时,点积 \(QK^T\) 中的数值会变得非常大,从而使 \(\text{softmax}\) 的输入落入梯度极小的区域,导致训练困难。
因此,接下来的分析可以拆解为两个问题:
- 为什么当 \(d_k\) 较大时,\(QK^T\) 这个点积注意力矩阵中有较高的概率会出现数值较大的元素?
- 为什么当 \(QK^T\) 中存在数值较大的元素时,\(\text{softmax}\) 的梯度会趋近于 \(0\)?
为什么当 \(d_k\) 较大时,\(QK^T\) 这个点积注意力矩阵中有较高的概率会出现数值较大的元素?¶
Transformer 原论文中指出,若 \(q\) 和 \(k\) 的各维度独立,且均值为 0、方差为 1,则点积注意力矩阵中的元素 \(q \cdot k\) 的均值为 \(0\)、方差为 \(d_k\)。
Attention Is All You Need
To illustrate why the dot products get large, assume that the components of \(q\) and \(k\) are independent random variables with mean 0 and variance 1. Then their dot product
has mean 0 and variance \(d_k\).
下面我们来用概率统计的知识推导一下这个结论。
均值 \(\mathbb E[q\cdot k] = 0\)¶
方差 \(\text{Var}(q\cdot k) = d_k\)¶
以上推导告诉我们,对于点积注意力矩阵中的每个元素 \(q \cdot k\),其均值为 0,方差为 \(d_k\)。当 \(d_k\) 很大的时候,\(q \cdot k\) 的方差也会相应增大,这意味着 \(q \cdot k\) 会有较高的概率取到数值较大的元素。
为什么当 \(QK^T\) 中存在数值较大的元素时,\(\text{softmax}\) 的梯度会趋近于 \(0\)?¶
接下来我们要看,当 \(q \cdot k\) 数值太大时,为什么 \(\text{softmax}\) 的梯度会趋近于 \(0\)。
我们将 \(QK^T\) 中的第 \(i\) 行第 \(j\) 列的元素表示为 \(q_i \cdot k_j\):
即:
其中,矩阵 \(Q\) 和 \(K\) 的形状均为 \(n \times d_k\),\(n\) 表示序列中 token 的数量,\(d_k\) 表示键向量的维度,它通常等于嵌入向量的维度 \(d_\text{model}\)。\(q_i\) 表示查询矩阵 \(Q\) 的第 \(i\) 行(第 \(i\) 个 token 的查询向量),\(k_i\) 表示键矩阵 \(K\) 的第 \(i\) 行(第 \(i\) 个 token 的键向量)。
对 \(QK^T\) 应用 \(\text{softmax}\) 时,我们会对每一行进行 \(\text{softmax}\) 变换,即
其中 \(i\) 表示行索引,\(j\) 表示列索引。
由于 \(\text{softmax}\) 是对每一行进行的,因此,不同行之间的数据是互不影响其 \(\text{softmax}\) 变换的结果的。
不失一般性,我们考察第 \(1\) 行的导数,也就是 \(\text{softmax}(QK^T)\) 中第一行向量关于 \(q_1 \cdot k_i, \forall i \in \{1, 2, \cdots, n\}\) 的雅可比矩阵。
记 \(\text{softmax}\) 变换前的值 \(z_i = q_1 \cdot k_i\),\(\text{softmax}\) 变换后的值 \(s_i = \frac{\exp(z_i)}{\sum_{l=1}^{n} \exp(z_l)}\)。
则有:
为了推导 \(s_i\) 关于 \(z_j\) 的偏导数,我们先把 \(\text{softmax}\) 写成更方便求导的形式。记分母归一化项为
则有
对 \(z_j\) 求偏导时分两种情况讨论:
- 当 \(i = j\) 时,
- 当 \(i \neq j\) 时,
将上述两种情况合并,可以写成更紧凑的形式:
将结果代回雅可比矩阵,可以具体写出:
当任意一个 \(z_i\) 非常大、对应的 \(s_i \approx 1\) 时,其他 \(s_j, j \neq i\) 都会趋近于 \(0\)。也就是 \(\text{softmax}\) 变换后得到的行向量中,只有一个元素为 \(1\),其他元素均为 \(0\)。此时,\(J_{softmax}\) 中的所有元素都会趋近于 \(0\)。
数值示例¶
通过数值示例,可以更好地帮助我们理解,\(\text{softmax}\) 中存在数值较大的元素时,\(\text{softmax}\) 变换后最大的元素会趋近于 \(1\),其他元素会趋近于 \(0\)。
"""
Visualizing how softmax output changes as input variance increases.
Variances: [1, 10, 20, 30, 50, 100]
Each sample has 64 elements.
Generates two 3x2 figures:
1) Before softmax
2) After softmax
"""
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(0)
variances = [1, 10, 20, 30, 50, 100]
n = 64
def stable_softmax(x: np.ndarray) -> np.ndarray:
x = np.asarray(x, dtype=np.float64)
x_shift = x - np.max(x)
exps = np.exp(x_shift)
return exps / np.sum(exps)
# Generate samples
samples = [np.random.normal(0.0, np.sqrt(v), size=n) for v in variances]
softmaxed = [stable_softmax(x) for x in samples]
# Plot before softmax
fig1, axes1 = plt.subplots(3, 2, figsize=(14, 8), constrained_layout=True)
axes1 = axes1.ravel()
for ax, var, x in zip(axes1, variances, samples):
ax.bar(np.arange(n), x)
ax.set_title(f"Before softmax: variance={var} (std≈{np.sqrt(var):.2f})")
ax.set_xlabel("Index")
ax.set_ylabel("Value")
fig1.suptitle(
"Samples with mean=0 and different variances (Before softmax)", fontsize=14, y=1.02
)
# Plot after softmax
fig2, axes2 = plt.subplots(3, 2, figsize=(14, 8), constrained_layout=True)
axes2 = axes2.ravel()
for ax, var, s in zip(axes2, variances, softmaxed):
ax.bar(np.arange(n), s)
ax.set_title(f"After softmax: variance={var}")
ax.set_xlabel("Index")
ax.set_ylabel("Probability")
ax.set_ylim(0, max(0.2, s.max() * 1.1))
fig2.suptitle(
"Softmax probability distributions under different variances", fontsize=14, y=1.02
)
plt.show()
回顾与总结¶
我们首先证明了在 \(q\) 和 \(k\) 各维度独立且均值为 \(0\)、方差为 \(1\) 的假设下,点积注意力的元素方差满足 \(\operatorname{Var}(q\cdot k)=d_k\)。因此当维度增大时,\(q\cdot k\) 的取值会越频繁地落到较大的区间。
当 \(\text{softmax}\) 的输入元素的存在非常大的数值时,向量中的最大值在转换后会趋近于 \(1\),其他元素会趋近于 \(0\)。这会导致 \(\text{softmax}\) 的雅可比矩阵中的每一项元素 \(\partial s_i/\partial z_j = s_i(\mathbf 1\{i=j\}-s_j)\) 几乎全部衰减为 \(0\)。梯度消失后,模型将难以通过反向传播来更新参数。
Transformer 原论文中将点积注意力缩放为 \(\frac{q\cdot k}{\sqrt{d_k}}\) 后,可以把注意力矩阵中元素的方差重新拉回 \(1\):
从而让 \(\text{softmax}\) 的输入维持在既不会饱和又足够区分的范围,使得梯度不再因为输入数值太大而消失。