docs: Phase 24 investigation notes and revised speedup plan
Attempted the simple win — replace matmul_rows_gemv with matmul_2d in forward_verify_paged_decode_attention — and it worked (0.44x -> 0.68x on 5 prompts) but produced matched=false. Root cause is K/V drift, not just logit rounding: matmul_2d at m=1 uses the custom GEMV path, at m>=2 it uses cuBLAS GEMM, and the two produce different BF16 bits. Verify then writes K/V with GEMM values while baseline decode would have written GEMV values, and every downstream position drifts. A near-tie fallback for the current row's logit does nothing to fix already-diverged history, so it was reverted in the same session. Docs/24 captures the finding and lays out the actual path forward: implement a launch_gemv_bf16_batched kernel that runs gamma m=1 GEMVs in a single launch with bit-identical output to gamma sequential calls, then add draft-side CUDA graph and adaptive gamma. Also includes a back-of-envelope that shows current acceptance rate 0.39 + verify=13ms lands close to 1.0x speedup even with verify made free; hitting speedup_e2e > 1 needs launch-overhead savings AND either higher acceptance or a cheaper draft. Reverts: none (Phase 24 attempts never landed on main). Only the doc.
This commit is contained in:
144
docs/24-speculative-batched-verify.md
Normal file
144
docs/24-speculative-batched-verify.md
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
# Phase 24: Speculative Decoding Performance — target `speedup_e2e > 1`
|
||||||
|
|
||||||
|
> Status (2026-07-01): investigation-in-progress. Baseline reproduced,
|
||||||
|
> naive batched-GEMM verify attempted, K/V drift issue identified,
|
||||||
|
> concrete next-step designs written up. **Nothing landed on main yet.**
|
||||||
|
|
||||||
|
## 1. Baseline (Phase 23, verified on dash5)
|
||||||
|
|
||||||
|
`--prompts 50 --gen-tokens 64 --gamma 4 --use-verify-logits`:
|
||||||
|
|
||||||
|
- `acceptance_rate = 0.39`
|
||||||
|
- `matched = true`, `verify_decode_mismatches = 0`
|
||||||
|
- `spec_e2e_tpot_ms = 30.07`, `baseline_e2e_tpot_ms = 13.09`
|
||||||
|
- **`speedup_e2e = 0.44×`**
|
||||||
|
- `tokens_per_target_step = 0.91`
|
||||||
|
|
||||||
|
5-prompt sanity re-run reproduces the same shape (~0.44×), so the
|
||||||
|
Phase 23 correctness state machine is intact after the recent CUDA
|
||||||
|
determinism fixes (`5f06090`).
|
||||||
|
|
||||||
|
## 2. Cost budget & the ceiling
|
||||||
|
|
||||||
|
Rough numbers on 5090 TP=1:
|
||||||
|
- `baseline decode`: ~12.6 ms / token (Qwen3-8B BF16, paged).
|
||||||
|
- `draft decode` (Qwen3-0.6B): ~2.5 ms / token (rough estimate).
|
||||||
|
- `verify` (Phase 23 row-GEMV, γ=4): ~13 ms.
|
||||||
|
|
||||||
|
Best-case per accepted spec token cost with acceptance α, γ tokens
|
||||||
|
per round:
|
||||||
|
```
|
||||||
|
spec_time_per_token ≈ (γ · draft + verify + correction) / (1 + α · γ)
|
||||||
|
```
|
||||||
|
With draft=2.5, verify=13, correction≈13, α=0.4, γ=4:
|
||||||
|
```
|
||||||
|
spec_time_per_token ≈ (10 + 13 + 13) / (1 + 1.6) ≈ 13.8 ms/token
|
||||||
|
```
|
||||||
|
Baseline is 12.6 ms/token. **Even with the row-GEMV verify perfectly
|
||||||
|
free, current acceptance rate 0.39 gives us at best ~1× speedup.**
|
||||||
|
|
||||||
|
## 3. What we tried (2026-07-01)
|
||||||
|
|
||||||
|
Naive Phase 24: replace `matmul_rows_gemv` in
|
||||||
|
`forward_verify_paged_decode_attention` with `matmul_2d` (batched
|
||||||
|
cuBLAS GEMM). Result on 5 prompts × 32 tokens:
|
||||||
|
|
||||||
|
- `speedup_e2e = 0.68×` (up from 0.44×) — verify itself much faster.
|
||||||
|
- **`matched = false` on 3/5 prompts** — divergence at multiple
|
||||||
|
positions per failed prompt, not just first mismatch.
|
||||||
|
|
||||||
|
Root cause: **K/V drift, not logit rounding**.
|
||||||
|
|
||||||
|
`matmul_2d` at `m=1` routes through the custom `launch_gemv_bf16`
|
||||||
|
kernel; at `m≥2` it goes through cuBLAS `GemmEx`. Those two paths
|
||||||
|
produce **different BF16 bits** for the same math because their
|
||||||
|
accumulation orders differ. Therefore:
|
||||||
|
|
||||||
|
- Verify's QKV projection at `m=γ` writes K/V into the paged cache
|
||||||
|
with cuBLAS-GEMM values.
|
||||||
|
- Baseline decode's QKV projection at `m=1` would have written K/V
|
||||||
|
with GEMV values.
|
||||||
|
- Downstream attention reads these K/V; the two paths diverge starting
|
||||||
|
at the very next position. A near-tie fallback for the *current*
|
||||||
|
row's logit does not fix already-diverged history.
|
||||||
|
|
||||||
|
Near-tie fallback (added and reverted in the same session, kept only
|
||||||
|
in this doc) attempted to correct verify-argmax when top1−top2 was
|
||||||
|
small. It did nothing about the K/V drift, so mismatches persisted.
|
||||||
|
|
||||||
|
## 4. Revised path to `speedup_e2e > 1`
|
||||||
|
|
||||||
|
Two independent levers. Combining them is the plan.
|
||||||
|
|
||||||
|
### 4.1 A batched-GEMV kernel with GEMV-identical numerics
|
||||||
|
|
||||||
|
Write a `launch_gemv_bf16_batched` that runs γ separate `m=1` GEMVs in
|
||||||
|
a **single kernel launch**, sharing the K panel across rows and
|
||||||
|
producing bit-exact-same output as γ sequential `launch_gemv_bf16`
|
||||||
|
calls. This gives Phase 24's launch-overhead savings without breaking
|
||||||
|
K/V bits. Estimated saving vs row-loop: ~2–4 ms per verify at γ=4
|
||||||
|
(720 fewer launches × 3–5 μs each).
|
||||||
|
|
||||||
|
Concrete kernel design:
|
||||||
|
- Grid: `(N / TILE_N, num_k_blocks, γ)` — same layout as current
|
||||||
|
gemv, plus γ in the z-axis.
|
||||||
|
- Each block reads its row's `x[γ_idx, :]` panel once, then writes
|
||||||
|
`partials[γ_idx, k_block, n_tile]`.
|
||||||
|
- Reduction kernel: `(N / TILE_N, γ)`, reduces K-blocks in fixed
|
||||||
|
order per row (same as current `gemv_reduce_to_bf16_kernel`).
|
||||||
|
|
||||||
|
Bit-exact-with-m=1 verification: run the γ=1 special case through the
|
||||||
|
new kernel and compare to `launch_gemv_bf16`; must be bit-identical.
|
||||||
|
|
||||||
|
### 4.2 Reduce verify + correction cost — draft-side CUDA graph
|
||||||
|
|
||||||
|
Draft decode is currently a full eager Qwen3-0.6B forward per γ step.
|
||||||
|
Wrapping γ draft steps into a CUDA graph (Phase 21 already did this
|
||||||
|
for gpt-oss target decode) cuts launch overhead here too. Estimated:
|
||||||
|
~1–1.5 ms per γ=4 window.
|
||||||
|
|
||||||
|
### 4.3 Adaptive γ
|
||||||
|
|
||||||
|
Currently γ=4 fixed. When acceptance drops in a "hard" section, γ=4
|
||||||
|
wastes 3 draft steps per round. Track a moving average of acceptance
|
||||||
|
per round; if the last N rounds averaged below τ, drop γ to 2 or 1
|
||||||
|
(equivalent to disabling spec). If it climbs above τ_high, restore.
|
||||||
|
|
||||||
|
## 5. Revised acceptance criteria
|
||||||
|
|
||||||
|
1. `cargo fmt && cargo check && cargo test` on dash5.
|
||||||
|
2. `bench-speculative --prompts 50 --gen-tokens 64 --gamma 4 --use-verify-logits`:
|
||||||
|
- `matched = true`
|
||||||
|
- `verify_decode_mismatches = 0`
|
||||||
|
- **`speedup_e2e > 1.0`**
|
||||||
|
3. GSM8K-50 (if time permits) token-identical with baseline.
|
||||||
|
|
||||||
|
## 6. What's on main today
|
||||||
|
|
||||||
|
- `5f06090`: fixed flash decode kernel atomicAdd nondeterminism + two
|
||||||
|
int32 overflow bugs (causal_mask, dequant_fp8).
|
||||||
|
- `ce10e4a`: sampling NaN-safe on top-k/top-p path.
|
||||||
|
- `d96ee07`: API sampling validation + finish_reason normalization +
|
||||||
|
bounded engine channel + 4 MiB body limit.
|
||||||
|
|
||||||
|
The Phase 24 attempt (batched matmul_2d in verify) is **not** on
|
||||||
|
main. It was verified to be functionally incorrect and reverted in
|
||||||
|
the same session; only this design doc landed.
|
||||||
|
|
||||||
|
## 7. Next actions
|
||||||
|
|
||||||
|
In order:
|
||||||
|
|
||||||
|
1. Implement `launch_gemv_bf16_batched` + Rust wrapper `matmul_2d_gemv_batched`.
|
||||||
|
2. Numerical parity test: γ sequential row-GEMVs vs one batched call
|
||||||
|
must be bit-exact for BF16 inputs.
|
||||||
|
3. Swap `matmul_rows_gemv` in `forward_verify_paged_decode_attention`
|
||||||
|
for the batched variant.
|
||||||
|
4. Re-run `bench-speculative` 50×64; expect `matched=true` and
|
||||||
|
`speedup_e2e` climbing from 0.44× toward the 1.0× ceiling
|
||||||
|
established by 4.1's launch-overhead savings alone.
|
||||||
|
5. If still <1×, layer on 4.2 (draft CUDA graph) and 4.3 (adaptive γ).
|
||||||
|
6. If still <1× after 4.1–4.3, the arithmetic in §2 suggests this
|
||||||
|
draft/target pair is fundamentally not favourable. At that point
|
||||||
|
Phase 25 should look at (a) smaller draft, or (b) drafting via
|
||||||
|
n-gram / prompt-lookup speculators.
|
||||||
Reference in New Issue
Block a user