Files
xserv/docs/24-speculative-batched-verify.md
Gahow Wang 42e13f33dd docs: Phase 24 investigation notes and revised speedup plan
Attempted the simple win — replace matmul_rows_gemv with matmul_2d in
forward_verify_paged_decode_attention — and it worked (0.44x -> 0.68x
on 5 prompts) but produced matched=false. Root cause is K/V drift, not
just logit rounding: matmul_2d at m=1 uses the custom GEMV path, at
m>=2 it uses cuBLAS GEMM, and the two produce different BF16 bits.
Verify then writes K/V with GEMM values while baseline decode would
have written GEMV values, and every downstream position drifts.

A near-tie fallback for the current row's logit does nothing to fix
already-diverged history, so it was reverted in the same session.

Docs/24 captures the finding and lays out the actual path forward:
implement a launch_gemv_bf16_batched kernel that runs gamma m=1 GEMVs
in a single launch with bit-identical output to gamma sequential
calls, then add draft-side CUDA graph and adaptive gamma.

Also includes a back-of-envelope that shows current acceptance rate
0.39 + verify=13ms lands close to 1.0x speedup even with verify made
free; hitting speedup_e2e > 1 needs launch-overhead savings AND either
higher acceptance or a cheaper draft.

Reverts: none (Phase 24 attempts never landed on main). Only the doc.
2026-07-01 15:35:11 +08:00

145 lines
5.8 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase 24: Speculative Decoding Performance — target `speedup_e2e > 1`
> Status (2026-07-01): investigation-in-progress. Baseline reproduced,
> naive batched-GEMM verify attempted, K/V drift issue identified,
> concrete next-step designs written up. **Nothing landed on main yet.**
## 1. Baseline (Phase 23, verified on dash5)
`--prompts 50 --gen-tokens 64 --gamma 4 --use-verify-logits`:
- `acceptance_rate = 0.39`
- `matched = true`, `verify_decode_mismatches = 0`
- `spec_e2e_tpot_ms = 30.07`, `baseline_e2e_tpot_ms = 13.09`
- **`speedup_e2e = 0.44×`**
- `tokens_per_target_step = 0.91`
5-prompt sanity re-run reproduces the same shape (~0.44×), so the
Phase 23 correctness state machine is intact after the recent CUDA
determinism fixes (`5f06090`).
## 2. Cost budget & the ceiling
Rough numbers on 5090 TP=1:
- `baseline decode`: ~12.6 ms / token (Qwen3-8B BF16, paged).
- `draft decode` (Qwen3-0.6B): ~2.5 ms / token (rough estimate).
- `verify` (Phase 23 row-GEMV, γ=4): ~13 ms.
Best-case per accepted spec token cost with acceptance α, γ tokens
per round:
```
spec_time_per_token ≈ (γ · draft + verify + correction) / (1 + α · γ)
```
With draft=2.5, verify=13, correction≈13, α=0.4, γ=4:
```
spec_time_per_token ≈ (10 + 13 + 13) / (1 + 1.6) ≈ 13.8 ms/token
```
Baseline is 12.6 ms/token. **Even with the row-GEMV verify perfectly
free, current acceptance rate 0.39 gives us at best ~1× speedup.**
## 3. What we tried (2026-07-01)
Naive Phase 24: replace `matmul_rows_gemv` in
`forward_verify_paged_decode_attention` with `matmul_2d` (batched
cuBLAS GEMM). Result on 5 prompts × 32 tokens:
- `speedup_e2e = 0.68×` (up from 0.44×) — verify itself much faster.
- **`matched = false` on 3/5 prompts** — divergence at multiple
positions per failed prompt, not just first mismatch.
Root cause: **K/V drift, not logit rounding**.
`matmul_2d` at `m=1` routes through the custom `launch_gemv_bf16`
kernel; at `m≥2` it goes through cuBLAS `GemmEx`. Those two paths
produce **different BF16 bits** for the same math because their
accumulation orders differ. Therefore:
- Verify's QKV projection at `m=γ` writes K/V into the paged cache
with cuBLAS-GEMM values.
- Baseline decode's QKV projection at `m=1` would have written K/V
with GEMV values.
- Downstream attention reads these K/V; the two paths diverge starting
at the very next position. A near-tie fallback for the *current*
row's logit does not fix already-diverged history.
Near-tie fallback (added and reverted in the same session, kept only
in this doc) attempted to correct verify-argmax when top1top2 was
small. It did nothing about the K/V drift, so mismatches persisted.
## 4. Revised path to `speedup_e2e > 1`
Two independent levers. Combining them is the plan.
### 4.1 A batched-GEMV kernel with GEMV-identical numerics
Write a `launch_gemv_bf16_batched` that runs γ separate `m=1` GEMVs in
a **single kernel launch**, sharing the K panel across rows and
producing bit-exact-same output as γ sequential `launch_gemv_bf16`
calls. This gives Phase 24's launch-overhead savings without breaking
K/V bits. Estimated saving vs row-loop: ~24 ms per verify at γ=4
(720 fewer launches × 35 μs each).
Concrete kernel design:
- Grid: `(N / TILE_N, num_k_blocks, γ)` — same layout as current
gemv, plus γ in the z-axis.
- Each block reads its row's `x[γ_idx, :]` panel once, then writes
`partials[γ_idx, k_block, n_tile]`.
- Reduction kernel: `(N / TILE_N, γ)`, reduces K-blocks in fixed
order per row (same as current `gemv_reduce_to_bf16_kernel`).
Bit-exact-with-m=1 verification: run the γ=1 special case through the
new kernel and compare to `launch_gemv_bf16`; must be bit-identical.
### 4.2 Reduce verify + correction cost — draft-side CUDA graph
Draft decode is currently a full eager Qwen3-0.6B forward per γ step.
Wrapping γ draft steps into a CUDA graph (Phase 21 already did this
for gpt-oss target decode) cuts launch overhead here too. Estimated:
~11.5 ms per γ=4 window.
### 4.3 Adaptive γ
Currently γ=4 fixed. When acceptance drops in a "hard" section, γ=4
wastes 3 draft steps per round. Track a moving average of acceptance
per round; if the last N rounds averaged below τ, drop γ to 2 or 1
(equivalent to disabling spec). If it climbs above τ_high, restore.
## 5. Revised acceptance criteria
1. `cargo fmt && cargo check && cargo test` on dash5.
2. `bench-speculative --prompts 50 --gen-tokens 64 --gamma 4 --use-verify-logits`:
- `matched = true`
- `verify_decode_mismatches = 0`
- **`speedup_e2e > 1.0`**
3. GSM8K-50 (if time permits) token-identical with baseline.
## 6. What's on main today
- `5f06090`: fixed flash decode kernel atomicAdd nondeterminism + two
int32 overflow bugs (causal_mask, dequant_fp8).
- `ce10e4a`: sampling NaN-safe on top-k/top-p path.
- `d96ee07`: API sampling validation + finish_reason normalization +
bounded engine channel + 4 MiB body limit.
The Phase 24 attempt (batched matmul_2d in verify) is **not** on
main. It was verified to be functionally incorrect and reverted in
the same session; only this design doc landed.
## 7. Next actions
In order:
1. Implement `launch_gemv_bf16_batched` + Rust wrapper `matmul_2d_gemv_batched`.
2. Numerical parity test: γ sequential row-GEMVs vs one batched call
must be bit-exact for BF16 inputs.
3. Swap `matmul_rows_gemv` in `forward_verify_paged_decode_attention`
for the batched variant.
4. Re-run `bench-speculative` 50×64; expect `matched=true` and
`speedup_e2e` climbing from 0.44× toward the 1.0× ceiling
established by 4.1's launch-overhead savings alone.
5. If still <1×, layer on 4.2 (draft CUDA graph) and 4.3 (adaptive γ).
6. If still <1× after 4.14.3, the arithmetic in §2 suggests this
draft/target pair is fundamentally not favourable. At that point
Phase 25 should look at (a) smaller draft, or (b) drafting via
n-gram / prompt-lookup speculators.