diff --git a/docs/24-speculative-batched-verify.md b/docs/24-speculative-batched-verify.md new file mode 100644 index 0000000..9411143 --- /dev/null +++ b/docs/24-speculative-batched-verify.md @@ -0,0 +1,144 @@ +# Phase 24: Speculative Decoding Performance — target `speedup_e2e > 1` + +> Status (2026-07-01): investigation-in-progress. Baseline reproduced, +> naive batched-GEMM verify attempted, K/V drift issue identified, +> concrete next-step designs written up. **Nothing landed on main yet.** + +## 1. Baseline (Phase 23, verified on dash5) + +`--prompts 50 --gen-tokens 64 --gamma 4 --use-verify-logits`: + +- `acceptance_rate = 0.39` +- `matched = true`, `verify_decode_mismatches = 0` +- `spec_e2e_tpot_ms = 30.07`, `baseline_e2e_tpot_ms = 13.09` +- **`speedup_e2e = 0.44×`** +- `tokens_per_target_step = 0.91` + +5-prompt sanity re-run reproduces the same shape (~0.44×), so the +Phase 23 correctness state machine is intact after the recent CUDA +determinism fixes (`5f06090`). + +## 2. Cost budget & the ceiling + +Rough numbers on 5090 TP=1: +- `baseline decode`: ~12.6 ms / token (Qwen3-8B BF16, paged). +- `draft decode` (Qwen3-0.6B): ~2.5 ms / token (rough estimate). +- `verify` (Phase 23 row-GEMV, γ=4): ~13 ms. + +Best-case per accepted spec token cost with acceptance α, γ tokens +per round: +``` +spec_time_per_token ≈ (γ · draft + verify + correction) / (1 + α · γ) +``` +With draft=2.5, verify=13, correction≈13, α=0.4, γ=4: +``` +spec_time_per_token ≈ (10 + 13 + 13) / (1 + 1.6) ≈ 13.8 ms/token +``` +Baseline is 12.6 ms/token. **Even with the row-GEMV verify perfectly +free, current acceptance rate 0.39 gives us at best ~1× speedup.** + +## 3. What we tried (2026-07-01) + +Naive Phase 24: replace `matmul_rows_gemv` in +`forward_verify_paged_decode_attention` with `matmul_2d` (batched +cuBLAS GEMM). Result on 5 prompts × 32 tokens: + +- `speedup_e2e = 0.68×` (up from 0.44×) — verify itself much faster. +- **`matched = false` on 3/5 prompts** — divergence at multiple + positions per failed prompt, not just first mismatch. + +Root cause: **K/V drift, not logit rounding**. + +`matmul_2d` at `m=1` routes through the custom `launch_gemv_bf16` +kernel; at `m≥2` it goes through cuBLAS `GemmEx`. Those two paths +produce **different BF16 bits** for the same math because their +accumulation orders differ. Therefore: + +- Verify's QKV projection at `m=γ` writes K/V into the paged cache + with cuBLAS-GEMM values. +- Baseline decode's QKV projection at `m=1` would have written K/V + with GEMV values. +- Downstream attention reads these K/V; the two paths diverge starting + at the very next position. A near-tie fallback for the *current* + row's logit does not fix already-diverged history. + +Near-tie fallback (added and reverted in the same session, kept only +in this doc) attempted to correct verify-argmax when top1−top2 was +small. It did nothing about the K/V drift, so mismatches persisted. + +## 4. Revised path to `speedup_e2e > 1` + +Two independent levers. Combining them is the plan. + +### 4.1 A batched-GEMV kernel with GEMV-identical numerics + +Write a `launch_gemv_bf16_batched` that runs γ separate `m=1` GEMVs in +a **single kernel launch**, sharing the K panel across rows and +producing bit-exact-same output as γ sequential `launch_gemv_bf16` +calls. This gives Phase 24's launch-overhead savings without breaking +K/V bits. Estimated saving vs row-loop: ~2–4 ms per verify at γ=4 +(720 fewer launches × 3–5 μs each). + +Concrete kernel design: +- Grid: `(N / TILE_N, num_k_blocks, γ)` — same layout as current + gemv, plus γ in the z-axis. +- Each block reads its row's `x[γ_idx, :]` panel once, then writes + `partials[γ_idx, k_block, n_tile]`. +- Reduction kernel: `(N / TILE_N, γ)`, reduces K-blocks in fixed + order per row (same as current `gemv_reduce_to_bf16_kernel`). + +Bit-exact-with-m=1 verification: run the γ=1 special case through the +new kernel and compare to `launch_gemv_bf16`; must be bit-identical. + +### 4.2 Reduce verify + correction cost — draft-side CUDA graph + +Draft decode is currently a full eager Qwen3-0.6B forward per γ step. +Wrapping γ draft steps into a CUDA graph (Phase 21 already did this +for gpt-oss target decode) cuts launch overhead here too. Estimated: +~1–1.5 ms per γ=4 window. + +### 4.3 Adaptive γ + +Currently γ=4 fixed. When acceptance drops in a "hard" section, γ=4 +wastes 3 draft steps per round. Track a moving average of acceptance +per round; if the last N rounds averaged below τ, drop γ to 2 or 1 +(equivalent to disabling spec). If it climbs above τ_high, restore. + +## 5. Revised acceptance criteria + +1. `cargo fmt && cargo check && cargo test` on dash5. +2. `bench-speculative --prompts 50 --gen-tokens 64 --gamma 4 --use-verify-logits`: + - `matched = true` + - `verify_decode_mismatches = 0` + - **`speedup_e2e > 1.0`** +3. GSM8K-50 (if time permits) token-identical with baseline. + +## 6. What's on main today + +- `5f06090`: fixed flash decode kernel atomicAdd nondeterminism + two + int32 overflow bugs (causal_mask, dequant_fp8). +- `ce10e4a`: sampling NaN-safe on top-k/top-p path. +- `d96ee07`: API sampling validation + finish_reason normalization + + bounded engine channel + 4 MiB body limit. + +The Phase 24 attempt (batched matmul_2d in verify) is **not** on +main. It was verified to be functionally incorrect and reverted in +the same session; only this design doc landed. + +## 7. Next actions + +In order: + +1. Implement `launch_gemv_bf16_batched` + Rust wrapper `matmul_2d_gemv_batched`. +2. Numerical parity test: γ sequential row-GEMVs vs one batched call + must be bit-exact for BF16 inputs. +3. Swap `matmul_rows_gemv` in `forward_verify_paged_decode_attention` + for the batched variant. +4. Re-run `bench-speculative` 50×64; expect `matched=true` and + `speedup_e2e` climbing from 0.44× toward the 1.0× ceiling + established by 4.1's launch-overhead savings alone. +5. If still <1×, layer on 4.2 (draft CUDA graph) and 4.3 (adaptive γ). +6. If still <1× after 4.1–4.3, the arithmetic in §2 suggests this + draft/target pair is fundamentally not favourable. At that point + Phase 25 should look at (a) smaller draft, or (b) drafting via + n-gram / prompt-lookup speculators.