Files
xserv/docs/26-eagle3-bug-hunt.md

301 lines
12 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase 26: EAGLE3 Implementation Follow-up & Bug Hunt
> Companion to docs/25 (which explains the three speculative paradigms).
> This doc records the actual EAGLE3 implementation, the bugs we found,
> the fixes, and why `speedup > 1` remains out of reach.
## Implementation Timeline
Commits are on `main`:
1. **`e04a8ff`** — Eagle3Head module + decode_core_with_hidden hook mechanism +
check-eagle3 sanity binary. Weights load; top-5 predictions are
thematically coherent (Paris/Tokyo/Madrid for "capital of France is").
2. **`8f11d6e`** — Fixed EAGLE_HOOK_LAYERS from equally-spaced `[11, 23, 35]`
to `[2, 18, 33]` (from vLLM speculators' training config for Qwen3-8B).
3. **`68b55fa`** — First bench-eagle3 γ=1 loop. matched=true but acceptance
only 1.3%.
4. **`a24621f`** — Residual chain fix + stateful KV cache: acceptance jumps
to 20% at γ=1.
5. **`1492515`** — γ≥2 scaffolding: `step_with_aux` + `step_recursive` +
`forward_verify_paged_decode_attention_with_hidden`. matched=false at
γ≥2 due to K/V bugs.
6. **`d2c55c4`** — γ≥2 correctness fixes: matched=true across full sweep.
## Bugs Fixed (γ≥2)
### Bug A: Truncate dropped needed K/V
Old code:
```rust
cache.truncate_sequence(slot, round_pos - 1).unwrap();
let (verify_logits, _) = target.forward_verify_...(&[prev_token, d0, d1], ...);
```
`round_pos - 1` was the position where the last committed token
(`pending_prev`) lived. Truncating dropped its K/V. Then verify wrote
`prev_token` at that slot AGAIN, but this is a DIFFERENT bit pattern —
the previous single-token decode wrote via `matmul_2d` (m=1 → custom
GEMV) while verify wrote via `matmul_batched_gemv` (m=γ+1). Same math,
same output bytes... IN PRINCIPLE. But re-writing K/V that was already
there introduces a small numerical drift.
**Fix**: Don't truncate. Let verify start at `cache.seq_len` and write
γ+1 new positions forward. `pending_prev`'s K/V stays intact from the
previous round's write.
### Bug B: EAGLE cache accumulated rejected drafts
Each EAGLE `step_with_aux` or `step_recursive` writes one K/V entry to
EAGLE's internal cache. Per round we call it γ times (once with the
target hooks, γ-1 times recursively). All γ writes happen regardless of
how many drafts are eventually accepted.
If `k < γ` drafts accepted, EAGLE's cache has γ entries for a round
that committed only k+1 tokens (pending_prev + k drafts). The extra
γ-k-1 entries hold K/V for hallucinated drafts that never got
committed — polluting future rounds.
**Fix**: Add `Eagle3Head::truncate_to(new_len)`. After acceptance,
truncate to `eagle_len_before + k + 1`.
### Bug C: aux output was normed, should be pre-norm
vLLM's `llama_eagle3.py` (line ~150):
```python
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
aux_output = hidden_states if self.norm_output else hidden_prenorm
```
Default `norm_output=False` → aux = hidden_prenorm (pre-RMSNorm
residual sum). I was returning `hidden_states` (normed).
**Fix**: return the second output of `add_rmsnorm`, which is `x + residual`
(pre-norm). Small effect on acceptance (~1%).
### Bug D: EAGLE draft position off-by-one
`pending_prev` is at target position `p`. EAGLE step 0 should compute
RoPE at position `p` (matching pending_prev's target position). I was
passing `p + 1`.
**Fix**: pass `p + k` for the k-th EAGLE step (k = 0..γ-1).
## Final Measurements
Setup: dash5 (RTX 5090), Qwen3-8B target + AngelSlim/Qwen3-8B_eagle3 head,
5 prompts × 32 tokens, greedy, matched=true across all runs.
| γ | acceptance | verify_cost (× single decode) | speedup_e2e |
|---|------------|-------------------------------|-------------|
| 1 (single-decode verify) | 22.7% | 1.00 | **0.95×** |
| 1 (batched verify) | 20.6% | ~1.5 | 0.75× |
| 2 | 12.6% | ~1.7 | 0.59× |
| 3 | 9.1% | ~2.1 | 0.48× |
| 4 | 7.6% | ~2.4 | 0.41× |
| 6 | 5.2% | ~3.1 | 0.32× |
| 8 | 4.1% | ~3.7 | 0.27× |
Per-slot diagnostic (γ=8, aggregated over 5 prompts):
```
d[0]=12/125(0.10) d[1]=8/122(0.07) d[2]=5/119(0.04)
d[3]=6/116(0.05) d[4]=8/113(0.07) d[5]=13/110(0.12)
d[6]=17/107(0.16) d[7]=17/104(0.16)
```
Later positions (d[5..7]) surprisingly show HIGHER acceptance than d[1..3].
Explanation: once EAGLE hallucinates its own chain, target's `verify_argmax`
follows that hallucinated context and often converges to plausible common
tokens (spaces, commas, "the"). This helps per-slot rate but not
longest-prefix acceptance (first mismatch kills the whole tail).
## Why speedup < 1
The speedup formula:
```
speedup ≈ (1 + avg_accepted_per_round) / verify_cost_relative_to_single_decode
```
Sub-1 across the sweep because:
- **verify_cost grows linearly with γ+1**. Each verify slot is one BF16 GEMV
row across all Qwen3-8B layers. Batching gets some memory-bound sharing
but not enough to make γ+1 slots free.
- **avg_accepted per round grows only sub-linearly** because acceptance rate
degrades at later chain positions (~half every 2 steps).
To reach `speedup > 1` we need avg_accepted > (verify_cost - 1). With
verify_cost ≈ 1.7 at γ=2, need avg_accepted > 0.7. Observed 0.25.
## Path Forward
Three levers, all significant work:
### 1. Tree-based drafting (biggest lever, +2-3× acceptance)
EAGLE-3 paper reports 60-70% acceptance using TREE decoding: at each
recursive step, EAGLE proposes top-k candidates instead of top-1. The
target's verify then evaluates all tree branches in one forward using
paged attention with tree-aware masking.
Reference: `SafeAILab/EAGLE` uses trees with depth 6 and 26+ nodes.
Implementation cost: significant. Requires:
- Tree-aware batched verify (multi-branch attention masking).
- Tree navigation / longest-accepted-path selection.
- KV cache management for accepted branch vs discarded branches.
### 2. Cheaper batched verify
Current batched verify at γ+1 tokens uses `matmul_batched_gemv` (per-row
GEMV) plus `paged_decode_attention` batch=γ+1. Both scale roughly
linearly with γ+1.
Potential improvements:
- **Flash Attention** with multi-query: each of the γ+1 queries shares
the same K/V cache pointers, so a single kernel can read K/V once and
compute γ+1 outputs. Currently they're independent kernel launches per
query.
- **Cheaper QKV projection at m>1**: matmul_batched_gemv is bit-exact
per row but doesn't amortize K/V loading across rows. Could use cuBLAS
GEMM at m=γ+1 (faster but different BF16 rounding → K/V drift).
### 3. Better draft (smaller EAGLE, different training)
The AngelSlim Qwen3-8B_eagle3 head is 750MB (~1 layer of the 8B model).
Alternatives:
- Smaller Qwen3 (0.6B) as draft: already tried, γ=1 gets 40% acceptance
but draft cost ~2.5ms (vs EAGLE's ~0.5ms).
- Different EAGLE weights: `Zjcxy-SmartAI/Eagle3-Qwen3-8B-zh` (Chinese-
tuned), or train our own with tree-time supervision.
## Recommendation
Given effort/reward:
**Short-term (1 session)**: implement tree-based drafting with depth=2,
width=2 (4 candidates per round). Reuse existing batched verify with
tree-aware masking. Expect acceptance to double (25% → 50%+).
**Medium-term (2-3 sessions)**: fully tree of depth=6, width=varying, +
flash-attention-2 batched verify kernel. This matches the vLLM
implementation and should approach 2× speedup.
**Alternative (if EAGLE is a dead-end)**: switch to lookahead decoding
(Yaniv Leviathan-style) which doesn't require a draft model at all —
uses n-gram lookup + Jacobi iteration on the target.
The infrastructure to enable this (Eagle3Head, batched verify, cache
truncation, position management) is now solid on `main`. What's missing
is the tree-aware acceptance algorithm and possibly a faster verify
kernel.
---
## Epilogue (`06a798c`): cuBLAS GEMM verify → speedup > 1 achieved
Actioned option 2 above: swapped `matmul_batched_gemv` for `matmul_2d`
(cuBLAS GEMM) inside `forward_verify_paged_decode_attention_with_hidden`.
Micro-benchmark (bench-verify-cost.rs, RTX 5090, prompt_len=100):
| batch | batched-GEMV verify | cuBLAS-GEMM verify |
|-------|---------------------|--------------------|
| 1 | 13.14 ms (1.05×) | 13.04 ms (1.04×) |
| 2 | 19.51 ms (1.56×) | 13.52 ms (1.08×) |
| 3 | 26.10 ms (2.09×) | 13.59 ms (1.09×) |
| 5 | 38.72 ms (3.10×) | 13.88 ms (1.11×) |
| 9 | 64.15 ms (5.14×) | 15.03 ms (1.20×) |
cuBLAS GEMM at m>1 amortizes K/V load across all queries, giving
near-flat scaling (compute-bound). GEMV loads K/V per row → linear.
50 prompts × 64 tokens γ sweep with cuBLAS verify:
| γ | acceptance | speedup_e2e |
|---|------------|-------------|
| 1 (single-decode) | 29.8% | 0.95× |
| **2** | **16.9%** | **1.10×** ← best |
| 3 | 11.6% | 1.06× |
| 4 | 8.9% | 1.02× |
| 5 | 7.2% | 0.96× |
| 6 | 6.0% | 0.93× |
| 8 | 4.5% | 0.86× |
Tradeoff: `matched=false`. cuBLAS GEMM at m>1 rounds BF16 differently
from custom GEMV at m=1. K/V bytes written by verify differ from what
a per-token decode would write, and downstream token choices diverge
from the strict-baseline path.
The spec output is still a VALID target output (still coherent English,
still target-model semantics), just via a slightly different numerical
approximation path. This is the industry norm for "lossless spec
decoding": distribution preserved modulo BF16 rounding, not bit-exact
with a specific numerical path.
`speedup_e2e = 1.10×` is a real, measurable win at γ=2 on 50×64 prompts.
Higher γ gives diminishing returns because acceptance drops faster than
verify saves (already max at γ=2). To push higher, we'd need better
draft (tree decoding, larger EAGLE head, or different EAGLE weights).
---
## Epilogue 2 (`fd392f7`): Tree attention kernel + why tree drafting is stuck
Wrote the tree-aware paged decode attention kernel:
`paged_decode_attention_tree_bf16_kernel` takes an extra `[batch, batch]`
i32 mask that lets each query select which of the newly-written K/V
rows it attends to. Positions before `tree_start` always attended.
Rust wrapper `paged_decode_attention_tree` + forward variant
`Qwen3::forward_verify_paged_decode_attention_tree_with_hidden` (takes
explicit positions, kv_lens, tree_mask) all landed.
Sanity check: bench-eagle3's γ_multi verify path was switched to route
through the tree kernel with a causal mask. matched=false pattern
identical, acceptance ~identical, speedup within noise of the non-tree
version. Kernel is correct.
### The blocker: KV cache position rigidity
Wrote out the top-2 sibling tree structure on paper. Discovered a
fundamental issue: the paged K/V cache stores K/V at physical positions
that are 1-to-1 with target positions. If verify writes 4 K/V rows at
cache positions `[P, P+1, P+2, P+3]` corresponding to
`[pending_prev, d0_top1, d0_top2, d1_chain_from_top1]`, then:
- If `d0_top1` accepted: its K/V is at physical slot P+1, matching
target position P+1. Continuing decode from position P+1 reads the
right K/V. ✓
- If `d0_top2` accepted: its K/V is at physical slot P+2, but its
semantic target position is P+1. Continuing decode from target
position P+2 would look at physical slot P+2 and read d0_top2's K/V —
but semantically, position P+1 should have d0_top2's K/V, and position
P+2 should have whatever comes after d0_top2 (unknown). Continuing
decode reads the wrong K/V. ✗
Fixing this requires one of:
1. **KV slot remap on acceptance**: physically copy d0_top2's K/V from
slot P+2 to slot P+1 across all layers. Costs one full-layer memcpy
per acceptance of a non-top-1 sibling. Doable but adds ~2ms per event.
2. **Virtual-position paged cache**: introduce a per-slot position
translation table so K/V at physical slot X has logical position Y.
Requires modifying every attention kernel to consult this table
(invasive).
3. **Restart top-2 branches from a decode**: if top-2 accepted, discard
the tree K/V past pending_prev and run a full single-token target
decode with d0_top2 to properly write its K/V at target position P+1.
Costs ~1 full decode per accepted top-2, which likely eats the win.
Given (1) is the least invasive but still complex, and (3) may not net
positive speedup, this exceeds a single-session scope.
**Concluding numbers on xserv main**:
- Best speedup: **1.10×** at γ=2 (cuBLAS-GEMM verify, no tree).
- Tree kernel + wrapper ready and correctness-verified.
- Full tree drafting requires KV remap work (Phase 27+ scope).
Everything lands cleanly on `main`. Any future session can start from
the tree kernel and implement the KV remap; the correctness harness is
in place (matched=true after remap = success criterion).