Phase 22 lands a correctness-only speculative decoding loop for Qwen3 target + Qwen3 small draft (batch=1, greedy, gamma=4). Phase 23 turns verify logits into the authoritative acceptance signal so mirror-decode per accepted token is no longer needed. - paged_kv_cache: truncate_sequence(slot, new_len) shrinks a registered sequence, freeing whole physical blocks no longer reachable and leaving the slot registered. Covered by a CUDA-gated unit test. - qwen3: forward_verify_paged_decode_attention writes the draft window into the target cache, runs the same paged decode attention kernel per draft token, and uses matmul_rows_gemv so linear layers follow the single-token decode BF16 rounding path. - bench-speculative: new bench binary drives the state machine with --gamma / --gen-tokens / --prompts / --use-verify-logits / --verify-path flash|paged-decode / --dump-verify-mismatches, and compares baseline vs spec token sequences plus TPOT / tok/s / speedup. - docs/22 records the decode-authoritative v0 result and dash5 numbers (matched=true, speedup_e2e ~0.29x, verify_decode_mismatches>0 under --use-verify-logits). - docs/23 records the paged-decode verify path (matched=true, verify_decode_mismatches=0, 50x64 speedup_e2e ~0.44x) and the next-step performance TODO.
3.2 KiB
3.2 KiB
Phase 23: Speculative Verify Parity
目标:把 speculative decoding 从 v0 的 correctness-only 状态机推进到 "verify logits 可作为权威接受依据"。本阶段仍只覆盖 Qwen3 target + Qwen3 small draft、batch=1、greedy。
1. Problem
Phase 22 的默认模式用逐 token target decode 作为权威路径,因此输出能与 baseline
一致。但诊断 --use-verify-logits 会失败:target 对 draft window 做 batched
prefill verify 时,部分 logits top-1 与逐 token decode 不一致。
实测 top-k 显示分叉不是大幅数值错误,而是 BF16 near-tie:
verify_top5=17689:24.500,9856:24.375,...
decode_top5=9856:24.500,17689:24.500,...
如果直接用这些 verify logits 接受/拒绝 draft token,greedy token 序列会偏离纯 target decode。
2. Design
新增 Qwen3::forward_verify_paged_decode_attention:
- 在 target commit cache 上一次写入 draft window 的 K/V;
- attention 使用现有 paged decode attention,每个 draft token 对应一行 metadata,
context lens 分别为
pos + 1; - 线性层使用逐行 GEMV,与
forward_decode_paged的 BF16 rounding path 对齐; - 若 token 全接受,直接保留 verify 写入的 KV;
- 若在第
k个 token 拒绝,把 target cache truncate 到 accepted prefix,再只 decode 一个 correction token。
bench 新增:
--use-verify-logits:用 verify logits 作为接受依据,默认选择paged-decodeverify path;--verify-path flash|paged-decode:显式选择旧 flash prefill 诊断或新 paged-decode verify path;--dump-verify-mismatches:打印 mismatch 行 top-k,用于定位 near-tie。
3. Validation
dash5:
- GPU:RTX 5090,device 0;
- target:
/opt/wjh/models/qwen3-8b; - draft:
/dashscope-tmp/wjh/models/qwen3-0.6b; - command:
bench-speculative ... --prompts 50 --gen-tokens 64 --gamma 4 --device 0 --use-verify-logits; - log:
/dashscope-tmp/wjh/xserv-spec-inplace-verify-50x64.log。
结果:
prompts=50 matched=true
acceptance_mode=verify_logits
verify_path=paged-decode
acceptance_rate=0.3927 accepted=2120 proposed=5398
tokens_per_target_step=0.9112 target_steps=3512
verify_steps=1376 mirror_decode_steps=0 commit_decode_steps=1068 correction_steps=1068
verify_decode_mismatches=0
baseline_e2e_tpot_ms=13.094 baseline_e2e_tok_s=76.372
spec_e2e_tpot_ms=30.069 spec_e2e_tok_s=33.257 speedup_e2e=0.4355
baseline_decode_tpot_ms=12.846 baseline_decode_tok_s=77.844
spec_decode_tpot_ms=29.731 spec_decode_tok_s=33.635 speedup_decode=0.4321
decode_token_counts baseline=3200 spec=3200
对比 Phase 22 的保守 decode-authoritative v0:
- verify logits 现在可以作为权威接受依据;
mirror_decode_steps从每个 accepted token 一次降为 0;- 50x64 e2e speedup 从约 0.29x 提升到 0.44x;
- 仍未超过 baseline,因为 verify path 为了 parity 使用逐行 GEMV,且 draft acceptance 只有约 39%。
4. Next TODO
下一阶段要从 correctness parity 转向性能:
- 逐层替换 row-GEMV 为 batched GEMM,同时保留 near-tie fallback 或 top-k audit;
- 加一个
--verify-audit-decode低频抽样审计,避免每轮都做 target decode; - 扫
gamma与 draft 选择,记录 acceptance 与 TPOT 曲线; speedup_e2e > 1前不接 HTTP/continuous batching/gpt-oss spec。