Files
xserv/docs/24-speculative-batched-verify.md
Gahow Wang 42e13f33dd docs: Phase 24 investigation notes and revised speedup plan
Attempted the simple win — replace matmul_rows_gemv with matmul_2d in
forward_verify_paged_decode_attention — and it worked (0.44x -> 0.68x
on 5 prompts) but produced matched=false. Root cause is K/V drift, not
just logit rounding: matmul_2d at m=1 uses the custom GEMV path, at
m>=2 it uses cuBLAS GEMM, and the two produce different BF16 bits.
Verify then writes K/V with GEMM values while baseline decode would
have written GEMV values, and every downstream position drifts.

A near-tie fallback for the current row's logit does nothing to fix
already-diverged history, so it was reverted in the same session.

Docs/24 captures the finding and lays out the actual path forward:
implement a launch_gemv_bf16_batched kernel that runs gamma m=1 GEMVs
in a single launch with bit-identical output to gamma sequential
calls, then add draft-side CUDA graph and adaptive gamma.

Also includes a back-of-envelope that shows current acceptance rate
0.39 + verify=13ms lands close to 1.0x speedup even with verify made
free; hitting speedup_e2e > 1 needs launch-overhead savings AND either
higher acceptance or a cheaper draft.

Reverts: none (Phase 24 attempts never landed on main). Only the doc.
2026-07-01 15:35:11 +08:00

5.8 KiB
Raw Blame History

Phase 24: Speculative Decoding Performance — target speedup_e2e > 1

Status (2026-07-01): investigation-in-progress. Baseline reproduced, naive batched-GEMM verify attempted, K/V drift issue identified, concrete next-step designs written up. Nothing landed on main yet.

1. Baseline (Phase 23, verified on dash5)

--prompts 50 --gen-tokens 64 --gamma 4 --use-verify-logits:

  • acceptance_rate = 0.39
  • matched = true, verify_decode_mismatches = 0
  • spec_e2e_tpot_ms = 30.07, baseline_e2e_tpot_ms = 13.09
  • speedup_e2e = 0.44×
  • tokens_per_target_step = 0.91

5-prompt sanity re-run reproduces the same shape (~0.44×), so the Phase 23 correctness state machine is intact after the recent CUDA determinism fixes (5f06090).

2. Cost budget & the ceiling

Rough numbers on 5090 TP=1:

  • baseline decode: ~12.6 ms / token (Qwen3-8B BF16, paged).
  • draft decode (Qwen3-0.6B): ~2.5 ms / token (rough estimate).
  • verify (Phase 23 row-GEMV, γ=4): ~13 ms.

Best-case per accepted spec token cost with acceptance α, γ tokens per round:

spec_time_per_token ≈ (γ · draft + verify + correction) / (1 + α · γ)

With draft=2.5, verify=13, correction≈13, α=0.4, γ=4:

spec_time_per_token ≈ (10 + 13 + 13) / (1 + 1.6) ≈ 13.8 ms/token

Baseline is 12.6 ms/token. Even with the row-GEMV verify perfectly free, current acceptance rate 0.39 gives us at best ~1× speedup.

3. What we tried (2026-07-01)

Naive Phase 24: replace matmul_rows_gemv in forward_verify_paged_decode_attention with matmul_2d (batched cuBLAS GEMM). Result on 5 prompts × 32 tokens:

  • speedup_e2e = 0.68× (up from 0.44×) — verify itself much faster.
  • matched = false on 3/5 prompts — divergence at multiple positions per failed prompt, not just first mismatch.

Root cause: K/V drift, not logit rounding.

matmul_2d at m=1 routes through the custom launch_gemv_bf16 kernel; at m≥2 it goes through cuBLAS GemmEx. Those two paths produce different BF16 bits for the same math because their accumulation orders differ. Therefore:

  • Verify's QKV projection at m=γ writes K/V into the paged cache with cuBLAS-GEMM values.
  • Baseline decode's QKV projection at m=1 would have written K/V with GEMV values.
  • Downstream attention reads these K/V; the two paths diverge starting at the very next position. A near-tie fallback for the current row's logit does not fix already-diverged history.

Near-tie fallback (added and reverted in the same session, kept only in this doc) attempted to correct verify-argmax when top1top2 was small. It did nothing about the K/V drift, so mismatches persisted.

4. Revised path to speedup_e2e > 1

Two independent levers. Combining them is the plan.

4.1 A batched-GEMV kernel with GEMV-identical numerics

Write a launch_gemv_bf16_batched that runs γ separate m=1 GEMVs in a single kernel launch, sharing the K panel across rows and producing bit-exact-same output as γ sequential launch_gemv_bf16 calls. This gives Phase 24's launch-overhead savings without breaking K/V bits. Estimated saving vs row-loop: ~24 ms per verify at γ=4 (720 fewer launches × 35 μs each).

Concrete kernel design:

  • Grid: (N / TILE_N, num_k_blocks, γ) — same layout as current gemv, plus γ in the z-axis.
  • Each block reads its row's x[γ_idx, :] panel once, then writes partials[γ_idx, k_block, n_tile].
  • Reduction kernel: (N / TILE_N, γ), reduces K-blocks in fixed order per row (same as current gemv_reduce_to_bf16_kernel).

Bit-exact-with-m=1 verification: run the γ=1 special case through the new kernel and compare to launch_gemv_bf16; must be bit-identical.

4.2 Reduce verify + correction cost — draft-side CUDA graph

Draft decode is currently a full eager Qwen3-0.6B forward per γ step. Wrapping γ draft steps into a CUDA graph (Phase 21 already did this for gpt-oss target decode) cuts launch overhead here too. Estimated: ~11.5 ms per γ=4 window.

4.3 Adaptive γ

Currently γ=4 fixed. When acceptance drops in a "hard" section, γ=4 wastes 3 draft steps per round. Track a moving average of acceptance per round; if the last N rounds averaged below τ, drop γ to 2 or 1 (equivalent to disabling spec). If it climbs above τ_high, restore.

5. Revised acceptance criteria

  1. cargo fmt && cargo check && cargo test on dash5.
  2. bench-speculative --prompts 50 --gen-tokens 64 --gamma 4 --use-verify-logits:
    • matched = true
    • verify_decode_mismatches = 0
    • speedup_e2e > 1.0
  3. GSM8K-50 (if time permits) token-identical with baseline.

6. What's on main today

  • 5f06090: fixed flash decode kernel atomicAdd nondeterminism + two int32 overflow bugs (causal_mask, dequant_fp8).
  • ce10e4a: sampling NaN-safe on top-k/top-p path.
  • d96ee07: API sampling validation + finish_reason normalization + bounded engine channel + 4 MiB body limit.

The Phase 24 attempt (batched matmul_2d in verify) is not on main. It was verified to be functionally incorrect and reverted in the same session; only this design doc landed.

7. Next actions

In order:

  1. Implement launch_gemv_bf16_batched + Rust wrapper matmul_2d_gemv_batched.
  2. Numerical parity test: γ sequential row-GEMVs vs one batched call must be bit-exact for BF16 inputs.
  3. Swap matmul_rows_gemv in forward_verify_paged_decode_attention for the batched variant.
  4. Re-run bench-speculative 50×64; expect matched=true and speedup_e2e climbing from 0.44× toward the 1.0× ceiling established by 4.1's launch-overhead savings alone.
  5. If still <1×, layer on 4.2 (draft CUDA graph) and 4.3 (adaptive γ).
  6. If still <1× after 4.14.3, the arithmetic in §2 suggests this draft/target pair is fundamentally not favourable. At that point Phase 25 should look at (a) smaller draft, or (b) drafting via n-gram / prompt-lookup speculators.