← 返回首页
02_attention_viz.html

自注意力:让每个位置学会"往回看"

配套 02_attention.py。上一章 bigram 只看前 1 个字符,撞在 val loss ≈ 2.49 的墙上。 这一章一步步搭出自注意力 —— 六步看懂直觉,再回去读代码就豁然开朗。

STEP 1
往回看 = 因果
STEP 2
数据的形状 B·T·C
STEP 3
矩阵做平均
STEP 4
Q / K / V
STEP 5
算注意力权重
STEP 6
聚合 · 全貌
超参速查 batch_size 32 句子数 block_size 8 位置数 T n_embd 32 维度 C head_size 32 头维度 ⚠ 三个 32 含义不同,纯属撞脸;block_size 是 8

Step 1 · 核心想法:让每个位置"往回看"它的过去

bigram 的天花板,在于它做预测时只能看前 1 个字符自注意力的想法很简单: 让第 t 个位置,能参考它前面所有位置(0..t)。 点下面任意一个位置,看它能看到谁。

📜 出处:自注意力来自 2017 年的 《Attention Is All You Need》 —— Transformer 的奠基之作。本章搭的是它的核心零件(单头自注意力),第 3 关会完整复现整套架构。

只看过去 = 因果(causal)

预测"下一个字符"时,未来的答案不能偷看。所以每个位置只允许看 0..t,看不到 t+1 之后。

最朴素的"参考过去" = 平均

先不讲复杂的:把能看到的那些位置的向量取平均,就是一种最简单的"往回看"。
💡 这"最多往回看几个 token"就是 block_size = 8block —— 也就是大家天天挂嘴边的上下文窗口(context window)。本章设成 8,只为把直觉一眼看清。
上下文窗口能多大?它和 batch_size 是一回事吗?
GPT-2 是 1024,如今 100 万 token 已是标配 —— 原理一样:再长也只是个固定上限, 超出窗口的更早历史模型看不见(代码 generateidx[:, -block_size:] 就是在裁这个)。
⚠️ 它和 batch_size 不是一回事:block_size=8block 是"往回看几个 token", batch_size=32batch 是"同时并行几条句子"。

下一步:先把喂进去的数据"形状"看清 —— B、T、C 和那几个撞脸的 32 分别是谁。

Step 2 · 数据的形状:B、T、C(以及那几个撞脸的 32)

动手搭注意力前,先看清喂进去的数据长什么"形状"。代码里张量都按 (B, T, C) 三个大写字母标注 —— 拨一拨下面三组交互,把它们和那几个 32 一次对清楚。

A一个 batch 长什么样:B 条句子 × T 个位置

每次训练,get_batch 从全书随机取 B 个起点,每个起点连抓 T 个字符。 点「🎲 再取一个 batch」换一批;点任意一个格子,看它其实是一根向量。

⋮ 这里只画 5 条,真实一个 batch 有 batch_size = 32 条,32 条并行一起算
B = Batch 一批几条句子(并行)= batch_size · T = Time 一条几个位置 = block_size · C = Channels 每个位置的向量多长(下面 B 段细说)。
BC 会变:形状沿着网络一层层流动

点每一站,看 (B,T,C) 里的 C 怎么变 —— BT 一路不变,只有 C 在变。或点「▶ 一路跑到 logits」自动走一遍。

idx
(B,T,)
+emb
(B,T,32)
q/k/v
(B,T,32)
out
(B,T,32)
logits
(B,T,65)
C那几个 32(和一个 8)分别定哪根轴
batch_size = 32→ B 句子数
block_size = 8→ T 位置数
n_embd = 32→ C 嵌入维度
head_size = 32→ C 头维度
⚠️ 别把 Q / K / V 对到 B / T / C 三根轴上

QKV三个张量(三种角色),不是三根轴各分一个。同一个 x 复制 3 份、各用自己的 W 重调,得到的 q / k / v 各自都是完整的 (B,T,32):

x
(B,T,32)
复制 3 份,
各用自己的 W 重调
Q (B,T,32) K (B,T,32) V (B,T,32)
三个都是 32×8×32 的完整立方体,没有谁"只分到一根轴"。正因如此,q @ kᵀ 才能拿两个完整张量两两位置算相关、得到 B×T×T 的分数表。 ("三个"角色 和 "32" 维是两个无关的东西:把 head_size 改成 30,还是张表,只是每个向量 30 个数。Q/K/V 的完整讲解在 Step 4。)
训练时 B、T 永远等于 batch_size、block_size 吗?
训练时(一个 batch 就是 32×8)。但代码写的是 B, T, C = x.shape,动态从形状里读、不写死 —— 因为采样生成时不一样:generate 一次只喂 1 条,且上下文从 1 个字符慢慢长到 8。那时 B=1、T 从 1 涨到 8。 所以严谨说:B/T 是"这批数据当前实际的条数 / 长度",训练时恰好等于那两个超参。
C 是"纯动态"随便变的吗?
不是,它和 B/T 的"动态"是两码事:
· B / T喂进来的数据变(训练 32×8,采样可能 1×3);
· C走到网络第几层变,和数据无关 —— 给定结构,每一层的 C 是多少都是写死、确定的(嵌入处 32、logits 处 65)。
一句话:每一层有它自己固定的 C,不是随机乱跳。
这一批整数(idx)是怎么从文本文件读进来的?
02_attention.py 开头四步,把一本书变成可训练的整数:
open(path, "r", encoding="utf-8") —— "r" = read 只读模式(只读不写;"w" 才是写入并清空原文件)。
text = f.read() —— 把整本书一次性读成一条长字符串(含所有换行、空格、标点),所以 text 是一整条不断行的长字符串。
sorted(set(text)) 去重排序得词表;encode 把每个字符换成整数 id → 整条文本变一长串数字。
torch.tensor(...) 变张量,get_batch 再从中随机切出上面那个 B×T 的整数表。

形状这把尺子拿稳了,下一步:怎么用一次矩阵乘法,既"只看过去"又"取平均"?

Step 3 · 用一个矩阵,实现"对过去做平均"

这正是 02_attention.py A 段做的事。一个下三角、再归一化的权重矩阵 a, 乘上输入 b,就一次性完成"只看过去 + 取平均"。点按钮一步步看 a 怎么来。

⚠️ 这一步是 02_attention.py A 段的纯演示:用固定数字讲清"下三角加权平均"这套数学,不训练、也不是模型本体。真正带 Q/K/V、会被训练的注意力从 Step 4 起。
为什么是 3×3?这 9 个格子是什么?
9 个格子 = 权重表 a,大小 3×3它是"位置 × 位置":行 = 哪个位置在看,列 = 它在看哪个位置。 两边都是 3,是因为这个演示只取了序列的前 3 个位置(如 "To be or" 的 T o ␣)。
点第 4 个按钮「算 c=a@b」后,右边出现的 b3×2(6 格):3 个位置,每个位置一个 2 个数的向量。 所以 位置数a 的行列(3×3)、定 b 的行(3);每个向量多少个数只定 b 的列(2)。
真实尺寸:位置数 = block_size=8block、向量长度 = head_size=32head, 这张 a 就是 8×8 —— 即 Step 5 的热力图。
权重 a (3×3)

下三角 ▾ 管「因果」

权重在哪里 = 0。右上角恒为 0 → 未来位置被钉死,谁也偷看不到。

归一化 ▾ 管「平均」

权重总量 = 1。每行和为 1 → 是加权平均,输出和输入同量级,不会越往后越爆。

下一步:均匀平均太"蠢"——每个过去位置一视同仁。怎么让"谁相关就多看谁"?

Step 4 · 升级:把"死平均"换成"按内容加权"——Q / K / V

上一步的权重是写死的(每个过去位置一样多)。真实语言里,有的词更该被关注。 我们想让权重由内容算出来。办法:把同一个输入 x,投影成三种角色。 像在图书馆找书 ——

Query (q)

我在找什么
心里的检索词:
"我要找讲恐龙的内容"

Key (k)

我宣传自己是什么
书脊上的标签:
"本书关于:恐龙"

Value (v)

你看我,我给你什么
翻开后的正文:
真正被取走的信息
代码里就是三个 nn.Linear:self.query / self.key / self.value,把同一个 x 投影成 qkv这三个投影矩阵,就是这一层真正要训练的参数。
为什么 k、v 要分开,不合成一个?
因为"我宣传的标签"和"我实际给的货"可以不同 —— 标签(key)为了好被搜到,正文(value)才是真信息。
那 x 到底怎么变成 q / k / v?

先说 x 是什么。每个位置(每个 token)进注意力前,已经是一根向量: x = tok_emb + pos_emb(内容嵌入 + 位置嵌入),长度 n_embd=32。 8 个位置就是 8 根 x。它是"这个 token、在这个位置"的数字表示,是还没分角色的原料。 (注意:这里没有"用户提问"—— attention 里的 query 是每个位置自己发出的,只是个比喻。)

这根 x 是从第 1 章"升级"来的(对比一下就懂它的前因):

第1章 bigramtoken id ─查表→ logits(65vocab) 嵌入表 65vocab×65vocab,查出来直接就是得分;没有语义空间、没有位置(只看 1 个字符,顺序无意义)。见 01_bigram.pyBigramLanguageModel
第2章 attentiontoken id ─查表→ 32n_embd维向量 +位置向量 ─注意力→ ─lm_head→ logits(65vocab) 嵌入表改成 65vocab×32n_embd,先进 32n_embd 维语义空间(还不是得分);位置另学一份加进来;混合完最后才用 lm_head 映射回 65vocab 维得分。见 02_attention.pyforward

所以 x = tok_emb + pos_emb 只在第 2 章才有:第 1 章嵌入一步出得分;第 2 章把它拆成「嵌入 → 注意力 → lm_head」,中间那个 32n_embd 维向量才是 attention 处理的 x

再说怎么变。x 变成 q/k/v,靠一次线性投影(其实就是矩阵乘法): q = x @ Wqk = x @ Wkv = x @ Wv。三个权重矩阵各不相同, 代码里就是三个 nn.Linear(n_embd, head_size)。它们不是谁手写的规则 —— 出生时随机初始化,训练中被梯度一点点学出来。下面拨一下看(x 画成 4 维方便上屏,真实是 n_embd=32 维 —— 注意是维度,不是 batch_size):

x(输入向量)· 这 4 个数是随手填的占位
(4 个独立的数 2、0、1、3,不是"2013")
真实是 n_embd=32 个(维度,非 batch),挑小数好心算
× Wq可训练矩阵
q
× Wk可训练矩阵
k
× Wv可训练矩阵
v
一次线性投影 = 输出的每个数字,都是输入 x 所有数字的一个加权组合(权重就是矩阵里的数)。 同一个 x 过三个不同矩阵,被"调"出三种不同向量 —— 不是把 x 切成三段,而是把整根 x 重新组合三次。
q/k/v 凭什么"思路不一样"?哪张是 Q,谁定的?
不一样,是练出来的:Wq/Wk/Wv 初始化是三组不同的随机数,所以 q/k/v 出生就天生不同 (因为矩阵不同,不是因为输入不同 —— 它们仨用同一根 x)。训练只为降"预测下一字符"的 loss, 它不懂 query/key/value,但三个能独立拧的旋钮让梯度下降试出"一个擅长发问、一个擅长被检索、一个擅长搬运"最省 loss。
但"哪张是 Q"是写死的,不是模型自己选:角色由"输出被塞进公式哪个槽位"决定 —— q = x @ Wq 的输出被放去"当发问方点积",v = x @ Wv 被放去"被聚合", 这些槽位训练前就由程序员钉死(self.query=…/self.value=…)。涌现的只是每张表把自己那份活干好,模型内部从不会分不清谁是谁。
"三个 QKV" 和维度 head_size 是一回事吗?q/k/v 从哪段数据算?
是两个无关的轴:""=有几种角色(query/key/value,架构写死=3); 32head_size=每个向量多长(超参,可调)。把 head_size 改成 30, 还是恰好三张表,只是每个 q/k/v 从 32 个数变 30 个 —— 不会变成"30 个 QKV"。想要第 4 种角色得换架构,拉维度做不到。
另外:q、k、v 全部只从输入段 idx[:,0:8];偏移 1 的目标段 idx[:,1:9] 只在最末尾给 logits 当"标准答案"判分,从不进入 q/k/v。好比 "the cat sat on the ___",答案 "mat" 只判对错,不掺进你的思考。
为什么偏偏是 QKV?是为将来 SFT 问答"挖的坑"吗?
一个常见的因果倒置,正过来:
  • QKV(2017)为机器翻译而生,解决当下问题:每个位置按内容去别处取相关信息。而问答 SFT(2022)晚了五年 —— 不可能"先挖坑等填"。
  • "3" 照搬查字典结构:Query 找什么 · Key 怎么宣传自己 · Value 取走的内容。Q、K 分开是因为"按什么匹配"和"匹配上给什么"本是两件事。
  • SFT 不"填坑":它只是续训同一批参数(含 Wq/Wk/Wv),把预训练早已长出的"抓上下文"能力调成助手口吻。能力是预训练建的,SFT 只拧风格。
  • 能换成别的吗?换过。RNN/CNN/加性注意力都做过同样的事。QKV 胜在能并行、表达力强、好堆叠,不是"适合问答"。问答能力来自"规模 × 预训练一个通用机制"。
一句话:不是挖坑等填料,而是造了把通用瑞士军刀,后来发现它能开的罐头远超当年只想削的那个苹果。

下一步:有了 qk,怎么算出"谁该多看谁"的那张权重表?

Step 5 · 算注意力权重:打分 → 缩放 → 掩码 → softmax

用 8 个位置演示注意力分数表 wei(8×8)。 第 i 行第 j 列 = "位置 i 对位置 j 的关注度"。逐步点四个按钮看它怎么成形。

看穿这一步:最终的 wei 就是 Step 3 那个矩阵 a 的""版本 —— 同守两条铁律:① 因果(右上为 0)② 每行和为 1(softmax)。唯一升级:权重不再写死,而是 q·k 由内容算出来的。
第 ② 步为什么非要 ÷√d 不可?
分数 q·k32head_size 个乘积相加,摆动幅度天然 ≈ √32 ≈ 5.66 —— 维度越大点积越容易爆炸。 而 softmax 对大数极敏感:分数差 ±5.7,exp 后差约 6 万倍,权重塌成只认一个位置的 one-hot → 梯度消失,学不动。 ÷√head_size(代码 * head_size**-0.5)把摆幅拉回 ~1,让 softmax 不论维度多大都待在能学习的温和区间
代码里 transpose / masked_fill / register_buffer 在干嘛?
都在 Head.forward:
k.transpose(-2,-1):把 k8T×32head 对调成 32head×8T, 这样 q @ kᵀ 才乘得出 8T×8T 分数表。只为形状对得上,不改数值。
masked_fill(tril==0, -inf):就是掩码,把未来位置填 -∞,softmax 后归零。
③ 那张 trilregister_buffer 挂成常量(不训练,但随 .to(device) 搬家),区别于会被梯度拧的 nn.Parameter

下一步:有了权重表 wei,拿它去聚合 value,得到每个位置的新向量。

Step 6 · 加权聚合,与一次 forward 的全貌

最后一步:用权重表 wei 去对所有 value 加权求和, out = wei @ v —— 这正好对应 Step 3 的 c = a @ b。 每个位置由此得到一根"已经看过它的历史"的新向量。

A聚合:out = wei @ v(点一行,看它怎么把历史揉成一根新向量)

和 Step 3 的 c = a @ b同一件事 —— 用权重表 weivalue 加权求和。 唯一不同:这里的权重不再均匀(是 Step 5 按相关性算出来的),相关的位置占比大。wei 任意一行,看 out 那一行怎么来。

权重 wei(4×4,每行和=1)
@
value v(4 位置 × 3 维示意)
=
out(4 × 3)

out 就是每个位置"看过历史后"的新向量。权重均匀 → 就是 Step 3 的死平均;权重由内容算出来 → 才叫注意力

wei @ v 在算什么?wei 是不是就"全部"了,out 只是 v 的拷贝?
注意力分两半,角色完全不同:
算权重(路由):q·k → 缩放 → 掩码 → softmax 得到 wei —— 一张"谁看谁、各看多少"的调度表, 每行是和为 1 的比例,本身不含任何内容。两个意思天差地别的句子也可能算出相同的 wei,真正区分它们的是 v。 所以 wei 不是"全部",它是"空的",内容全在 v 里。
搬内容(聚合):wei @ v 才是信息真正跨位置流动的一步,也是注意力的主戏。
out 不是拷贝,是加权混合出的新向量。看上面 out 任意一行:若该行权重是 [0.7, 0.2, 0.1], 则 out = 0.7·v₀ + 0.2·v₁ + 0.1·v₂ —— 一根原来不存在的向量,既不是 v₀ 也不是任何单行 v。 只有当某个权重 = 1、其余 = 0(硬注意力)时,out 才恰好等于某一行 v,那才叫"拷贝";一般 softmax 出来都是小数,所以是混合
out 的形状(高、宽)是怎么定的?
矩阵乘法 (m×n) @ (n×p) = (m×p)。这里 weiT×T,vhead_size, 所以 out = wei @ vhead_size:
· out 的高 = wei 的高(T 个位置,每个位置输出一行结果);
· out 的宽 = v 的宽(head_size,每个位置混出的新向量仍是 head_size 维)。
中间那个 T消掉了 —— 这正是"把 T 个 value 沿位置维加权求和、压成一根"的数学体现: 聚合前每个位置面对 T 个历史向量,聚合后只剩它自己那一根。
B放进模型:一次 forward 走完整条路

代码里 AttentionLanguageModel.forward 就是这条链:

idx输入字符
tok_emb + pos_emb内容 + 位置
自注意力头q,k,v → wei → wei@v
lm_head映射回词表
logits下一字符打分
🔑 全章其实就两个 forward,而且是承上启下的 —— 抓住这点就抓住了第 2 章。
谁的 forward干的事角色
Head
.forward
一个零件:输入 32n_embd 的向量 → q·k、缩放、掩码、softmax、wei@v → 输出 32head 的"看过历史"向量。这是本章全新的东西。 启下
新机制
Attention
LM

.forward
总装线:嵌入(tok+pos)→ 调用 Head.forwardlm_head 映射回词表 → cross_entropy这就是第 1 章那条老骨架,只多插了一行 x = sa_head(x) 承上
老骨架

外层承接 bigram(嵌入→logits→交叉熵一字没改),内层启出新本事(把"看过去"封成零件)。bigram 只有一个 forward,因为它没有零件可调用。

loss 到底怎么从 logits 算出来?(慢动作)
logits32B×8T×65vocab,targets32B×8T(每位置"真正下一个字符"id,偏移 1)。 交叉熵 只认"一摞预测+一摞答案",先摊平(256, 65)(256,) —— 32×8=256 道独立的"预测下一字符"小考题
逐题:65 个分数 → softmax → 取"正确答案"被分到的概率 p → 损失 -log(p)(p 越高越不扣分);256 题取平均 = 那个 loss
和第 1 章 bigram 同一行、同一算法 —— 区别只是 bigram 的 65 个分数只看当前 1 个字符,attention 的分数来自已吸收前文的向量。判分没变,被判的预测信息更足了。
为什么要多一个 pos_emb 位置嵌入?
注意力本身看不见顺序 —— 把输入打乱,输出只会跟着一起打乱,权重不变。 所以要额外加一份"位置信息"告诉模型"这是第几个位置",x = tok_emb + pos_emb 就是把内容和位置相加。
C参数真身:训练到底练出了什么

和第 1 章那张 6×6 表一样,训练的产物就是一堆数字旋钮。第 2 章的"真身"是这 6 块矩阵 —— 出生是随机雪花,训练后被梯度调出结构。点按钮看"训练前 → 训练后"。

(格子是示意,不是真实数值;真实每块更大,如 token_emb65vocab×32n_embdq/k/v/wei/out不在这 6 块里 —— 它们是现算的中间产物,不是被训练的旋钮。)

这 6 块各多大?参数量怎么算(对照 bigram)?
和 bigram 用同一个交叉熵、同一套梯度下降,只是旋钮变多了:
模型训练出来的"旋钮"(参数真身)规模
第1章
bigram
token_embedding_table
一张 65vocab×65vocab 表,既是"内容"又直接是"下一字符打分"
65×65
≈ 4,225
第2章
attention
token_embedding 65vocab×32n_embd
position_embedding 8block×32n_embd
Wq/Wk/Wv32n_embd×32head
lm_head 32head×65vocab
2080+256
+1024×3
+2080
≈ 9,488
这些维度谁定的、越大越好吗?head_size 必须等于 n_embd 吗?
32n_embd32head 都是你自己设的超参(容量旋钮):维度越大能塞的信息越多 —— 但不是无脑越大越好,参数算力跟着涨、数据少会过拟合。取 32 是玩具级(GPT-2 是 768,GPT-3 是 12288)。
head_size 不"必须等于" n_embd:nn.Linear(n_embd, head_size)输入被钉死成 32n_embd, 输出 head_size 随你挑。本章单头图省事才设相等;第 3 章多头时每个头取 head_size = n_embd / 头数(更小),拼起来才凑回 32n_embd
方向上:比 n_embd 小是常用的(多头压缩,每头看一个子空间);比 n_embd 大基本是冗余 —— 输出再多也只是 32n_embd 维输入的线性组合,装不下更多信息。
哪些"不是真身"?(q / k / v / wei / out 别记混)
它们每来一个输入就现算一遍,从不存储,梯度下降也不直接拧:
  • q / k / v 向量 —— 拿旋钮 Wq/Wk/Wvx 临时算的。
  • 得分表 wei 和它的 softmax 权重 —— 当场算的中间结果。"那张表"不是参数。
  • out —— Head.forward 的返回值,是"看过历史"的新向量;不是三张权重表,真身 Wq/Wk/Wv 一直留在模块里。
  • 被拧的只有矩阵 W(旋钮),q/k/v/wei/out 全是"旋钮+输入"算出的咖啡。
本章出现了两个 softmax,别混
① Step 5 内部那次,作用在 wei 上,决定"对每个历史位置关注多少"; ② 最末尾那次,作用在 logits(65 个)上,决定"下一个字符是谁" —— 第 ② 个才和 bigram 结尾那次是同一个。softmax 自己没有参数。
D跑起来:loss 怎么跌破 bigram 那道 2.49 的墙

02_attention.py 训练 5000 步。点「▶ 跑训练」,看注意力的 loss 一路下探,穿过 bigram 卡死的 2.49 虚线 —— 这就是这一章的兑现。

attention(本章)val loss bigram 的墙 ≈ 2.49 step 0 · loss 4.17

(示意曲线:数量级与趋势对得上 02_attention.py 的实际打印,非逐位精确复现。)

🎯 结果:打破 2.49 的墙。装上注意力后,每个位置都能参考前文, val loss 终于跌破 bigram 那道 2.49 的天花板(2.49 → 2.41,小但真实)。这就是这一章的兑现。

点右下角「完成 🎉」收下这一章。