eagle3: γ≥2 correctness fixes + per-slot diagnostic

Two subtle bugs found and fixed in the γ≥2 speculative loop:

1. Wrong position handling: cache.truncate_sequence(round_pos - 1) was
   dropping the K/V of pending_prev, then verify OVERWROTE that slot with
   the wrong token. Removed the truncate: verify now starts at
   cache.seq_len (== position of pending_prev) and writes γ+1 tokens
   forward. Also fixed EAGLE draft positions: pending_prev is at position
   p, so step 0 uses position=p (not p+1).

2. EAGLE KV cache accumulated rejected drafts' K/V: each round writes γ
   entries to EAGLE's cache regardless of how many drafts were accepted.
   Added eagle.truncate_to(new_len) API. After each round, truncate to
   eagle_len_before + k + 1 (pending_prev + k accepted drafts).

Also expose Eagle3Head::current_len() getter and Eagle3Head::truncate_to().

Additionally: return the PRE-norm hidden state as aux (matching vllm's
llama_eagle3.py default `norm_output=False`). Was returning the normed
version.

Result: matched=true across the full γ sweep. speedup_e2e remains <1:

  γ=1 (single-decode verify): accept=22.7%, speedup=0.95x
  γ=1 (batched verify):       accept=20.6%, speedup=0.75x
  γ=2:                         accept=12.6%, speedup=0.59x
  γ=4:                         accept=7.6%,  speedup=0.41x
  γ=8:                         accept=4.1%,  speedup=0.27x

Per-slot diagnostic shows d[0]≈15%, d[1]≈8%, d[2..γ-1] varies. d[0] is
lower than γ=1's 20% because batched verify introduces small numerical
differences vs single-token decode.

Larger γ hurts because:
- verify_cost scales roughly linearly with γ+1 (batched matmul at
  batch=γ+1 costs ~γ+1× a single decode).
- accepted tokens per round grows sub-linearly (recursive EAGLE degrades).
- speedup ≈ (1 + accepted_avg) / verify_cost → below 1 across the sweep.

Path forward for speedup > 1 requires EITHER: (a) faster batched verify
(closer to single-decode cost per query row via better GPU utilization),
OR (b) better draft accuracy (tree-based drafting to explore multiple
candidates per position, larger EAGLE head, or a differently-trained
EAGLE variant).
This commit is contained in:
2026-07-01 19:16:31 +08:00
parent 14925154a3
commit d2c55c47b2
2 changed files with 168 additions and 59 deletions

View File

@@ -384,53 +384,59 @@ fn run_eagle_gamma_multi(
eagle.reset(); eagle.reset();
let t0 = Instant::now(); let t0 = Instant::now();
// Prefill. Advance one target decode step to seed hidden hooks. // Prefill: cache holds prompt K/V. Argmax of last row = first_token (target
// for position prompt_len).
let prefill_logits = target.forward_prefill_paged(prompt_ids, slot, cache); let prefill_logits = target.forward_prefill_paged(prompt_ids, slot, cache);
let first_token = last_argmax(&prefill_logits); let first_token = last_argmax(&prefill_logits);
let mut generated = vec![first_token];
// Seed: run one target decode with first_token to get hooks at that position. // Seed decode with first_token: writes its K/V at prompt_len, returns hooks
let (logits, mut seed_hooks) = target_decode_with_hidden(target, first_token, cache, slot); // for that position + logits whose argmax is target's answer for
let mut prev_token = last_argmax(&logits); // prompt_len+1 (this becomes our first pending_prev).
generated.push(prev_token); let (seed_logits, seed_hooks_v) = target_decode_with_hidden(target, first_token, cache, slot);
let mut target_steps = 2; // 1 prefill + 1 decode seed let mut generated = vec![first_token];
let mut pending_prev = last_argmax(&seed_logits);
let mut seed_hooks: [Tensor; 3] = seed_hooks_v;
let mut target_steps = 2; // 1 prefill + 1 seed decode
let mut accepted_total = 0usize; let mut accepted_total = 0usize;
let mut proposed_total = 0usize; let mut proposed_total = 0usize;
let mut per_slot_correct: Vec<usize> = vec![0; gamma];
let mut per_slot_total: Vec<usize> = vec![0; gamma];
while generated.len() < gen_tokens && !tokenizer.is_eos(prev_token) { while generated.len() + 1 < gen_tokens && !tokenizer.is_eos(pending_prev) {
let round_pos = cache.seq_len(slot); // position of next token to predict let p = cache.seq_len(slot); // position where pending_prev will be written
let remaining = gen_tokens - generated.len(); let remaining = gen_tokens - generated.len() - 1; // -1 because pending_prev counts
let round_gamma = gamma.min(remaining); let round_gamma = gamma.min(remaining);
if round_gamma == 0 {
break;
}
// Draft γ tokens recursively. // Draft γ tokens for positions p+1..p+γ. EAGLE step k takes the token
// at position p+k as input (pending_prev at p for k=0, previous draft
// d[k-1] at p+k for k>=1) and predicts token at position p+k+1.
// Snapshot EAGLE's cache len so we can roll back rejected drafts' K/V.
let eagle_len_before = eagle.current_len();
let mut drafts: Vec<u32> = Vec::with_capacity(round_gamma); let mut drafts: Vec<u32> = Vec::with_capacity(round_gamma);
let (d0, _, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, prev_token, round_pos); let (d0, _, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p);
drafts.push(d0); drafts.push(d0);
let mut prev_aux = aux0; let mut prev_aux = aux0;
let mut prev_draft = d0; let mut prev_draft = d0;
for k in 1..round_gamma { for k in 1..round_gamma {
let (dk, _, auxk) = let (dk, _, auxk) = eagle.step_recursive(prev_aux, embed_tokens, prev_draft, p + k);
eagle.step_recursive(prev_aux, embed_tokens, prev_draft, round_pos + k);
drafts.push(dk); drafts.push(dk);
prev_aux = auxk; prev_aux = auxk;
prev_draft = dk; prev_draft = dk;
} }
proposed_total += round_gamma; proposed_total += round_gamma;
// Verify: target forward on [prev_token, d0..d_{γ-1}] (length γ+1) to // Verify: run target on [pending_prev, d[0]..d[γ-1]] at positions
// get γ+1 logits. verify_argmax[i] predicts position round_pos+i. // [p..p+γ]. Writes γ+1 K/V rows. verify_argmax[i] predicts position
// Accept d[i] iff d[i] == verify_argmax[i]. The extra verify_argmax[γ] // p+i+1, i.e., verifies d[i].
// provides a free correction for the all-accepted case.
let mut verify_input: Vec<u32> = Vec::with_capacity(round_gamma + 1); let mut verify_input: Vec<u32> = Vec::with_capacity(round_gamma + 1);
verify_input.push(prev_token); verify_input.push(pending_prev);
for &d in drafts.iter() { for &d in drafts.iter() {
verify_input.push(d); verify_input.push(d);
} }
// Roll back the seed decode (we wrote prev_token at round_pos-1). let (verify_logits, verify_hooks) = target
cache
.truncate_sequence(slot, round_pos - 1)
.expect("truncate before verify");
let (verify_logits, _verify_hooks) = target
.forward_verify_paged_decode_attention_with_hidden( .forward_verify_paged_decode_attention_with_hidden(
&verify_input, &verify_input,
slot, slot,
@@ -438,46 +444,135 @@ fn run_eagle_gamma_multi(
&EAGLE_HOOK_LAYERS, &EAGLE_HOOK_LAYERS,
); );
target_steps += 1; target_steps += 1;
// cache.seq_len is now p + γ + 1.
// Longest matching prefix: d[i] accepted iff d[i] == argmax(verify_logits[i]).
let verify_argmax = argmax_rows(&verify_logits); let verify_argmax = argmax_rows(&verify_logits);
let mut accepted = 0usize; // Longest matching prefix of d[i] == verify_argmax[i].
while accepted < round_gamma && drafts[accepted] == verify_argmax[accepted] { // Also count per-slot: was d[i] correct independent of longest-prefix?
accepted += 1; let mut k = 0usize;
while k < round_gamma && drafts[k] == verify_argmax[k] {
k += 1;
}
accepted_total += k;
// Per-slot diagnostic: independent acceptance rate per position.
for (i, &d) in drafts.iter().enumerate() {
if d == verify_argmax[i] {
per_slot_correct[i] += 1;
}
per_slot_total[i] += 1;
} }
accepted_total += accepted;
// Commit accepted tokens (already in cache at positions round_pos..round_pos+accepted-1). // EAGLE wrote γ K/V entries this round at slots
// Truncate cache to that length. // eagle_len_before..eagle_len_before+γ. Slot i (in the new range) holds
cache // K/V for input at target position p+i (pending_prev at slot 0, drafts
.truncate_sequence(slot, round_pos + accepted) // at slots 1..γ-1... wait no: step_with_aux input pending_prev writes
.expect("truncate to accepted"); // at slot eagle_len_before with RoPE position p; step_recursive for
for i in 0..accepted { // d[0] writes at slot eagle_len_before+1 with RoPE position p+1; etc.
// So slot eagle_len_before+i holds pending_prev (i=0) or d[i-1]
// (i>=1). Accepted commits pending_prev + d[0..k-1] = k+1 EAGLE entries
// at slots eagle_len_before..eagle_len_before+k. Slots
// eagle_len_before+k+1..+γ-1 hold rejected drafts' K/V — discard.
let keep_len = eagle_len_before + k + 1;
if keep_len < eagle.current_len() {
eagle.truncate_to(keep_len);
}
// Commit pending_prev + accepted drafts. Positions p..p+k filled.
generated.push(pending_prev);
for i in 0..k {
generated.push(drafts[i]); generated.push(drafts[i]);
if generated.len() >= gen_tokens || tokenizer.is_eos(drafts[i]) { if tokenizer.is_eos(drafts[i]) {
break; // Truncate cache to keep only committed K/V.
cache
.truncate_sequence(slot, p + i + 2)
.expect("truncate on EOS");
return finalize(
generated,
t0,
target_steps,
accepted_total,
proposed_total,
cache,
slot,
);
}
if generated.len() >= gen_tokens {
cache
.truncate_sequence(slot, p + i + 2)
.expect("truncate on gen_tokens");
return finalize(
generated,
t0,
target_steps,
accepted_total,
proposed_total,
cache,
slot,
);
} }
} }
if generated.len() >= gen_tokens {
break;
}
if !generated.is_empty() && tokenizer.is_eos(*generated.last().unwrap()) {
break;
}
// Correction: verify_argmax[accepted] is target's answer for position // Truncate cache to drop unaccepted drafts' K/V. Keep positions 0..p+k.
// round_pos+accepted. Write its K/V via one target decode. This also // Cache now has committed tokens through position p+k (pending_prev at
// gives us seed_hooks for the next round. // p, drafts[0..k] at p+1..p+k).
let correction = verify_argmax[accepted]; cache
let (_, new_hooks) = target_decode_with_hidden(target, correction, cache, slot); .truncate_sequence(slot, p + k + 1)
target_steps += 1; .expect("truncate to accepted");
generated.push(correction);
prev_token = correction; // New pending_prev = verify_argmax[k]. seed_hooks = row k of verify_hooks.
// Extract the position-`accepted` slice of verify_hooks as the seed for next round? pending_prev = verify_argmax[k];
// Actually the correction just wrote K/V AT position round_pos+accepted, and let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
// its hidden state = new_hooks. We use new_hooks as the seed for the next round. for h in verify_hooks.iter() {
seed_hooks = new_hooks; new_hooks.push(h.narrow(0, k, 1).contiguous());
}
seed_hooks = [
new_hooks.remove(0),
new_hooks.remove(0),
new_hooks.remove(0),
];
} }
// Loop ended. If pending_prev wasn't committed yet, commit it now.
if generated.len() < gen_tokens && !tokenizer.is_eos(pending_prev) {
generated.push(pending_prev);
} else if generated.len() < gen_tokens {
generated.push(pending_prev);
}
eprint!("[per-slot d[i] correct/total: ");
for i in 0..gamma {
let rate = if per_slot_total[i] > 0 {
per_slot_correct[i] as f64 / per_slot_total[i] as f64
} else {
0.0
};
eprint!(
"{}={}/{}({:.2}) ",
i, per_slot_correct[i], per_slot_total[i], rate
);
}
eprintln!("]");
finalize(
generated,
t0,
target_steps,
accepted_total,
proposed_total,
cache,
slot,
)
}
fn finalize(
generated: Vec<u32>,
t0: Instant,
target_steps: usize,
accepted: usize,
proposed: usize,
cache: &mut PagedKVCache,
slot: usize,
) -> RunStats {
sync_device(); sync_device();
let total_s = t0.elapsed().as_secs_f64(); let total_s = t0.elapsed().as_secs_f64();
cache.free_sequence(slot); cache.free_sequence(slot);
@@ -485,8 +580,8 @@ fn run_eagle_gamma_multi(
ids: generated, ids: generated,
total_s, total_s,
target_steps, target_steps,
accepted: accepted_total, accepted,
proposed: proposed_total, proposed,
} }
} }

View File

@@ -150,6 +150,18 @@ impl Eagle3Head {
self.current_len = 0; self.current_len = 0;
} }
/// Truncate the internal KV cache to `new_len` entries. Used to discard
/// K/V of rejected drafts after a speculative round.
pub fn truncate_to(&mut self, new_len: usize) {
assert!(new_len <= self.current_len);
self.current_len = new_len;
}
/// Current number of committed K/V entries in the internal EAGLE cache.
pub fn current_len(&self) -> usize {
self.current_len
}
/// One draft step: produce a token in target vocabulary space. /// One draft step: produce a token in target vocabulary space.
/// ///
/// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers /// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers
@@ -248,12 +260,14 @@ impl Eagle3Head {
let hidden = silu_mul(&gate, &up); let hidden = silu_mul(&gate, &up);
let down = matmul_2d(&hidden, &self.down_proj_wt); let down = matmul_2d(&hidden, &self.down_proj_wt);
let (x, _) = add_rmsnorm(&down, &residual, &self.norm, eps); let (x, prenorm) = add_rmsnorm(&down, &residual, &self.norm, eps);
let logits = matmul_2d(&x, &self.lm_head_wt); let logits = matmul_2d(&x, &self.lm_head_wt);
let draft_id = argmax_bf16_single(&logits); let draft_id = argmax_bf16_single(&logits);
let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32; let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32;
(target_id, logits, x) // aux for recursive drafting = PRE-norm hidden (default norm_output=False
// in vllm/llama_eagle3.py). Feeding the pre-norm state matches training.
(target_id, logits, prenorm)
} }
/// Write new K/V rows (shape [1, num_kv_heads, head_dim]) at position /// Write new K/V rows (shape [1, num_kv_heads, head_dim]) at position