eagle3: γ≥2 correctness fixes + per-slot diagnostic
Two subtle bugs found and fixed in the γ≥2 speculative loop: 1. Wrong position handling: cache.truncate_sequence(round_pos - 1) was dropping the K/V of pending_prev, then verify OVERWROTE that slot with the wrong token. Removed the truncate: verify now starts at cache.seq_len (== position of pending_prev) and writes γ+1 tokens forward. Also fixed EAGLE draft positions: pending_prev is at position p, so step 0 uses position=p (not p+1). 2. EAGLE KV cache accumulated rejected drafts' K/V: each round writes γ entries to EAGLE's cache regardless of how many drafts were accepted. Added eagle.truncate_to(new_len) API. After each round, truncate to eagle_len_before + k + 1 (pending_prev + k accepted drafts). Also expose Eagle3Head::current_len() getter and Eagle3Head::truncate_to(). Additionally: return the PRE-norm hidden state as aux (matching vllm's llama_eagle3.py default `norm_output=False`). Was returning the normed version. Result: matched=true across the full γ sweep. speedup_e2e remains <1: γ=1 (single-decode verify): accept=22.7%, speedup=0.95x γ=1 (batched verify): accept=20.6%, speedup=0.75x γ=2: accept=12.6%, speedup=0.59x γ=4: accept=7.6%, speedup=0.41x γ=8: accept=4.1%, speedup=0.27x Per-slot diagnostic shows d[0]≈15%, d[1]≈8%, d[2..γ-1] varies. d[0] is lower than γ=1's 20% because batched verify introduces small numerical differences vs single-token decode. Larger γ hurts because: - verify_cost scales roughly linearly with γ+1 (batched matmul at batch=γ+1 costs ~γ+1× a single decode). - accepted tokens per round grows sub-linearly (recursive EAGLE degrades). - speedup ≈ (1 + accepted_avg) / verify_cost → below 1 across the sweep. Path forward for speedup > 1 requires EITHER: (a) faster batched verify (closer to single-decode cost per query row via better GPU utilization), OR (b) better draft accuracy (tree-based drafting to explore multiple candidates per position, larger EAGLE head, or a differently-trained EAGLE variant).
This commit is contained in:
@@ -384,53 +384,59 @@ fn run_eagle_gamma_multi(
|
|||||||
eagle.reset();
|
eagle.reset();
|
||||||
let t0 = Instant::now();
|
let t0 = Instant::now();
|
||||||
|
|
||||||
// Prefill. Advance one target decode step to seed hidden hooks.
|
// Prefill: cache holds prompt K/V. Argmax of last row = first_token (target
|
||||||
|
// for position prompt_len).
|
||||||
let prefill_logits = target.forward_prefill_paged(prompt_ids, slot, cache);
|
let prefill_logits = target.forward_prefill_paged(prompt_ids, slot, cache);
|
||||||
let first_token = last_argmax(&prefill_logits);
|
let first_token = last_argmax(&prefill_logits);
|
||||||
let mut generated = vec![first_token];
|
|
||||||
|
|
||||||
// Seed: run one target decode with first_token to get hooks at that position.
|
// Seed decode with first_token: writes its K/V at prompt_len, returns hooks
|
||||||
let (logits, mut seed_hooks) = target_decode_with_hidden(target, first_token, cache, slot);
|
// for that position + logits whose argmax is target's answer for
|
||||||
let mut prev_token = last_argmax(&logits);
|
// prompt_len+1 (this becomes our first pending_prev).
|
||||||
generated.push(prev_token);
|
let (seed_logits, seed_hooks_v) = target_decode_with_hidden(target, first_token, cache, slot);
|
||||||
let mut target_steps = 2; // 1 prefill + 1 decode seed
|
let mut generated = vec![first_token];
|
||||||
|
let mut pending_prev = last_argmax(&seed_logits);
|
||||||
|
let mut seed_hooks: [Tensor; 3] = seed_hooks_v;
|
||||||
|
let mut target_steps = 2; // 1 prefill + 1 seed decode
|
||||||
let mut accepted_total = 0usize;
|
let mut accepted_total = 0usize;
|
||||||
let mut proposed_total = 0usize;
|
let mut proposed_total = 0usize;
|
||||||
|
let mut per_slot_correct: Vec<usize> = vec![0; gamma];
|
||||||
|
let mut per_slot_total: Vec<usize> = vec![0; gamma];
|
||||||
|
|
||||||
while generated.len() < gen_tokens && !tokenizer.is_eos(prev_token) {
|
while generated.len() + 1 < gen_tokens && !tokenizer.is_eos(pending_prev) {
|
||||||
let round_pos = cache.seq_len(slot); // position of next token to predict
|
let p = cache.seq_len(slot); // position where pending_prev will be written
|
||||||
let remaining = gen_tokens - generated.len();
|
let remaining = gen_tokens - generated.len() - 1; // -1 because pending_prev counts
|
||||||
let round_gamma = gamma.min(remaining);
|
let round_gamma = gamma.min(remaining);
|
||||||
|
if round_gamma == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// Draft γ tokens recursively.
|
// Draft γ tokens for positions p+1..p+γ. EAGLE step k takes the token
|
||||||
|
// at position p+k as input (pending_prev at p for k=0, previous draft
|
||||||
|
// d[k-1] at p+k for k>=1) and predicts token at position p+k+1.
|
||||||
|
// Snapshot EAGLE's cache len so we can roll back rejected drafts' K/V.
|
||||||
|
let eagle_len_before = eagle.current_len();
|
||||||
let mut drafts: Vec<u32> = Vec::with_capacity(round_gamma);
|
let mut drafts: Vec<u32> = Vec::with_capacity(round_gamma);
|
||||||
let (d0, _, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, prev_token, round_pos);
|
let (d0, _, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p);
|
||||||
drafts.push(d0);
|
drafts.push(d0);
|
||||||
let mut prev_aux = aux0;
|
let mut prev_aux = aux0;
|
||||||
let mut prev_draft = d0;
|
let mut prev_draft = d0;
|
||||||
for k in 1..round_gamma {
|
for k in 1..round_gamma {
|
||||||
let (dk, _, auxk) =
|
let (dk, _, auxk) = eagle.step_recursive(prev_aux, embed_tokens, prev_draft, p + k);
|
||||||
eagle.step_recursive(prev_aux, embed_tokens, prev_draft, round_pos + k);
|
|
||||||
drafts.push(dk);
|
drafts.push(dk);
|
||||||
prev_aux = auxk;
|
prev_aux = auxk;
|
||||||
prev_draft = dk;
|
prev_draft = dk;
|
||||||
}
|
}
|
||||||
proposed_total += round_gamma;
|
proposed_total += round_gamma;
|
||||||
|
|
||||||
// Verify: target forward on [prev_token, d0..d_{γ-1}] (length γ+1) to
|
// Verify: run target on [pending_prev, d[0]..d[γ-1]] at positions
|
||||||
// get γ+1 logits. verify_argmax[i] predicts position round_pos+i.
|
// [p..p+γ]. Writes γ+1 K/V rows. verify_argmax[i] predicts position
|
||||||
// Accept d[i] iff d[i] == verify_argmax[i]. The extra verify_argmax[γ]
|
// p+i+1, i.e., verifies d[i].
|
||||||
// provides a free correction for the all-accepted case.
|
|
||||||
let mut verify_input: Vec<u32> = Vec::with_capacity(round_gamma + 1);
|
let mut verify_input: Vec<u32> = Vec::with_capacity(round_gamma + 1);
|
||||||
verify_input.push(prev_token);
|
verify_input.push(pending_prev);
|
||||||
for &d in drafts.iter() {
|
for &d in drafts.iter() {
|
||||||
verify_input.push(d);
|
verify_input.push(d);
|
||||||
}
|
}
|
||||||
// Roll back the seed decode (we wrote prev_token at round_pos-1).
|
let (verify_logits, verify_hooks) = target
|
||||||
cache
|
|
||||||
.truncate_sequence(slot, round_pos - 1)
|
|
||||||
.expect("truncate before verify");
|
|
||||||
let (verify_logits, _verify_hooks) = target
|
|
||||||
.forward_verify_paged_decode_attention_with_hidden(
|
.forward_verify_paged_decode_attention_with_hidden(
|
||||||
&verify_input,
|
&verify_input,
|
||||||
slot,
|
slot,
|
||||||
@@ -438,46 +444,135 @@ fn run_eagle_gamma_multi(
|
|||||||
&EAGLE_HOOK_LAYERS,
|
&EAGLE_HOOK_LAYERS,
|
||||||
);
|
);
|
||||||
target_steps += 1;
|
target_steps += 1;
|
||||||
|
// cache.seq_len is now p + γ + 1.
|
||||||
|
|
||||||
// Longest matching prefix: d[i] accepted iff d[i] == argmax(verify_logits[i]).
|
|
||||||
let verify_argmax = argmax_rows(&verify_logits);
|
let verify_argmax = argmax_rows(&verify_logits);
|
||||||
let mut accepted = 0usize;
|
// Longest matching prefix of d[i] == verify_argmax[i].
|
||||||
while accepted < round_gamma && drafts[accepted] == verify_argmax[accepted] {
|
// Also count per-slot: was d[i] correct independent of longest-prefix?
|
||||||
accepted += 1;
|
let mut k = 0usize;
|
||||||
|
while k < round_gamma && drafts[k] == verify_argmax[k] {
|
||||||
|
k += 1;
|
||||||
|
}
|
||||||
|
accepted_total += k;
|
||||||
|
// Per-slot diagnostic: independent acceptance rate per position.
|
||||||
|
for (i, &d) in drafts.iter().enumerate() {
|
||||||
|
if d == verify_argmax[i] {
|
||||||
|
per_slot_correct[i] += 1;
|
||||||
|
}
|
||||||
|
per_slot_total[i] += 1;
|
||||||
}
|
}
|
||||||
accepted_total += accepted;
|
|
||||||
|
|
||||||
// Commit accepted tokens (already in cache at positions round_pos..round_pos+accepted-1).
|
// EAGLE wrote γ K/V entries this round at slots
|
||||||
// Truncate cache to that length.
|
// eagle_len_before..eagle_len_before+γ. Slot i (in the new range) holds
|
||||||
cache
|
// K/V for input at target position p+i (pending_prev at slot 0, drafts
|
||||||
.truncate_sequence(slot, round_pos + accepted)
|
// at slots 1..γ-1... wait no: step_with_aux input pending_prev writes
|
||||||
.expect("truncate to accepted");
|
// at slot eagle_len_before with RoPE position p; step_recursive for
|
||||||
for i in 0..accepted {
|
// d[0] writes at slot eagle_len_before+1 with RoPE position p+1; etc.
|
||||||
|
// So slot eagle_len_before+i holds pending_prev (i=0) or d[i-1]
|
||||||
|
// (i>=1). Accepted commits pending_prev + d[0..k-1] = k+1 EAGLE entries
|
||||||
|
// at slots eagle_len_before..eagle_len_before+k. Slots
|
||||||
|
// eagle_len_before+k+1..+γ-1 hold rejected drafts' K/V — discard.
|
||||||
|
let keep_len = eagle_len_before + k + 1;
|
||||||
|
if keep_len < eagle.current_len() {
|
||||||
|
eagle.truncate_to(keep_len);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit pending_prev + accepted drafts. Positions p..p+k filled.
|
||||||
|
generated.push(pending_prev);
|
||||||
|
for i in 0..k {
|
||||||
generated.push(drafts[i]);
|
generated.push(drafts[i]);
|
||||||
if generated.len() >= gen_tokens || tokenizer.is_eos(drafts[i]) {
|
if tokenizer.is_eos(drafts[i]) {
|
||||||
break;
|
// Truncate cache to keep only committed K/V.
|
||||||
|
cache
|
||||||
|
.truncate_sequence(slot, p + i + 2)
|
||||||
|
.expect("truncate on EOS");
|
||||||
|
return finalize(
|
||||||
|
generated,
|
||||||
|
t0,
|
||||||
|
target_steps,
|
||||||
|
accepted_total,
|
||||||
|
proposed_total,
|
||||||
|
cache,
|
||||||
|
slot,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if generated.len() >= gen_tokens {
|
||||||
|
cache
|
||||||
|
.truncate_sequence(slot, p + i + 2)
|
||||||
|
.expect("truncate on gen_tokens");
|
||||||
|
return finalize(
|
||||||
|
generated,
|
||||||
|
t0,
|
||||||
|
target_steps,
|
||||||
|
accepted_total,
|
||||||
|
proposed_total,
|
||||||
|
cache,
|
||||||
|
slot,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if generated.len() >= gen_tokens {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if !generated.is_empty() && tokenizer.is_eos(*generated.last().unwrap()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Correction: verify_argmax[accepted] is target's answer for position
|
// Truncate cache to drop unaccepted drafts' K/V. Keep positions 0..p+k.
|
||||||
// round_pos+accepted. Write its K/V via one target decode. This also
|
// Cache now has committed tokens through position p+k (pending_prev at
|
||||||
// gives us seed_hooks for the next round.
|
// p, drafts[0..k] at p+1..p+k).
|
||||||
let correction = verify_argmax[accepted];
|
cache
|
||||||
let (_, new_hooks) = target_decode_with_hidden(target, correction, cache, slot);
|
.truncate_sequence(slot, p + k + 1)
|
||||||
target_steps += 1;
|
.expect("truncate to accepted");
|
||||||
generated.push(correction);
|
|
||||||
prev_token = correction;
|
// New pending_prev = verify_argmax[k]. seed_hooks = row k of verify_hooks.
|
||||||
// Extract the position-`accepted` slice of verify_hooks as the seed for next round?
|
pending_prev = verify_argmax[k];
|
||||||
// Actually the correction just wrote K/V AT position round_pos+accepted, and
|
let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
|
||||||
// its hidden state = new_hooks. We use new_hooks as the seed for the next round.
|
for h in verify_hooks.iter() {
|
||||||
seed_hooks = new_hooks;
|
new_hooks.push(h.narrow(0, k, 1).contiguous());
|
||||||
|
}
|
||||||
|
seed_hooks = [
|
||||||
|
new_hooks.remove(0),
|
||||||
|
new_hooks.remove(0),
|
||||||
|
new_hooks.remove(0),
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Loop ended. If pending_prev wasn't committed yet, commit it now.
|
||||||
|
if generated.len() < gen_tokens && !tokenizer.is_eos(pending_prev) {
|
||||||
|
generated.push(pending_prev);
|
||||||
|
} else if generated.len() < gen_tokens {
|
||||||
|
generated.push(pending_prev);
|
||||||
|
}
|
||||||
|
|
||||||
|
eprint!("[per-slot d[i] correct/total: ");
|
||||||
|
for i in 0..gamma {
|
||||||
|
let rate = if per_slot_total[i] > 0 {
|
||||||
|
per_slot_correct[i] as f64 / per_slot_total[i] as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
eprint!(
|
||||||
|
"{}={}/{}({:.2}) ",
|
||||||
|
i, per_slot_correct[i], per_slot_total[i], rate
|
||||||
|
);
|
||||||
|
}
|
||||||
|
eprintln!("]");
|
||||||
|
|
||||||
|
finalize(
|
||||||
|
generated,
|
||||||
|
t0,
|
||||||
|
target_steps,
|
||||||
|
accepted_total,
|
||||||
|
proposed_total,
|
||||||
|
cache,
|
||||||
|
slot,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn finalize(
|
||||||
|
generated: Vec<u32>,
|
||||||
|
t0: Instant,
|
||||||
|
target_steps: usize,
|
||||||
|
accepted: usize,
|
||||||
|
proposed: usize,
|
||||||
|
cache: &mut PagedKVCache,
|
||||||
|
slot: usize,
|
||||||
|
) -> RunStats {
|
||||||
sync_device();
|
sync_device();
|
||||||
let total_s = t0.elapsed().as_secs_f64();
|
let total_s = t0.elapsed().as_secs_f64();
|
||||||
cache.free_sequence(slot);
|
cache.free_sequence(slot);
|
||||||
@@ -485,8 +580,8 @@ fn run_eagle_gamma_multi(
|
|||||||
ids: generated,
|
ids: generated,
|
||||||
total_s,
|
total_s,
|
||||||
target_steps,
|
target_steps,
|
||||||
accepted: accepted_total,
|
accepted,
|
||||||
proposed: proposed_total,
|
proposed,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -150,6 +150,18 @@ impl Eagle3Head {
|
|||||||
self.current_len = 0;
|
self.current_len = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Truncate the internal KV cache to `new_len` entries. Used to discard
|
||||||
|
/// K/V of rejected drafts after a speculative round.
|
||||||
|
pub fn truncate_to(&mut self, new_len: usize) {
|
||||||
|
assert!(new_len <= self.current_len);
|
||||||
|
self.current_len = new_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Current number of committed K/V entries in the internal EAGLE cache.
|
||||||
|
pub fn current_len(&self) -> usize {
|
||||||
|
self.current_len
|
||||||
|
}
|
||||||
|
|
||||||
/// One draft step: produce a token in target vocabulary space.
|
/// One draft step: produce a token in target vocabulary space.
|
||||||
///
|
///
|
||||||
/// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers
|
/// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers
|
||||||
@@ -248,12 +260,14 @@ impl Eagle3Head {
|
|||||||
let hidden = silu_mul(&gate, &up);
|
let hidden = silu_mul(&gate, &up);
|
||||||
let down = matmul_2d(&hidden, &self.down_proj_wt);
|
let down = matmul_2d(&hidden, &self.down_proj_wt);
|
||||||
|
|
||||||
let (x, _) = add_rmsnorm(&down, &residual, &self.norm, eps);
|
let (x, prenorm) = add_rmsnorm(&down, &residual, &self.norm, eps);
|
||||||
let logits = matmul_2d(&x, &self.lm_head_wt);
|
let logits = matmul_2d(&x, &self.lm_head_wt);
|
||||||
|
|
||||||
let draft_id = argmax_bf16_single(&logits);
|
let draft_id = argmax_bf16_single(&logits);
|
||||||
let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32;
|
let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32;
|
||||||
(target_id, logits, x)
|
// aux for recursive drafting = PRE-norm hidden (default norm_output=False
|
||||||
|
// in vllm/llama_eagle3.py). Feeding the pre-norm state matches training.
|
||||||
|
(target_id, logits, prenorm)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Write new K/V rows (shape [1, num_kv_heads, head_dim]) at position
|
/// Write new K/V rows (shape [1, num_kv_heads, head_dim]) at position
|
||||||
|
|||||||
Reference in New Issue
Block a user