Files
xserv/docs/22-speculative-decoding.md
Gahow Wang ce7229f4fe speculative: Qwen3 draft-model v0 with paged verify parity
Phase 22 lands a correctness-only speculative decoding loop for Qwen3
target + Qwen3 small draft (batch=1, greedy, gamma=4). Phase 23 turns
verify logits into the authoritative acceptance signal so mirror-decode
per accepted token is no longer needed.

- paged_kv_cache: truncate_sequence(slot, new_len) shrinks a registered
  sequence, freeing whole physical blocks no longer reachable and
  leaving the slot registered. Covered by a CUDA-gated unit test.
- qwen3: forward_verify_paged_decode_attention writes the draft window
  into the target cache, runs the same paged decode attention kernel per
  draft token, and uses matmul_rows_gemv so linear layers follow the
  single-token decode BF16 rounding path.
- bench-speculative: new bench binary drives the state machine with
  --gamma / --gen-tokens / --prompts / --use-verify-logits /
  --verify-path flash|paged-decode / --dump-verify-mismatches, and
  compares baseline vs spec token sequences plus TPOT / tok/s / speedup.
- docs/22 records the decode-authoritative v0 result and dash5 numbers
  (matched=true, speedup_e2e ~0.29x, verify_decode_mismatches>0 under
  --use-verify-logits).
- docs/23 records the paged-decode verify path (matched=true,
  verify_decode_mismatches=0, 50x64 speedup_e2e ~0.44x) and the
  next-step performance TODO.
2026-07-01 14:16:30 +08:00

187 lines
8.1 KiB
Markdown

# Phase 22: Draft-Model Speculative Decoding v0
> 目标:实现一个可验证的 speculative decoding 最小闭环。先只覆盖
> Qwen3 target + 同 tokenizer 的小 Qwen3 draft、batch=1、greedy
> (`temperature=0`)。本阶段不做 gpt-oss,不做 sampling rejection,不接入
> continuous batching。
## 1. Scope
本阶段只解决一个窄问题:
- target:现有 Qwen3 paged KV 路径,优先 Qwen3-8B;
- draft:同 tokenizer 的小 Qwen3,例如 Qwen3-0.6B;
- batch size:1;
- decoding:greedy argmax;
- draft window:`gamma=4`;
- acceptance:exact-match,即 `target_argmax == draft_token`
HTTP flag 可以后续接入。v0 先提供独立 bench/CLI,因为它能直接输出 token
一致性、acceptance rate、tokens/target-step、TPOT/tok/s,也避免把尚未稳定的
rollback 行为放进服务端调度循环。
bench 为了让 baseline/spec 对比不受跨 prompt KV pool 复用影响,每个 prompt 的
baseline run 和 speculative run 都使用新建的 paged KV cache。cache 分配发生在
单次 run 的计时外,输出的 TPOT/tok/s 只覆盖模型 prefill/decode 工作。
## 2. Why Qwen3 First
Qwen3 是现有代码里最适合作为 speculative v0 的模型族:
1. target 已有稳定的 `forward_prefill_paged``forward_decode_paged`;
2. 小 Qwen3 与 Qwen3-8B 共享 tokenizer,可以直接比较 token id;
3. Qwen3 是 dense decoder-only,没有 gpt-oss 的 harmony 格式、MoE sparse 路径、
sliding-window 或 CUDA Graph 状态;
4. greedy 输出的正确性定义简单:只要 spec 生成的 token 序列与纯 target greedy
完全一致即可。
gpt-oss spec 需要先定义 harmony prompt、MoE draft 选择、graph replay 与 rollback
的交互,这些都不属于本阶段。
## 3. Algorithm
对每个 prompt 建两套模型、三套 KV 状态:
```text
target model + target commit PagedKVCache
target model + target verify PagedKVCache
draft model + draft PagedKVCache
```
先把 prompt 分别 prefill 到三套 cache。此时 cache 都包含 prompt,并各自持有
"下一个 token" 的 logits。
每个 speculative round:
1. draft 从当前 draft logits 取 argmax,连续生成 `gamma` 个 draft token;
2. draft 每生成一个 token 就用自己的 paged decode append 到 draft KV,所以 round
结束时 draft cache 暂时包含整个草稿序列;
3. target verify cache 对完整 draft token 序列调用一次 paged prefill,覆盖
"target 可一次验证草稿窗口" 这条执行路径;
4. target verify cache 立刻 rollback 到 round 起点,避免把 prefill 临时写入污染
commit cache;
5. 用 target decode 轨迹作为权威结果,从左到右比较
`target_next_argmax == draft_token`,只接受连续匹配前缀;
6. 对每个接受 token,用 target decode 重放一次来提交 target KV,并得到下一步
`target_next_argmax`;verify cache 也 mirror decode 同一个 token,保持长度与 prefix 对齐;
7. 若全部匹配,draft cache 已经包含完整草稿,三套 cache 长度重新对齐;
8. 若在第 `k` 个 token 拒绝,提交前 `k` 个 draft token,再提交 target 在该位置的
argmax 作为修正 token。draft cache rollback 到 round 起点后重放接受 token 和修正
token,target commit/verify cache 都由 decode 路径提交到同一 prefix。
v0 不使用完整 speculative sampling 的概率校正。它只利用小模型猜测 greedy 轨迹,
因此生成序列必须与纯 target greedy 完全一致。
当前实现选择 decode 轨迹作为提交路径,而不是直接保留 target prefill 写入的 KV。
原因是 v0 验收要求 token 序列与纯 target greedy 完全一致;如果 prefill 和 decode
路径在数值或 KV 写入顺序上存在细微差异,直接提交 prefill KV 会让后续 greedy 输出
漂移。这个保守实现仍会执行 target paged prefill 验证和 rollback,但 verify 写入放在
独立 cache,不会影响权威 commit cache。代价是额外 mirror decode,速度收益预期较差,
主要用于先验证 draft-model speculative 的状态机和一致性。
为保证 greedy exactness,decode 里两个原有非确定点也需要固定:
- BF16 GEMV 不再用跨 K-block `atomicAdd`;改为写 K-block partials,再按固定顺序
reduce;
- paged decode attention 不再用 `atomicAdd` 合并 warp 输出;改为 per-warp partials
后按 warp id 顺序 reduce。
## 4. KV Commit And Rollback
现有 `forward_prefill_paged` 会一次性把传入 token 写进 paged KV,并提前推进
`seq_len`。验证草稿时 target verify cache 因此会临时包含整个 draft window。
新增的 cache 操作只做逻辑截断:
```text
truncate_sequence(slot, new_len)
```
约束:
- 只允许 `new_len <= current_len`;
- 保留覆盖 `[0, new_len)` 所需的物理 block;
- 释放右侧多余 block;
- 不清零仍在保留 block 内的旧字节,因为后续逻辑长度会阻止 attention 读取它们,
同一位置再次写入时会覆盖旧值;
- slot 仍保持 registered,`new_len=0` 时也保留第一个 block。
这让 target 和 draft 都能在拒绝时安全丢弃多写 KV,并在修正 token decode 后重新
对齐。
## 5. Acceptance Criteria
本阶段验收:
- `cargo fmt`;
- `cargo check`;
- `cargo test`;
- `bench-speculative` 可加载 target+draft 两套 Qwen3;
- 50 prompts,greedy,baseline target 与 speculative token id 序列完全一致;
- 输出 acceptance rate、tokens/target-step、TPOT、tok/s 和 speedup;
- 若 draft 模型缺失或磁盘不足,明确报告阻塞条件,不盲目下载大模型。
## 6. Validation Results
dash5 环境:
- GPU:RTX 5090,device 0;
- target:`/opt/wjh/models/qwen3-8b`;
- draft:`/dashscope-tmp/wjh/models/qwen3-0.6b`;
- command:`bench-speculative ... --prompts 50 --gen-tokens 32 --gamma 4 --device 0`;
- log:`/dashscope-tmp/wjh/xserv-spec-default-50x32-final.log`
默认 `acceptance_mode=decode` 的结果:
```text
prompts=50 matched=true
acceptance_rate=0.3664 accepted=1020 proposed=2784
tokens_per_target_step=0.3639 target_steps=4397
verify_steps=729 mirror_decode_steps=1550 commit_decode_steps=1550 correction_steps=568
verify_decode_mismatches=10
baseline_e2e_tpot_ms=13.123 baseline_e2e_tok_s=76.204
spec_e2e_tpot_ms=44.867 spec_e2e_tok_s=22.288 speedup_e2e=0.2925
baseline_decode_tpot_ms=12.638 baseline_decode_tok_s=79.127
spec_decode_tpot_ms=43.731 spec_decode_tok_s=22.867 speedup_decode=0.2890
decode_token_counts baseline=1600 spec=1600
```
诊断 `--use-verify-logits` 的结果:
- command:`bench-speculative ... --prompts 10 --gen-tokens 32 --gamma 4 --device 0 --use-verify-logits`;
- log:`/dashscope-tmp/wjh/xserv-spec-verify-logits-10x32.log`;
- exit status:`2`;
- summary:`matched=false`, `verify_decode_mismatches=4`;
- prompt 0/2/7 出现 baseline/spec token 序列分叉。
结论:当前可以做 correctness-first 的 speculative decoding 状态机,但还不能把
target batched prefill verify logits 作为 greedy 接受依据。verify prefill 路径与
逐 token decode 路径存在 top-1 不一致;默认模式必须继续以 decode 轨迹为权威,
因此 v0 是正确性闭环,不是性能优化。
## 7. Known Limits
- 只支持 batch=1;
- 只支持 Qwen3-family dense models;
- 只支持 greedy exact-match acceptance;
- 未实现 probabilistic rejection sampling,所以 temperature/top-k/top-p 不支持;
- 未接 HTTP/continuous batching;
- 未与 CUDA Graph decode 结合;
- 当前 v0 为保证 greedy exactness,接受 token 也会用 target decode 重放提交,因此
即使 acceptance 高也可能变慢;
- draft prefill 和 target prefill 都会计入端到端耗时,短输出可能没有收益。
## 8. Next Phase TODO
如果继续 speculative decoding,下一阶段不要先接 HTTP,应先解决 verify 路径:
1. 做最小 prefill-vs-decode parity harness:固定 prompt、cache len、draft token,
dump 每层/最终 logits 的 top-k,定位 top-1 分叉来自 attention、GEMV 还是 KV 写入顺序;
2.`--use-verify-logits` 在至少 50 prompts x 64 tokens 下 `matched=true`
`verify_decode_mismatches=0`;
3. parity 过后再做真正 multi-token target commit:要么安全保留 verify prefill 写入的
KV,要么实现专用 paged multi-token verify/commit kernel,避免当前的 mirror+commit
decode 重放;
4. 只有 `speedup_e2e > 1` 后再考虑 HTTP flag、continuous batching、sampling 或
gpt-oss speculative decoding。