配套 02_attention.py。上一章 bigram 只看前 1 个字符,撞在 val loss ≈ 2.49 的墙上。
这一章一步步搭出自注意力 —— 六步看懂直觉,再回去读代码就豁然开朗。
bigram 的天花板,在于它做预测时只能看前 1 个字符。自注意力的想法很简单:
让第 t 个位置,能参考它前面所有位置(0..t)。
点下面任意一个位置,看它能看到谁。
📜 出处:自注意力来自 2017 年的 《Attention Is All You Need》 —— Transformer 的奠基之作。本章搭的是它的核心零件(单头自注意力),第 3 关会完整复现整套架构。
0..t,看不到 t+1 之后。
block_size = 8block ——
也就是大家天天挂嘴边的上下文窗口(context window)。本章设成 8,只为把直觉一眼看清。
generate 里 idx[:, -block_size:] 就是在裁这个)。batch_size 不是一回事:block_size=8block 是"往回看几个 token",
batch_size=32batch 是"同时并行几条句子"。
动手搭注意力前,先看清喂进去的数据长什么"形状"。代码里张量都按 (B, T, C) 三个大写字母标注 ——
拨一拨下面三组交互,把它们和那几个 32 一次对清楚。
每次训练,get_batch 从全书随机取 B 个起点,每个起点连抓 T 个字符。
点「🎲 再取一个 batch」换一批;点任意一个格子,看它其实是一根向量。
batch_size ·
T = Time 一条几个位置 = block_size ·
C = Channels 每个位置的向量多长(下面 B 段细说)。
点每一站,看 (B,T,C) 里的 C 怎么变 —— B、T 一路不变,只有 C 在变。或点「▶ 一路跑到 logits」自动走一遍。
32(和一个 8)分别定哪根轴
Q、K、V 是三个张量(三种角色),不是三根轴各分一个。同一个 x 复制 3 份、各用自己的 W 重调,得到的 q / k / v 各自都是完整的 (B,T,32):
q @ kᵀ 才能拿两个完整张量两两位置算相关、得到 B×T×T 的分数表。
("三个"角色 和 "32" 维是两个无关的东西:把 head_size 改成 30,还是三张表,只是每个向量 30 个数。Q/K/V 的完整讲解在 Step 4。)
B, T, C = x.shape,动态从形状里读、不写死 ——
因为采样生成时不一样:generate 一次只喂 1 条,且上下文从 1 个字符慢慢长到 8。那时 B=1、T 从 1 涨到 8。
所以严谨说:B/T 是"这批数据当前实际的条数 / 长度",训练时恰好等于那两个超参。
02_attention.py 开头四步,把一本书变成可训练的整数:open(path, "r", encoding="utf-8") —— "r" = read 只读模式(只读不写;"w" 才是写入并清空原文件)。text = f.read() —— 把整本书一次性读成一条长字符串(含所有换行、空格、标点),所以 text 是一整条不断行的长字符串。sorted(set(text)) 去重排序得词表;encode 把每个字符换成整数 id → 整条文本变一长串数字。torch.tensor(...) 变张量,get_batch 再从中随机切出上面那个 B×T 的整数表。
这正是 02_attention.py A 段做的事。一个下三角、再归一化的权重矩阵 a,
乘上输入 b,就一次性完成"只看过去 + 取平均"。点按钮一步步看 a 怎么来。
02_attention.py A 段的纯演示:用固定数字讲清"下三角加权平均"这套数学,不训练、也不是模型本体。真正带 Q/K/V、会被训练的注意力从 Step 4 起。
a,大小 3×3。它是"位置 × 位置":行 = 哪个位置在看,列 = 它在看哪个位置。
两边都是 3,是因为这个演示只取了序列的前 3 个位置(如 "To be or" 的 T o ␣)。c=a@b」后,右边出现的 b 是 3×2(6 格):3 个位置,每个位置一个 2 个数的向量。
所以 位置数定 a 的行列(3×3)、定 b 的行(3);每个向量多少个数只定 b 的列(2)。block_size=8block、向量长度 = head_size=32head,
这张 a 就是 8×8 —— 即 Step 5 的热力图。
上一步的权重是写死的(每个过去位置一样多)。真实语言里,有的词更该被关注。
我们想让权重由内容算出来。办法:把同一个输入 x,投影成三种角色。
像在图书馆找书 ——
nn.Linear:self.query / self.key / self.value,把同一个 x
投影成 q、k、v。
这三个投影矩阵,就是这一层真正要训练的参数。
先说 x 是什么。每个位置(每个 token)进注意力前,已经是一根向量:
x = tok_emb + pos_emb(内容嵌入 + 位置嵌入),长度 n_embd=32。
8 个位置就是 8 根 x。它是"这个 token、在这个位置"的数字表示,是还没分角色的原料。
(注意:这里没有"用户提问"—— attention 里的 query 是每个位置自己发出的,只是个比喻。)
这根 x 是从第 1 章"升级"来的(对比一下就懂它的前因):
01_bigram.py 的 BigramLanguageModel
lm_head 映射回 65vocab 维得分。见 02_attention.py 的 forward
所以 x = tok_emb + pos_emb 只在第 2 章才有:第 1 章嵌入一步出得分;第 2 章把它拆成「嵌入 → 注意力 → lm_head」,中间那个 32n_embd 维向量才是 attention 处理的 x。
再说怎么变。x 变成 q/k/v,靠一次线性投影(其实就是矩阵乘法):
q = x @ Wq、k = x @ Wk、v = x @ Wv。三个权重矩阵各不相同,
代码里就是三个 nn.Linear(n_embd, head_size)。它们不是谁手写的规则 ——
出生时随机初始化,训练中被梯度一点点学出来。下面拨一下看(x 画成 4 维方便上屏,真实是 n_embd=32 维 —— 注意是维度,不是 batch_size):
Wq/Wk/Wv 初始化是三组不同的随机数,所以 q/k/v 出生就天生不同
(因为矩阵不同,不是因为输入不同 —— 它们仨用同一根 x)。训练只为降"预测下一字符"的 loss,
它不懂 query/key/value,但三个能独立拧的旋钮让梯度下降试出"一个擅长发问、一个擅长被检索、一个擅长搬运"最省 loss。q = x @ Wq 的输出被放去"当发问方点积",v = x @ Wv 被放去"被聚合",
这些槽位训练前就由程序员钉死(self.query=…/self.value=…)。涌现的只是每张表把自己那份活干好,模型内部从不会分不清谁是谁。
idx[:,0:8] 算;偏移 1 的目标段 idx[:,1:9] 只在最末尾给 logits 当"标准答案"判分,从不进入 q/k/v。好比 "the cat sat on the ___",答案 "mat" 只判对错,不掺进你的思考。
Wq/Wk/Wv),把预训练早已长出的"抓上下文"能力调成助手口吻。能力是预训练建的,SFT 只拧风格。
用 8 个位置演示注意力分数表 wei(8×8)。
第 i 行第 j 列 = "位置 i 对位置 j 的关注度"。逐步点四个按钮看它怎么成形。
wei 就是 Step 3 那个矩阵 a 的"活"版本 ——
同守两条铁律:① 因果(右上为 0)② 每行和为 1(softmax)。唯一升级:权重不再写死,而是 q·k 由内容算出来的。
q·k 是 32head_size 个乘积相加,摆动幅度天然 ≈ √32 ≈ 5.66 —— 维度越大点积越容易爆炸。
而 softmax 对大数极敏感:分数差 ±5.7,exp 后差约 6 万倍,权重塌成只认一个位置的 one-hot →
梯度消失,学不动。
÷√head_size(代码 * head_size**-0.5)把摆幅拉回 ~1,让 softmax 不论维度多大都待在能学习的温和区间。
Head.forward:k.transpose(-2,-1):把 k 从 8T×32head 对调成 32head×8T,
这样 q @ kᵀ 才乘得出 8T×8T 分数表。只为形状对得上,不改数值。masked_fill(tril==0, -inf):就是掩码,把未来位置填 -∞,softmax 后归零。tril 用 register_buffer 挂成常量(不训练,但随 .to(device) 搬家),区别于会被梯度拧的 nn.Parameter。
最后一步:用权重表 wei 去对所有 value 加权求和,
out = wei @ v —— 这正好对应 Step 3 的 c = a @ b。
每个位置由此得到一根"已经看过它的历史"的新向量。
和 Step 3 的 c = a @ b 是同一件事 —— 用权重表 wei 对 value 加权求和。
唯一不同:这里的权重不再均匀(是 Step 5 按相关性算出来的),相关的位置占比大。点 wei 任意一行,看 out 那一行怎么来。
out 就是每个位置"看过历史后"的新向量。权重均匀 → 就是 Step 3 的死平均;权重由内容算出来 → 才叫注意力。
wei @ v 在算什么?wei 是不是就"全部"了,out 只是 v 的拷贝?q·k → 缩放 → 掩码 → softmax 得到 wei —— 一张"谁看谁、各看多少"的调度表,
每行是和为 1 的比例,本身不含任何内容。两个意思天差地别的句子也可能算出相同的 wei,真正区分它们的是 v。
所以 wei 不是"全部",它是"空的",内容全在 v 里。wei @ v 才是信息真正跨位置流动的一步,也是注意力的主戏。out 不是拷贝,是加权混合出的新向量。看上面 out 任意一行:若该行权重是 [0.7, 0.2, 0.1],
则 out = 0.7·v₀ + 0.2·v₁ + 0.1·v₂ —— 一根原来不存在的向量,既不是 v₀ 也不是任何单行 v。
只有当某个权重 = 1、其余 = 0(硬注意力)时,out 才恰好等于某一行 v,那才叫"拷贝";一般 softmax 出来都是小数,所以是混合。
out 的形状(高、宽)是怎么定的?(m×n) @ (n×p) = (m×p)。这里 wei 是 T×T,v 是 T×head_size,
所以 out = wei @ v 是 T×head_size:out 的高 = wei 的高(T 个位置,每个位置输出一行结果);out 的宽 = v 的宽(head_size,每个位置混出的新向量仍是 head_size 维)。T 被消掉了 —— 这正是"把 T 个 value 沿位置维加权求和、压成一根"的数学体现:
聚合前每个位置面对 T 个历史向量,聚合后只剩它自己那一根。
代码里 AttentionLanguageModel.forward 就是这条链:
forward,而且是承上启下的 —— 抓住这点就抓住了第 2 章。
| 谁的 forward | 干的事 | 角色 |
|---|---|---|
Head.forward |
一个零件:输入 32n_embd 的向量 → q·k、缩放、掩码、softmax、wei@v → 输出 32head 的"看过历史"向量。这是本章全新的东西。 |
启下 新机制 |
Attention.forward |
总装线:嵌入(tok+pos)→ 调用 Head.forward → lm_head 映射回词表 → cross_entropy。这就是第 1 章那条老骨架,只多插了一行 x = sa_head(x)。 |
承上 老骨架 |
外层承接 bigram(嵌入→logits→交叉熵一字没改),内层启出新本事(把"看过去"封成零件)。bigram 只有一个 forward,因为它没有零件可调用。
loss 到底怎么从 logits 算出来?(慢动作)logits 是 32B×8T×65vocab,targets 是 32B×8T(每位置"真正下一个字符"id,偏移 1)。
交叉熵 只认"一摞预测+一摞答案",先摊平成 (256, 65) 和 (256,) ——
32×8=256 道独立的"预测下一字符"小考题。p → 损失 -log(p)(p 越高越不扣分);256 题取平均 = 那个 loss。pos_emb 位置嵌入?x = tok_emb + pos_emb 就是把内容和位置相加。
和第 1 章那张 6×6 表一样,训练的产物就是一堆数字旋钮。第 2 章的"真身"是这 6 块矩阵 —— 出生是随机雪花,训练后被梯度调出结构。点按钮看"训练前 → 训练后"。
(格子是示意,不是真实数值;真实每块更大,如 token_emb 是 65vocab×32n_embd。
q/k/v/wei/out 都不在这 6 块里 —— 它们是现算的中间产物,不是被训练的旋钮。)
| 模型 | 训练出来的"旋钮"(参数真身) | 规模 |
|---|---|---|
| 第1章 bigram |
token_embedding_table一张 65vocab×65vocab 表,既是"内容"又直接是"下一字符打分" |
65×65 ≈ 4,225 |
| 第2章 attention |
token_embedding 65vocab×32n_embdposition_embedding 8block×32n_embdWq/Wk/Wv 各 32n_embd×32headlm_head 32head×65vocab
|
2080+256 +1024×3 +2080 ≈ 9,488 |
nn.Linear(n_embd, head_size) 里输入被钉死成 32n_embd,
输出 head_size 随你挑。本章单头图省事才设相等;第 3 章多头时每个头取 head_size = n_embd / 头数(更小),拼起来才凑回 32n_embd。q / k / v 向量 —— 拿旋钮 Wq/Wk/Wv 乘 x 临时算的。wei 和它的 softmax 权重 —— 当场算的中间结果。"那张表"不是参数。out —— Head.forward 的返回值,是"看过历史"的新向量;不是三张权重表,真身 Wq/Wk/Wv 一直留在模块里。W(旋钮),q/k/v/wei/out 全是"旋钮+输入"算出的咖啡。wei 上,决定"对每个历史位置关注多少";
② 最末尾那次,作用在 logits(65 个)上,决定"下一个字符是谁" ——
第 ② 个才和 bigram 结尾那次是同一个。softmax 自己没有参数。
02_attention.py 训练 5000 步。点「▶ 跑训练」,看注意力的 loss 一路下探,穿过 bigram 卡死的 2.49 虚线 —— 这就是这一章的兑现。
(示意曲线:数量级与趋势对得上 02_attention.py 的实际打印,非逐位精确复现。)
val loss 终于跌破 bigram 那道 2.49 的天花板(2.49 → 2.41,小但真实)。这就是这一章的兑现。