Files
xserv/docs/26-eagle3-bug-hunt.md

240 lines
9.5 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase 26: EAGLE3 Implementation Follow-up & Bug Hunt
> Companion to docs/25 (which explains the three speculative paradigms).
> This doc records the actual EAGLE3 implementation, the bugs we found,
> the fixes, and why `speedup > 1` remains out of reach.
## Implementation Timeline
Commits are on `main`:
1. **`e04a8ff`** — Eagle3Head module + decode_core_with_hidden hook mechanism +
check-eagle3 sanity binary. Weights load; top-5 predictions are
thematically coherent (Paris/Tokyo/Madrid for "capital of France is").
2. **`8f11d6e`** — Fixed EAGLE_HOOK_LAYERS from equally-spaced `[11, 23, 35]`
to `[2, 18, 33]` (from vLLM speculators' training config for Qwen3-8B).
3. **`68b55fa`** — First bench-eagle3 γ=1 loop. matched=true but acceptance
only 1.3%.
4. **`a24621f`** — Residual chain fix + stateful KV cache: acceptance jumps
to 20% at γ=1.
5. **`1492515`** — γ≥2 scaffolding: `step_with_aux` + `step_recursive` +
`forward_verify_paged_decode_attention_with_hidden`. matched=false at
γ≥2 due to K/V bugs.
6. **`d2c55c4`** — γ≥2 correctness fixes: matched=true across full sweep.
## Bugs Fixed (γ≥2)
### Bug A: Truncate dropped needed K/V
Old code:
```rust
cache.truncate_sequence(slot, round_pos - 1).unwrap();
let (verify_logits, _) = target.forward_verify_...(&[prev_token, d0, d1], ...);
```
`round_pos - 1` was the position where the last committed token
(`pending_prev`) lived. Truncating dropped its K/V. Then verify wrote
`prev_token` at that slot AGAIN, but this is a DIFFERENT bit pattern —
the previous single-token decode wrote via `matmul_2d` (m=1 → custom
GEMV) while verify wrote via `matmul_batched_gemv` (m=γ+1). Same math,
same output bytes... IN PRINCIPLE. But re-writing K/V that was already
there introduces a small numerical drift.
**Fix**: Don't truncate. Let verify start at `cache.seq_len` and write
γ+1 new positions forward. `pending_prev`'s K/V stays intact from the
previous round's write.
### Bug B: EAGLE cache accumulated rejected drafts
Each EAGLE `step_with_aux` or `step_recursive` writes one K/V entry to
EAGLE's internal cache. Per round we call it γ times (once with the
target hooks, γ-1 times recursively). All γ writes happen regardless of
how many drafts are eventually accepted.
If `k < γ` drafts accepted, EAGLE's cache has γ entries for a round
that committed only k+1 tokens (pending_prev + k drafts). The extra
γ-k-1 entries hold K/V for hallucinated drafts that never got
committed — polluting future rounds.
**Fix**: Add `Eagle3Head::truncate_to(new_len)`. After acceptance,
truncate to `eagle_len_before + k + 1`.
### Bug C: aux output was normed, should be pre-norm
vLLM's `llama_eagle3.py` (line ~150):
```python
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
aux_output = hidden_states if self.norm_output else hidden_prenorm
```
Default `norm_output=False` → aux = hidden_prenorm (pre-RMSNorm
residual sum). I was returning `hidden_states` (normed).
**Fix**: return the second output of `add_rmsnorm`, which is `x + residual`
(pre-norm). Small effect on acceptance (~1%).
### Bug D: EAGLE draft position off-by-one
`pending_prev` is at target position `p`. EAGLE step 0 should compute
RoPE at position `p` (matching pending_prev's target position). I was
passing `p + 1`.
**Fix**: pass `p + k` for the k-th EAGLE step (k = 0..γ-1).
## Final Measurements
Setup: dash5 (RTX 5090), Qwen3-8B target + AngelSlim/Qwen3-8B_eagle3 head,
5 prompts × 32 tokens, greedy, matched=true across all runs.
| γ | acceptance | verify_cost (× single decode) | speedup_e2e |
|---|------------|-------------------------------|-------------|
| 1 (single-decode verify) | 22.7% | 1.00 | **0.95×** |
| 1 (batched verify) | 20.6% | ~1.5 | 0.75× |
| 2 | 12.6% | ~1.7 | 0.59× |
| 3 | 9.1% | ~2.1 | 0.48× |
| 4 | 7.6% | ~2.4 | 0.41× |
| 6 | 5.2% | ~3.1 | 0.32× |
| 8 | 4.1% | ~3.7 | 0.27× |
Per-slot diagnostic (γ=8, aggregated over 5 prompts):
```
d[0]=12/125(0.10) d[1]=8/122(0.07) d[2]=5/119(0.04)
d[3]=6/116(0.05) d[4]=8/113(0.07) d[5]=13/110(0.12)
d[6]=17/107(0.16) d[7]=17/104(0.16)
```
Later positions (d[5..7]) surprisingly show HIGHER acceptance than d[1..3].
Explanation: once EAGLE hallucinates its own chain, target's `verify_argmax`
follows that hallucinated context and often converges to plausible common
tokens (spaces, commas, "the"). This helps per-slot rate but not
longest-prefix acceptance (first mismatch kills the whole tail).
## Why speedup < 1
The speedup formula:
```
speedup ≈ (1 + avg_accepted_per_round) / verify_cost_relative_to_single_decode
```
Sub-1 across the sweep because:
- **verify_cost grows linearly with γ+1**. Each verify slot is one BF16 GEMV
row across all Qwen3-8B layers. Batching gets some memory-bound sharing
but not enough to make γ+1 slots free.
- **avg_accepted per round grows only sub-linearly** because acceptance rate
degrades at later chain positions (~half every 2 steps).
To reach `speedup > 1` we need avg_accepted > (verify_cost - 1). With
verify_cost ≈ 1.7 at γ=2, need avg_accepted > 0.7. Observed 0.25.
## Path Forward
Three levers, all significant work:
### 1. Tree-based drafting (biggest lever, +2-3× acceptance)
EAGLE-3 paper reports 60-70% acceptance using TREE decoding: at each
recursive step, EAGLE proposes top-k candidates instead of top-1. The
target's verify then evaluates all tree branches in one forward using
paged attention with tree-aware masking.
Reference: `SafeAILab/EAGLE` uses trees with depth 6 and 26+ nodes.
Implementation cost: significant. Requires:
- Tree-aware batched verify (multi-branch attention masking).
- Tree navigation / longest-accepted-path selection.
- KV cache management for accepted branch vs discarded branches.
### 2. Cheaper batched verify
Current batched verify at γ+1 tokens uses `matmul_batched_gemv` (per-row
GEMV) plus `paged_decode_attention` batch=γ+1. Both scale roughly
linearly with γ+1.
Potential improvements:
- **Flash Attention** with multi-query: each of the γ+1 queries shares
the same K/V cache pointers, so a single kernel can read K/V once and
compute γ+1 outputs. Currently they're independent kernel launches per
query.
- **Cheaper QKV projection at m>1**: matmul_batched_gemv is bit-exact
per row but doesn't amortize K/V loading across rows. Could use cuBLAS
GEMM at m=γ+1 (faster but different BF16 rounding → K/V drift).
### 3. Better draft (smaller EAGLE, different training)
The AngelSlim Qwen3-8B_eagle3 head is 750MB (~1 layer of the 8B model).
Alternatives:
- Smaller Qwen3 (0.6B) as draft: already tried, γ=1 gets 40% acceptance
but draft cost ~2.5ms (vs EAGLE's ~0.5ms).
- Different EAGLE weights: `Zjcxy-SmartAI/Eagle3-Qwen3-8B-zh` (Chinese-
tuned), or train our own with tree-time supervision.
## Recommendation
Given effort/reward:
**Short-term (1 session)**: implement tree-based drafting with depth=2,
width=2 (4 candidates per round). Reuse existing batched verify with
tree-aware masking. Expect acceptance to double (25% → 50%+).
**Medium-term (2-3 sessions)**: fully tree of depth=6, width=varying, +
flash-attention-2 batched verify kernel. This matches the vLLM
implementation and should approach 2× speedup.
**Alternative (if EAGLE is a dead-end)**: switch to lookahead decoding
(Yaniv Leviathan-style) which doesn't require a draft model at all —
uses n-gram lookup + Jacobi iteration on the target.
The infrastructure to enable this (Eagle3Head, batched verify, cache
truncation, position management) is now solid on `main`. What's missing
is the tree-aware acceptance algorithm and possibly a faster verify
kernel.
---
## Epilogue (`06a798c`): cuBLAS GEMM verify → speedup > 1 achieved
Actioned option 2 above: swapped `matmul_batched_gemv` for `matmul_2d`
(cuBLAS GEMM) inside `forward_verify_paged_decode_attention_with_hidden`.
Micro-benchmark (bench-verify-cost.rs, RTX 5090, prompt_len=100):
| batch | batched-GEMV verify | cuBLAS-GEMM verify |
|-------|---------------------|--------------------|
| 1 | 13.14 ms (1.05×) | 13.04 ms (1.04×) |
| 2 | 19.51 ms (1.56×) | 13.52 ms (1.08×) |
| 3 | 26.10 ms (2.09×) | 13.59 ms (1.09×) |
| 5 | 38.72 ms (3.10×) | 13.88 ms (1.11×) |
| 9 | 64.15 ms (5.14×) | 15.03 ms (1.20×) |
cuBLAS GEMM at m>1 amortizes K/V load across all queries, giving
near-flat scaling (compute-bound). GEMV loads K/V per row → linear.
50 prompts × 64 tokens γ sweep with cuBLAS verify:
| γ | acceptance | speedup_e2e |
|---|------------|-------------|
| 1 (single-decode) | 29.8% | 0.95× |
| **2** | **16.9%** | **1.10×** ← best |
| 3 | 11.6% | 1.06× |
| 4 | 8.9% | 1.02× |
| 5 | 7.2% | 0.96× |
| 6 | 6.0% | 0.93× |
| 8 | 4.5% | 0.86× |
Tradeoff: `matched=false`. cuBLAS GEMM at m>1 rounds BF16 differently
from custom GEMV at m=1. K/V bytes written by verify differ from what
a per-token decode would write, and downstream token choices diverge
from the strict-baseline path.
The spec output is still a VALID target output (still coherent English,
still target-model semantics), just via a slightly different numerical
approximation path. This is the industry norm for "lossless spec
decoding": distribution preserved modulo BF16 rounding, not bit-exact
with a specific numerical path.
`speedup_e2e = 1.10×` is a real, measurable win at γ=2 on 50×64 prompts.
Higher γ gives diminishing returns because acceptance drops faster than
verify saves (already max at γ=2). To push higher, we'd need better
draft (tree decoding, larger EAGLE head, or different EAGLE weights).