AI 一月 11, 2026

torch.nn.Embedding 中 max_norm 的作用

文章字数 7.7k 阅读约需 7 mins. 阅读次数

https://docs.pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding

1. 类比 “词典”

nn.Embedding(num_embeddings, embedding_dim) 可以看成是一个查表词典

  • num_embeddings 行,每一行是一个 embedding_dim 维的向量。
  • 输入是索引(比如单词 ID、类别 ID),输出是对应行的向量。

max_norm 的作用:

给这个“词典”里的每一行向量设一个“最大长度”限制:

  • 这里的“长度”是向量的 p-范数,默认是 L2 范数(欧几里得长度)。
  • 当某一行被用到(即该 index 出现在输入中)时,如果它的范数大于 max_norm,就会被按比例缩小,使得它的范数刚好等于 max_norm

重要点:

  • 这个“缩小”发生在 forward 的时候,是原地修改 embedding.weight 的部分行
  • 只有被当前 batch 索引到的那些行才会被检查和可能被缩放。

2. 一个具体、可手算的小例子

2.1 定义一个简单的 Embedding

假设我们手动构造一个 embedding 权重(方便算):

num_embeddings = 5   # 总共有 5 行
embedding_dim = 3   # 每行是 3 维向量
max_norm = 1.5

我们设定当前(某一次训练时刻)的权重矩阵为:

$$
W \in \mathbb{R}^{5 \times 3} =
\begin{bmatrix}
0.3367 & 0.1288 & 0.2345 \\
0.2303 & -1.1229 & -0.1863 \\
2.2082 & -0.6380 & 0.4617 \\
0.2674 & 0.5349 & 0.8094 \\
1.1103 & -1.6898 & -0.9890
\end{bmatrix}
$$

也就是:

  • 第 0 行:[0.3367, 0.1288, 0.2345]
  • 第 1 行:[0.2303, -1.1229, -0.1863]
  • 第 2 行:[2.2082, -0.6380, 0.4617]
  • 第 3 行:[0.2674, 0.5349, 0.8094]
  • 第 4 行:[1.1103, -1.6898, -0.9890]

2.2 计算每一行的 L2 范数

对每一行向量 $v = (x, y, z)$,L2 范数定义为:

$$
\|v\|_2 = \sqrt{x^2 + y^2 + z^2}
$$

逐行计算:

向量 0

$$
v_0 = (0.3367, 0.1288, 0.2345)
$$

$$
\|v_0\|_2 = \sqrt{0.3367^2 + 0.1288^2 + 0.2345^2}
\approx 0.4300
$$

(约等于 0.43,小于 1.5,不会被改动)

向量 1

$$
v_1 = (0.2303, -1.1229, -0.1863)
$$

$$
\|v_1\|_2 = \sqrt{0.2303^2 + (-1.1229)^2 + (-0.1863)^2}
\approx 1.1613
$$

(小于 1.5,不会被改动)

向量 2(会被裁剪)

$$
v_2 = (2.2082, -0.6380, 0.4617)
$$

$$
\|v_2\|_2 = \sqrt{2.2082^2 + (-0.6380)^2 + 0.4617^2}
\approx 2.3444 > 1.5
$$

这个超过了 max_norm=1.5需要被缩小

向量 3

$$
v_3 = (0.2674, 0.5349, 0.8094)
$$

$$
\|v_3\|_2 \approx 1.0063 < 1.5
$$

不改。

向量 4(会被裁剪)

$$
v_4 = (1.1103, -1.6898, -0.9890)
$$

$$
\|v_4\|_2 = \sqrt{1.1103^2 + (-1.6898)^2 + (-0.9890)^2}
\approx 2.2508 > 1.5
$$

也超过了 1.5,需要缩小。

小结:

  • 需要被处理的只有:行 2 和 行 4(它们的范数 > 1.5)

3. max_norm 的具体计算公式(重归一化)

对每一个需要被裁剪的向量 $ v $:

  1. 先算出当前范数:$ \|v\| $
  2. 如果 $ \|v\| > \text{max_norm} $,就按系数
    $$
    \text{scale} = \frac{\text{max_norm}}{\|v\|}
    $$
    去缩小它:
    $$
    v_{\text{new}} = v \times \text{scale}
    $$

这样就有:

$$
\|v_{\text{new}}\| = \|v \times \text{scale}\| = \text{scale} \cdot \|v\|
= \frac{\text{max_norm}}{\|v\|} \cdot \|v\| = \text{max_norm}
$$

3.1 对行 2 的计算

  • 原始向量:

$$
v_2 = (2.2082, -0.6380, 0.4617), \quad \|v_2\| \approx 2.3444
$$

  • 缩放系数:

$$
\text{scale}_2 = \frac{1.5}{2.3444} \approx 0.6396
$$

  • 新向量(逐元素相乘):

$$
v_{2,\text{new}} = v_2 \times 0.6396 \approx
(2.2082 \times 0.6396,\; -0.6380 \times 0.6396,\; 0.4617 \times 0.6396)
$$

数值约为:

$$
v_{2,\text{new}} \approx (1.4128,\; -0.4082,\; 0.2954)
$$

再检查一下新范数:

$$
\|v_{2,\text{new}}\|_2 \approx 1.5000
$$

3.2 对行 4 的计算

  • 原始向量:

$$
v_4 = (1.1103, -1.6898, -0.9890), \quad \|v_4\| \approx 2.2508
$$

  • 缩放系数:

$$
\text{scale}_4 = \frac{1.5}{2.2508} \approx 0.6665
$$

  • 新向量:

$$
v_{4,\text{new}} \approx
(1.1103 \times 0.6665,\; -1.6898 \times 0.6665,\; -0.9890 \times 0.6665)
\approx (0.7399,\; -1.1261,\; -0.6591)
$$

新范数:

$$
\|v_{4,\text{new}}\|_2 \approx 1.5000
$$

4. 重归一化后的权重矩阵

归一化之后,新的权重矩阵 $W’$ 变为:

$$
W’ =
\begin{bmatrix}
0.3367 & 0.1288 & 0.2345 \\
0.2303 & -1.1229 & -0.1863 \\
1.4128 & -0.4082 & 0.2954 \\
0.2674 & 0.5349 & 0.8094 \\
0.7399 & -1.1261 & -0.6591
\end{bmatrix}
$$

新的各行范数:

  • 第 0 行:0.4300(未变)
  • 第 1 行:1.1613(未变)
  • 第 2 行:1.5000(从 2.3444 被缩到 1.5)
  • 第 3 行:1.0063(未变)
  • 第 4 行:1.5000(从 2.2508 被缩到 1.5)

5. 对权重的具体影响总结

只影响被访问到且范数超限的行

  • forward 时只拿到了某些 index(比如本 batch 里用到的词 ID),max_norm 只会根据这些 index 去检查并缩放对应行。
  • 其它没被访问到的行,这一轮 forward 不会去动它。

操作是 in-place 的

官方文档明确说明:当 max_norm 不为 None 时,Embedding.forward原地修改 weight

这意味着:

  • 缩放后的权重会被保留下来,用于之后的训练步骤。
  • 如果你在 forward 之前对 embedding.weight 做可微操作,需要先 .clone() 一份再用,否则会和 autograd 的 in-place 规则冲突。

不改变方向,只改变长度

  • 向量被按比例整体缩小:
    $$
    v_{\text{new}} = v \cdot \frac{\text{max_norm}}{\|v\|}
    $$
  • 方向(单位向量)不变,只是“缩短”到指定长度。

起到正则化 / 稳定作用

  • 限制每个 embedding 行向量的最大范数,可以避免某些向量过大,防止梯度爆炸或某些词向量“过度主导”模型。

6. 一句话记住

nn.Embedding 里:

max_norm = “给每一行 embedding 向量设一个最大长度
每次 forward 时,凡是被用到且长度超过这个上限的行,都会被按比例缩到这个长度,并且是直接改写权重矩阵的。”

0%