
https://docs.pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding
1. 类比 “词典”
nn.Embedding(num_embeddings, embedding_dim) 可以看成是一个查表词典:
- 有
num_embeddings行,每一行是一个embedding_dim维的向量。 - 输入是索引(比如单词 ID、类别 ID),输出是对应行的向量。
max_norm 的作用:
给这个“词典”里的每一行向量设一个“最大长度”限制:
- 这里的“长度”是向量的 p-范数,默认是 L2 范数(欧几里得长度)。
- 当某一行被用到(即该 index 出现在输入中)时,如果它的范数大于
max_norm,就会被按比例缩小,使得它的范数刚好等于max_norm。
重要点:
- 这个“缩小”发生在 forward 的时候,是原地修改 embedding.weight 的部分行。
- 只有被当前 batch 索引到的那些行才会被检查和可能被缩放。
2. 一个具体、可手算的小例子
2.1 定义一个简单的 Embedding
假设我们手动构造一个 embedding 权重(方便算):
num_embeddings = 5 # 总共有 5 行
embedding_dim = 3 # 每行是 3 维向量
max_norm = 1.5
我们设定当前(某一次训练时刻)的权重矩阵为:
$$
W \in \mathbb{R}^{5 \times 3} =
\begin{bmatrix}
0.3367 & 0.1288 & 0.2345 \\
0.2303 & -1.1229 & -0.1863 \\
2.2082 & -0.6380 & 0.4617 \\
0.2674 & 0.5349 & 0.8094 \\
1.1103 & -1.6898 & -0.9890
\end{bmatrix}
$$
也就是:
- 第 0 行:
[0.3367, 0.1288, 0.2345] - 第 1 行:
[0.2303, -1.1229, -0.1863] - 第 2 行:
[2.2082, -0.6380, 0.4617] - 第 3 行:
[0.2674, 0.5349, 0.8094] - 第 4 行:
[1.1103, -1.6898, -0.9890]
2.2 计算每一行的 L2 范数
对每一行向量 $v = (x, y, z)$,L2 范数定义为:
$$
\|v\|_2 = \sqrt{x^2 + y^2 + z^2}
$$
逐行计算:
向量 0
$$
v_0 = (0.3367, 0.1288, 0.2345)
$$
$$
\|v_0\|_2 = \sqrt{0.3367^2 + 0.1288^2 + 0.2345^2}
\approx 0.4300
$$
(约等于 0.43,小于 1.5,不会被改动)
向量 1
$$
v_1 = (0.2303, -1.1229, -0.1863)
$$
$$
\|v_1\|_2 = \sqrt{0.2303^2 + (-1.1229)^2 + (-0.1863)^2}
\approx 1.1613
$$
(小于 1.5,不会被改动)
向量 2(会被裁剪)
$$
v_2 = (2.2082, -0.6380, 0.4617)
$$
$$
\|v_2\|_2 = \sqrt{2.2082^2 + (-0.6380)^2 + 0.4617^2}
\approx 2.3444 > 1.5
$$
这个超过了 max_norm=1.5,需要被缩小。
向量 3
$$
v_3 = (0.2674, 0.5349, 0.8094)
$$
$$
\|v_3\|_2 \approx 1.0063 < 1.5
$$
不改。
向量 4(会被裁剪)
$$
v_4 = (1.1103, -1.6898, -0.9890)
$$
$$
\|v_4\|_2 = \sqrt{1.1103^2 + (-1.6898)^2 + (-0.9890)^2}
\approx 2.2508 > 1.5
$$
也超过了 1.5,需要缩小。
小结:
- 需要被处理的只有:行 2 和 行 4(它们的范数 > 1.5)
3. max_norm 的具体计算公式(重归一化)
对每一个需要被裁剪的向量 $ v $:
- 先算出当前范数:$ \|v\| $
- 如果 $ \|v\| > \text{max_norm} $,就按系数
$$
\text{scale} = \frac{\text{max_norm}}{\|v\|}
$$
去缩小它:
$$
v_{\text{new}} = v \times \text{scale}
$$
这样就有:
$$
\|v_{\text{new}}\| = \|v \times \text{scale}\| = \text{scale} \cdot \|v\|
= \frac{\text{max_norm}}{\|v\|} \cdot \|v\| = \text{max_norm}
$$
3.1 对行 2 的计算
- 原始向量:
$$
v_2 = (2.2082, -0.6380, 0.4617), \quad \|v_2\| \approx 2.3444
$$
- 缩放系数:
$$
\text{scale}_2 = \frac{1.5}{2.3444} \approx 0.6396
$$
- 新向量(逐元素相乘):
$$
v_{2,\text{new}} = v_2 \times 0.6396 \approx
(2.2082 \times 0.6396,\; -0.6380 \times 0.6396,\; 0.4617 \times 0.6396)
$$
数值约为:
$$
v_{2,\text{new}} \approx (1.4128,\; -0.4082,\; 0.2954)
$$
再检查一下新范数:
$$
\|v_{2,\text{new}}\|_2 \approx 1.5000
$$
3.2 对行 4 的计算
- 原始向量:
$$
v_4 = (1.1103, -1.6898, -0.9890), \quad \|v_4\| \approx 2.2508
$$
- 缩放系数:
$$
\text{scale}_4 = \frac{1.5}{2.2508} \approx 0.6665
$$
- 新向量:
$$
v_{4,\text{new}} \approx
(1.1103 \times 0.6665,\; -1.6898 \times 0.6665,\; -0.9890 \times 0.6665)
\approx (0.7399,\; -1.1261,\; -0.6591)
$$
新范数:
$$
\|v_{4,\text{new}}\|_2 \approx 1.5000
$$
4. 重归一化后的权重矩阵
归一化之后,新的权重矩阵 $W’$ 变为:
$$
W’ =
\begin{bmatrix}
0.3367 & 0.1288 & 0.2345 \\
0.2303 & -1.1229 & -0.1863 \\
1.4128 & -0.4082 & 0.2954 \\
0.2674 & 0.5349 & 0.8094 \\
0.7399 & -1.1261 & -0.6591
\end{bmatrix}
$$
新的各行范数:
- 第 0 行:0.4300(未变)
- 第 1 行:1.1613(未变)
- 第 2 行:1.5000(从 2.3444 被缩到 1.5)
- 第 3 行:1.0063(未变)
- 第 4 行:1.5000(从 2.2508 被缩到 1.5)
5. 对权重的具体影响总结
只影响被访问到且范数超限的行
- forward 时只拿到了某些 index(比如本 batch 里用到的词 ID),
max_norm只会根据这些 index 去检查并缩放对应行。 - 其它没被访问到的行,这一轮 forward 不会去动它。
操作是 in-place 的
官方文档明确说明:当 max_norm 不为 None 时,Embedding.forward 会原地修改 weight。
这意味着:
- 缩放后的权重会被保留下来,用于之后的训练步骤。
- 如果你在 forward 之前对
embedding.weight做可微操作,需要先.clone()一份再用,否则会和 autograd 的 in-place 规则冲突。
不改变方向,只改变长度
- 向量被按比例整体缩小:
$$
v_{\text{new}} = v \cdot \frac{\text{max_norm}}{\|v\|}
$$ - 方向(单位向量)不变,只是“缩短”到指定长度。
起到正则化 / 稳定作用
- 限制每个 embedding 行向量的最大范数,可以避免某些向量过大,防止梯度爆炸或某些词向量“过度主导”模型。
6. 一句话记住
在 nn.Embedding 里:
max_norm= “给每一行 embedding 向量设一个最大长度,
每次 forward 时,凡是被用到且长度超过这个上限的行,都会被按比例缩到这个长度,并且是直接改写权重矩阵的。”