本页对应 phase1-124m/05_sample.py。模型每一步只做一件事:对下一个 token 打分(logits)。
这一页把"打分 → 怎么挑一个字 → 一个字接一个字滚成句子 → 怎么加速"五步拆开,每步亲手拨一个旋钮。
演示用一个小示意词表(标着「示意」),真实里词表是 50304vocab 个 token —— 其中 50257 个是真 token,另 47 个是为对齐 GPU 补的空位,采样时会被屏蔽掉,所以每步实际只在 50257 个里挑。
给一段开头(prompt), 模型把它过一遍前向,在最后一个位置吐出对下一个 token 的一整排打分,再 softmax 成概率。 它不答题、只续写。换上一个"问句"开头,你会看到正确答案那个 token 并不在高处 —— 因为基座模型从没学过"回答"。
Q: 2+2=? A:,
它会接着续一段像那种文本的东西(再问一句、换行、复述题目…),而不是真去算"4"。
"会顺着指令回答问题"是 SFT
在 Phase 2 用「指令→回答」数据另外教出来的。本页(Phase 1)的模型还停在"只会续写"。
采样前先做一步 logits / temperature 再 softmax。温度低 → 分布变尖,几乎只挑最高分那个(保守、稳,但容易复读);
温度高 → 分布变平,小众 token 也有机会(发散、有创意,但容易胡说)。真实默认
0.9temp。拖滑块实时看分布形状变化(用 STEP 1 的续写示意分布)。
sample_next 里的一行 logits = logits / temperature。
就算温度不高,词表里仍有一长串低概率 token,偶尔被抽中就是一处"胡言"。 top_k 只保留最高的 k 个,其余打分设成 −∞(等于概率 0),再归一化采样。 真实默认 50top_k(在 50257 个真 token 里留 50 个)。这里示意词表只有 8 个,拖 k 看尾巴被切、概率质量重新分配。
top_k=1 等于每步都取最高分(贪心 argmax)。
问题是它没有多样性:同一个开头永远生成同一句,还容易陷入"复读机"死循环。所以实践里留一小撮候选(k=50)再
随机采样,
兼顾"不胡说"和"有变化"。
模型一次只吐一个 token。把它拼回输入末尾,整条序列再喂一遍、再采下一个,如此滚动 —— 这就是 自回归生成。 点「采下一个 token」,看序列一格一格长出来(灰=prompt,橙=生成;用预设示意分布,每步顺带亮一下被采中的候选)。
05_sample.py 用 --max_new_tokens(默认 200)定额生成多少个就停。
每步还会把上下文截到最近 1024block_size 个 token(idx[:, -block_size:]),
因为模型一次最多看这么长。(SFT 之后的对话模型会专门学一个"结束符"来自己收尾,那是 Phase 2 的事。)
05_sample.py 里加的推理优化)注意力要用到序列里每个位置的 K/V。 无 cache:每生成一个 token,都把过去所有位置的 K/V 从头重算一遍(下图不断变大的橙色三角 = 纯浪费); 有 cache:把过去的 K/V 缓存下来,每步只算新 token 那一列(绿)。点「下一步」并排看。
1+2+…+n 列(随长度平方级膨胀);有 cache 只算 n 列(线性)。点上面逐步对比。
05_sample.py 在 Mac MPS 上实测:有 cache 比无 cache 快约 2.5–3.1×
(例如某次 7.4 → 23.1 tok/s ≈ 3.1×,另一次 17.5 → 48 tok/s ≈ 2.7×;序列越长、省得越多)。
代码里就是 generate_nocache(每步重喂整段)对 generate_cached(prompt 预填充一次,之后每步只喂上一个新 token + 历史 KV)。
05_sample.py 里这步连因果掩码都省了(is_causal = q.size(2) > 1,只有一次喂多个 token 的 prompt 预填充才需要掩码)。