← 返回首页
05_sampling_viz.html

跑起来:从 logits 到文字(采样与 KV-cache)

本页对应 phase1-124m/05_sample.py。模型每一步只做一件事:对下一个 token 打分(logits)。 这一页把"打分 → 怎么挑一个字 → 一个字接一个字滚成句子 → 怎么加速"五步拆开,每步亲手拨一个旋钮。 演示用一个小示意词表(标着「示意」),真实里词表是 50304vocab 个 token —— 其中 50257 个是真 token,另 47 个是为对齐 GPU 补的空位,采样时会被屏蔽掉,所以每步实际只在 50257 个里挑。

STEP 1
base 只会续写
STEP 2
temperature 温度
STEP 3
top_k 砍尾巴
STEP 4
自回归循环
STEP 5
KV-cache 加速
采样旋钮 temperature 0.9 默认 top_k 50 默认 vocab 50304 词表大小 block_size 1024 能回看多远 演示词表只 8 个是示意 · 真实每步在 50257 个真 token 里挑 · 配套 05_sample.py

① base 模型只会"续写":它永远在猜下一个 token

给一段开头(prompt), 模型把它过一遍前向,在最后一个位置吐出对下一个 token 的一整排打分,再 softmax 成概率。 它不答题、只续写。换上一个"问句"开头,你会看到正确答案那个 token 并不在高处 —— 因为基座模型从没学过"回答"。

下一个 token概率分布(只截了示意词表的前 8 名)
"它不会答题" 到底什么意思?那 ChatGPT 怎么会答?
base 模型只学过"海量文本里,这个词后面通常接什么"。所以给它 Q: 2+2=? A:, 它会接着续一段像那种文本的东西(再问一句、换行、复述题目…),而不是真去算"4"。 "会顺着指令回答问题"是 SFTPhase 2 用「指令→回答」数据另外教出来的。本页(Phase 1)的模型还停在"只会续写"。

同一排 logits,挑字的方式不只一种。先拨第一个旋钮:temperature —— 它决定模型有多"敢"。

② temperature:一个旋钮,把分布捏尖或摊平

采样前先做一步 logits / temperature 再 softmax。温度低 → 分布变,几乎只挑最高分那个(保守、稳,但容易复读); 温度高 → 分布变,小众 token 也有机会(发散、有创意,但容易胡说)。真实默认 0.9temp。拖滑块实时看分布形状变化(用 STEP 1 的续写示意分布)。

0.90
0.1 极保守0.9 默认2.0 极发散
为什么除以温度就能改变"尖/平"?
softmax 比的是 logits 之间的差距。除以一个小温度(如 0.3)会把所有差距放大,最高分一骑绝尘 → 分布更尖; 除以一个大温度(如 1.8)把差距压扁,大家拉近 → 分布更平。温度 = 1 就是原样。 代码里就是 sample_next 里的一行 logits = logits / temperature

温度只是把整排概率重新塑形 —— 长长的"垃圾尾巴"还在。下一步用 top_k 直接把尾巴一刀切掉。

③ top_k:只在概率最高的 k 个里采样,尾巴一刀切掉

就算温度不高,词表里仍有一长串低概率 token,偶尔被抽中就是一处"胡言"。 top_k 只保留最高的 k 个,其余打分设成 −∞(等于概率 0),再归一化采样。 真实默认 50top_k(在 50257 个真 token 里留 50 个)。这里示意词表只有 8 个,拖 k 看尾巴被切、概率质量重新分配。

4
1 只留冠军(=贪心)越大留得越多8 全留(=不裁剪)
0.90
先裁 top_k,再按这个温度采样
top_k=1 和"直接取最高分"是一回事吗?为什么不总用它?
是的,top_k=1 等于每步都取最高分(贪心 argmax)。 问题是它没有多样性:同一个开头永远生成同一句,还容易陷入"复读机"死循环。所以实践里留一小撮候选(k=50)再 随机采样, 兼顾"不胡说"和"有变化"。

现在你会"挑一个 token"了。可一句话有几十个 token —— 怎么从一个变成一串?靠自回归循环。

④ 自回归:采一个 → 接回末尾 → 再采下一个

模型一次只吐一个 token。把它拼回输入末尾,整条序列再喂一遍、再采下一个,如此滚动 —— 这就是 自回归生成。 点「采下一个 token」,看序列一格一格长出来(灰=prompt,橙=生成;用预设示意分布,每步顺带亮一下被采中的候选)。

序列当前文本(末尾闪烁处 = 下一个要采的位置)
这一步末位 token 的 logits → softmax → 采样(示意)
点「采下一个 token」开始 ↑
什么时候停?会一直生成下去吗?
base 模型不会自己"说完了"。05_sample.py--max_new_tokens(默认 200)定额生成多少个就停。 每步还会把上下文截到最近 1024block_size 个 token(idx[:, -block_size:]), 因为模型一次最多看这么长。(SFT 之后的对话模型会专门学一个"结束符"来自己收尾,那是 Phase 2 的事。)

注意:每采一个新字,上面都把整条序列重过一遍 —— 越生成越长、越来越慢。最后一步:KV-cache 把这份浪费省掉。

⑤ KV-cache:别把过去重算一遍(05_sample.py 里加的推理优化)

注意力要用到序列里每个位置的 K/V无 cache:每生成一个 token,都把过去所有位置的 K/V 从头重算一遍(下图不断变大的橙色三角 = 纯浪费); 有 cache:把过去的 K/V 缓存下来,每步只算新 token 那一列(绿)。点「下一步」并排看。

无 cache · 每步重算整段
每生成一格,过去所有列全部重算一遍(橙)。
K/V 累计计算量0
有 cache · 每步只算新列
过去的 K/V 缓存复用(浅绿),每步只新增 1 列(深绿)。
K/V 累计计算量0
生成 n 个 token:无 cache 要算 1+2+…+n 列(随长度平方级膨胀);有 cache 只算 n 列(线性)。点上面逐步对比。
那"计算量省了一大截",为什么实测只快 2.5–3.1 倍,不是好几十倍?
上图省掉的是注意力里重复的 K/V 计算,但一次前向还有很多别的开销(MLP、LayerNorm、采样、框架调度),这些没省。 所以墙钟加速没有列数比那么夸张05_sample.py 在 Mac MPS 上实测:有 cache 比无 cache 快约 2.5–3.1× (例如某次 7.4 → 23.1 tok/s ≈ 3.1×,另一次 17.5 → 48 tok/s ≈ 2.7×;序列越长、省得越多)。 代码里就是 generate_nocache(每步重喂整段)对 generate_cached(prompt 预填充一次,之后每步只喂上一个新 token + 历史 KV)。
缓存了过去,新 token 还看得到全部历史吗?会不会偷看未来?
看得到全部历史:新 token 的 query 去和"缓存的过去 K/V + 自己这一列"做注意力,信息一点不少。 也不会偷看未来 —— 解码时每步只来 1 个新 token,它后面根本还没生成,自然无未来可看; 所以 05_sample.py 里这步连因果掩码都省了(is_causal = q.size(2) > 1,只有一次喂多个 token 的 prompt 预填充才需要掩码)。

Phase 1 到此收口:你已经能让这台机器"跑起来"续写文字了。但它只会续写、不会答题 —— 教它"听话回答",是 Phase 2 的 SFT。点右下角「完成 🎉」。