← 返回首页
03_transformer_viz.html

完整 Transformer:四块新积木,拼出 GPT

配套 03_transformer.py。第 2 章是单头注意力。这一章加四样东西凑成完整 GPT: 多头(同时看多种关系)、FeedForward(逐位置思考)、残差 + LayerNorm(让深网络训得动), 再把它们摞成多层 —— 六步走完,val loss 从 2.41 掉到 ~1.57。

STEP 1
切蛋糕:128 → 4 个头
STEP 2
每个头各看各的
STEP 3
拼回去 · 头数怎么选
STEP 4
FFN:逐位置思考
STEP 5
残差 + LayerNorm
STEP 6
堆成 GPT · 破墙
超参速查 n_embd 128 主干维度 C n_head 4 头数 head_size 32 每头维度 block_size 64 位置数 T n_layer 4 层数 vocab_size 65 字符表 head_size = n_embd / n_head = 128/4 = 32

Step 1 · 切蛋糕:把 128 维切成 4 个头(不是加 4 份)

多头 最容易误会的一点:它不是把模型加宽成好几倍,而是把同一块 128n_embd 切成 4n_head 份,每份 32head_size 维。 切几份,总料不变。拨下面看 head_size = n_embd / n_head 怎么变。

代码4 个头 = 4 份第 2 章的 Head

MultiHeadAttention 里就一行关键:ModuleListn_head 个 第 2 章那个 Head,每个头有自己独立的 Wq/Wk/Wv

self.heads = nn.ModuleList([Head(head_size) # head_size = 32
                            for _ in range(num_heads)])  # num_heads = 4

下一步:4 个头各有各的 Wq/Wk/Wv → 各算出不一样的注意力。它们会自己分工。

Step 2 · 每个头各看各的:分工是训练"逼"出来的

4 个头吃同一个输入,但各有独立的 Wq/Wk/Wv → 各自算出一张不同的注意力权重表 wei点下面的头,看它们关注的位置多么不一样。每个头画成 8×8 的下三角(只是上屏占位:真实 block_size=64T,只能看过去),格子越深 = 关注越多。

行 = 发问的位置 i(query);列 = 被看的位置 j(key)。每行和为 1。

这 4 个头的"专长",是谁分配的?
没人分配,是训练中自己涌现的。没有哪行代码写"1 号头管前一个词"。让它们分开的是两股力: ① 初始随机不同 —— 每个头的 Wq/Wk/Wv 随机初始化,起点就有差异; ② "别重复"的压力 —— 若 4 个头算同一件事,等于浪费 3 个头的容量,loss 降不下去, 梯度下降于是把它们推向各管一摊
训练完用可解释性工具去看,确实能认出一些有名的头:专盯前一个词的、回看句首的、配对引号/括号的、 以及 induction head(归纳头:见过 "AB",再遇 A 就预测 B)。
(上面四张图是示意,用来体现"每个头一种模式";真实的头多数更杂、更难一句话概括。)
这跟 Q/K/V 的分工是一回事吗?
不是,两种分工的来源相反:
· Q/K/V 是"公式钦定"的 —— q 站在点积左边(发问)、k 在右边(被问)、v 被聚合搬运, 三者在算式里位置不同,角色天生被劈开,绝不会冗余
· 多个头是"训练涌现"的 —— 架构上 4 个头完全对称、可互换,是训练动力学才把它们卷出了不同专长, 有时还会重复(见下条)。
一句话:Q/K/V 是三个不同工种;4 个头是四个一样的新人,在干活中自己卷出了专长。
会不会两个头分化成一模一样?(冗余头)
会,这叫"冗余头(redundant heads)"。"别重复"的压力是软的、间接的,不是硬约束,头一多总有几个长得彼此很像。 有研究(《Are Sixteen Heads Really Better than One?》)发现很多头是冗余的,训练后剪掉(prune)性能几乎不掉。 冗余不是错,只是浪费 —— 这也是"头并非越多越好"的一个原因。代码里那行 dropout 多少帮点忙, 训练时随机丢连接,逼头之间别抱团长得太像。

下一步:4 个头各吐一根 32 维向量,怎么拼回 128?头数又该怎么挑?

Step 3 · 拼回去,与"头数怎么选"

每个头输出一根 32head_size 维的"看过历史"向量。4 根拼接(concat)正好回到 128n_embd,再过一个 proj 线性层融合一下 —— 主干维度自始至终是 128,接口不变。 逐步点:

代码拼接 + 投影
out = torch.cat([h(x) for h in self.heads], dim=-1)  # 4×32 → 128,沿特征维拼
return self.dropout(self.proj(out))                  # proj: 128 → 128 融合
实战头数怎么选?有没有最佳实践

经验法则:别直接挑头数,先把 head_size 锚在 64 附近(32–128 都行),再反推 n_head = n_embd / head_size。真实模型几乎都这么干:

模型n_embdn_headhead_size
GPT-2 small7681264
GPT-2 medium10241664
GPT-31228896128
本章(玩具)128432
加头数会爆显存、让训练时间暴涨吗?
几乎不会。因为 head_size = n_embd / n_head —— 你不是在"加头",是在切同一块 128 维的蛋糕: 4 头×32 = 8 头×16 = 16 头×8,总料始终 128。显存、算力基本持平(不是线性涨,更不是指数), 拼接后那个 proj 也永远是 128×128,与头数无关。
真正会爆显存/时间的是另几个:block_size(注意力 O(T²),序列翻倍算力四倍)、 n_embdn_layerbatch_sizen_head 不在这张爆炸名单上。
那是头越多越好,还是越少越好?
都不是,是"固定预算怎么切"的平衡:
· 头多 → 每个头更薄(head_size 小),单头表达力弱,但同时并行的视角多;
· 头少 → 每个头更厚更强,但一次能看的关系少;
· 切太碎(如 128 切 64 个头、每头才 2 维)→ 每头薄到啥也表示不了,反而变差。
所以有个甜点区,把 head_size 放在 64 左右最省心;还要能整除(n_embd % n_head == 0),且习惯取 2 的幂(对硬件友好)。
🎯 多头一句话:把 128 切成几个子空间,各看一种关系,再拼回 128。 接口(进 128、出 128)没变,所以它能像积木一样,叠进后面每一层 Block 里。

下一步:横向收集完信息,每个位置还得自己"消化"一下 —— 这就是 FeedForward。

Step 4 · FeedForward:注意力管"沟通",它管"思考"

多头让每个位置横向从别人那收集信息(沟通)。收集完,总得自己消化一下 —— 这就是 FeedForward:一个逐位置(position-wise)的小 MLP,对每个位置的向量单独加工,位置之间互不干涉。 结构是放大 → ReLU 非线性 → 压回,逐步点:

代码FeedForward
nn.Sequential(
    nn.Linear(n_embd, 4 * n_embd),  # 128 → 512 放大 4 倍
    nn.ReLU(),                      # 非线性(关键)
    nn.Linear(4 * n_embd, n_embd),  # 512 → 128 压回,方便残差相加
)
"逐位置(position-wise)"到底什么意思?
同一个 MLP 独立作用在每个位置的向量上,位置之间不交换任何信息 —— 8 个位置就是把同一个 FFN 跑 8 遍。 换句话说:注意力横向沟通(让 token 互相看),FFN 纵向加工(每个 token 自己想)。 一个 Transformer 层就是这"沟通 + 思考"的组合拳。
为什么中间放大 4 倍?为什么非要 ReLU?
放大 4 倍是惯例(经验值):给非线性变换更大的"草稿纸",表达力更强;算完再压回 n_embd。
ReLU 不能省:它把负数清零,提供非线性。否则 Linear ∘ Linear 数学上还是一个线性变换, 整个 FFN 会退化成一个矩阵,放大 4 倍也白搭。非线性才是 FFN "能思考"的来源。

下一步:多头、FFN 这些子层,怎么稳稳叠很多层而不"训崩"?靠残差 + LayerNorm。

Step 5 · 残差 + LayerNorm:让深网络训得动

多头、FFN 都是"子层"。要把它们叠很多层(本章 4n_layer 层),靠两个稳定器: 残差连接(子层外包一个 x = x + 子层(x))和 LayerNorm(进子层前先归一化)。 一个 Block = x = x + sa(ln1(x)) 然后 x = x + ffwd(ln2(x))把残差开关拨一下,看堆深时的差别:

残差到底解决了什么?(梯度高速公路)
反向传播时,梯度要从顶层一路乘回底层。层一深,连乘会让梯度越来越小 → 梯度消失,底层学不动。 残差那个 + 给梯度开了一条直通车:它可以绕过子层直接往下传,不被连乘衰减。 于是子层只需学"在原信号上加什么修正 Δ",而不是从头重建整个信号 —— 这就是 2015 年 ResNet 的核心思想,Transformer 全靠它堆深。
LayerNorm 在归一化什么?pre-norm vs post-norm
LayerNorm每个位置那一根向量(沿特征维)做归一化:减均值、除标准差,把数值拉回稳定范围,深网络才不发散。 它不跨位置、不跨 batch,所以和序列长度无关。
代码里是 x + 子层(ln(x)) —— 先 norm 再进子层,叫 pre-norm;比原始 Transformer 的 post-norm(子层后再 norm)更好训、更稳。

下一步:把 4 个这样的 Block 摞起来,配上嵌入和输出头 —— 完整 GPT。

Step 6 · 堆成 GPT,跑出结果(破墙)

4n_layerBlock 摞起来,前面接嵌入(tok+pos)、后面接 ln_f + lm_head —— 这就是完整 GPT 的 forward:

idx输入字符
tok_emb + pos_emb内容 + 位置
4 × Block沟通+思考,带残差/norm
ln_f末层归一化
lm_head映射回词表
logits下一字符打分
跑起来loss 跌破第 2 章的 2.41

03_transformer.py 每 500 步打印一次 loss。下面的曲线自动播放训练过程,看它穿过第 2 章单头注意力卡住的 2.41 虚线,一路下探到 ~1.57。

完整 Transformer(本章)val loss 第 2 章单头 ≈ 2.41

(示意曲线:数量级与趋势对得上 03_transformer.py 的实际打印,非逐位精确复现。)

🎯 结果:破墙。多头(看多种关系)+ FFN(逐位置思考)+ 残差/LayerNorm(堆得动 4 层), val loss 跌破第 2 章的 2.41,一路下探到 ~1.57 —— 采样开始有"词"和对话结构,不再是纯字母汤。 这就是 nanoGPT 的内核。

点右下角「完成 🎉」收下整章 —— 你已经从零搭出一个 GPT。