配套 03_transformer.py。第 2 章是单头注意力。这一章加四样东西凑成完整 GPT:
多头(同时看多种关系)、FeedForward(逐位置思考)、残差 + LayerNorm(让深网络训得动),
再把它们摞成多层 —— 六步走完,val loss 从 2.41 掉到 ~1.57。
多头
最容易误会的一点:它不是把模型加宽成好几倍,而是把同一块 128n_embd
切成 4n_head 份,每份 32head_size 维。
切几份,总料不变。拨下面看 head_size = n_embd / n_head 怎么变。
Head
MultiHeadAttention 里就一行关键:ModuleList 装 n_head 个
第 2 章那个 Head,每个头有自己独立的 Wq/Wk/Wv。
self.heads = nn.ModuleList([Head(head_size) # head_size = 32 for _ in range(num_heads)]) # num_heads = 4
4 个头吃同一个输入,但各有独立的 Wq/Wk/Wv → 各自算出一张不同的注意力权重表 wei。
点下面的头,看它们关注的位置多么不一样。每个头画成 8×8 的下三角(只是上屏占位:真实 block_size=64T,只能看过去),格子越深 = 关注越多。
行 = 发问的位置 i(query);列 = 被看的位置 j(key)。每行和为 1。
Wq/Wk/Wv 随机初始化,起点就有差异;
② "别重复"的压力 —— 若 4 个头算同一件事,等于浪费 3 个头的容量,loss 降不下去,
梯度下降于是把它们推向各管一摊。dropout 多少帮点忙,
训练时随机丢连接,逼头之间别抱团长得太像。
每个头输出一根 32head_size 维的"看过历史"向量。4 根拼接(concat)正好回到
128n_embd,再过一个 proj 线性层融合一下 —— 主干维度自始至终是 128,接口不变。
逐步点:
out = torch.cat([h(x) for h in self.heads], dim=-1) # 4×32 → 128,沿特征维拼 return self.dropout(self.proj(out)) # proj: 128 → 128 融合
经验法则:别直接挑头数,先把 head_size 锚在 64 附近(32–128 都行),再反推
n_head = n_embd / head_size。真实模型几乎都这么干:
| 模型 | n_embd | n_head | head_size |
|---|---|---|---|
| GPT-2 small | 768 | 12 | 64 |
| GPT-2 medium | 1024 | 16 | 64 |
| GPT-3 | 12288 | 96 | 128 |
| 本章(玩具) | 128 | 4 | 32 |
head_size = n_embd / n_head —— 你不是在"加头",是在切同一块 128 维的蛋糕:
4 头×32 = 8 头×16 = 16 头×8,总料始终 128。显存、算力基本持平(不是线性涨,更不是指数),
拼接后那个 proj 也永远是 128×128,与头数无关。n_head 不在这张爆炸名单上。
n_embd % n_head == 0),且习惯取 2 的幂(对硬件友好)。
Block 里。
多头让每个位置横向从别人那收集信息(沟通)。收集完,总得自己消化一下 —— 这就是
FeedForward:一个逐位置(position-wise)的小
MLP,对每个位置的向量单独加工,位置之间互不干涉。
结构是放大 → ReLU 非线性 → 压回,逐步点:
nn.Sequential(
nn.Linear(n_embd, 4 * n_embd), # 128 → 512 放大 4 倍
nn.ReLU(), # 非线性(关键)
nn.Linear(4 * n_embd, n_embd), # 512 → 128 压回,方便残差相加
)
Linear ∘ Linear 数学上还是一个线性变换,
整个 FFN 会退化成一个矩阵,放大 4 倍也白搭。非线性才是 FFN "能思考"的来源。
多头、FFN 都是"子层"。要把它们叠很多层(本章 4n_layer 层),靠两个稳定器:
残差连接(子层外包一个 x = x + 子层(x))和
LayerNorm(进子层前先归一化)。
一个 Block = x = x + sa(ln1(x)) 然后 x = x + ffwd(ln2(x))。
把残差开关拨一下,看堆深时的差别:
+ 给梯度开了一条直通车:它可以绕过子层直接往下传,不被连乘衰减。
于是子层只需学"在原信号上加什么修正 Δ",而不是从头重建整个信号 —— 这就是 2015 年 ResNet 的核心思想,Transformer 全靠它堆深。
x + 子层(ln(x)) —— 先 norm 再进子层,叫 pre-norm;比原始 Transformer 的 post-norm(子层后再 norm)更好训、更稳。
把 4n_layer 个 Block 摞起来,前面接嵌入(tok+pos)、后面接 ln_f + lm_head ——
这就是完整 GPT 的 forward:
03_transformer.py 每 500 步打印一次 loss。下面的曲线自动播放训练过程,看它穿过第 2 章单头注意力卡住的 2.41 虚线,一路下探到 ~1.57。
(示意曲线:数量级与趋势对得上 03_transformer.py 的实际打印,非逐位精确复现。)
val loss 跌破第 2 章的 2.41,一路下探到 ~1.57 —— 采样开始有"词"和对话结构,不再是纯字母汤。
这就是 nanoGPT 的内核。