diff --git a/crates/xserv-model/src/bin/bench-eagle3.rs b/crates/xserv-model/src/bin/bench-eagle3.rs index a07db72..bb595b0 100644 --- a/crates/xserv-model/src/bin/bench-eagle3.rs +++ b/crates/xserv-model/src/bin/bench-eagle3.rs @@ -384,53 +384,59 @@ fn run_eagle_gamma_multi( eagle.reset(); let t0 = Instant::now(); - // Prefill. Advance one target decode step to seed hidden hooks. + // Prefill: cache holds prompt K/V. Argmax of last row = first_token (target + // for position prompt_len). let prefill_logits = target.forward_prefill_paged(prompt_ids, slot, cache); let first_token = last_argmax(&prefill_logits); - let mut generated = vec![first_token]; - // Seed: run one target decode with first_token to get hooks at that position. - let (logits, mut seed_hooks) = target_decode_with_hidden(target, first_token, cache, slot); - let mut prev_token = last_argmax(&logits); - generated.push(prev_token); - let mut target_steps = 2; // 1 prefill + 1 decode seed + // Seed decode with first_token: writes its K/V at prompt_len, returns hooks + // for that position + logits whose argmax is target's answer for + // prompt_len+1 (this becomes our first pending_prev). + let (seed_logits, seed_hooks_v) = target_decode_with_hidden(target, first_token, cache, slot); + let mut generated = vec![first_token]; + let mut pending_prev = last_argmax(&seed_logits); + let mut seed_hooks: [Tensor; 3] = seed_hooks_v; + let mut target_steps = 2; // 1 prefill + 1 seed decode let mut accepted_total = 0usize; let mut proposed_total = 0usize; + let mut per_slot_correct: Vec = vec![0; gamma]; + let mut per_slot_total: Vec = vec![0; gamma]; - while generated.len() < gen_tokens && !tokenizer.is_eos(prev_token) { - let round_pos = cache.seq_len(slot); // position of next token to predict - let remaining = gen_tokens - generated.len(); + while generated.len() + 1 < gen_tokens && !tokenizer.is_eos(pending_prev) { + let p = cache.seq_len(slot); // position where pending_prev will be written + let remaining = gen_tokens - generated.len() - 1; // -1 because pending_prev counts let round_gamma = gamma.min(remaining); + if round_gamma == 0 { + break; + } - // Draft γ tokens recursively. + // Draft γ tokens for positions p+1..p+γ. EAGLE step k takes the token + // at position p+k as input (pending_prev at p for k=0, previous draft + // d[k-1] at p+k for k>=1) and predicts token at position p+k+1. + // Snapshot EAGLE's cache len so we can roll back rejected drafts' K/V. + let eagle_len_before = eagle.current_len(); let mut drafts: Vec = Vec::with_capacity(round_gamma); - let (d0, _, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, prev_token, round_pos); + let (d0, _, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p); drafts.push(d0); let mut prev_aux = aux0; let mut prev_draft = d0; for k in 1..round_gamma { - let (dk, _, auxk) = - eagle.step_recursive(prev_aux, embed_tokens, prev_draft, round_pos + k); + let (dk, _, auxk) = eagle.step_recursive(prev_aux, embed_tokens, prev_draft, p + k); drafts.push(dk); prev_aux = auxk; prev_draft = dk; } proposed_total += round_gamma; - // Verify: target forward on [prev_token, d0..d_{γ-1}] (length γ+1) to - // get γ+1 logits. verify_argmax[i] predicts position round_pos+i. - // Accept d[i] iff d[i] == verify_argmax[i]. The extra verify_argmax[γ] - // provides a free correction for the all-accepted case. + // Verify: run target on [pending_prev, d[0]..d[γ-1]] at positions + // [p..p+γ]. Writes γ+1 K/V rows. verify_argmax[i] predicts position + // p+i+1, i.e., verifies d[i]. let mut verify_input: Vec = Vec::with_capacity(round_gamma + 1); - verify_input.push(prev_token); + verify_input.push(pending_prev); for &d in drafts.iter() { verify_input.push(d); } - // Roll back the seed decode (we wrote prev_token at round_pos-1). - cache - .truncate_sequence(slot, round_pos - 1) - .expect("truncate before verify"); - let (verify_logits, _verify_hooks) = target + let (verify_logits, verify_hooks) = target .forward_verify_paged_decode_attention_with_hidden( &verify_input, slot, @@ -438,46 +444,135 @@ fn run_eagle_gamma_multi( &EAGLE_HOOK_LAYERS, ); target_steps += 1; + // cache.seq_len is now p + γ + 1. - // Longest matching prefix: d[i] accepted iff d[i] == argmax(verify_logits[i]). let verify_argmax = argmax_rows(&verify_logits); - let mut accepted = 0usize; - while accepted < round_gamma && drafts[accepted] == verify_argmax[accepted] { - accepted += 1; + // Longest matching prefix of d[i] == verify_argmax[i]. + // Also count per-slot: was d[i] correct independent of longest-prefix? + let mut k = 0usize; + while k < round_gamma && drafts[k] == verify_argmax[k] { + k += 1; + } + accepted_total += k; + // Per-slot diagnostic: independent acceptance rate per position. + for (i, &d) in drafts.iter().enumerate() { + if d == verify_argmax[i] { + per_slot_correct[i] += 1; + } + per_slot_total[i] += 1; } - accepted_total += accepted; - // Commit accepted tokens (already in cache at positions round_pos..round_pos+accepted-1). - // Truncate cache to that length. - cache - .truncate_sequence(slot, round_pos + accepted) - .expect("truncate to accepted"); - for i in 0..accepted { + // EAGLE wrote γ K/V entries this round at slots + // eagle_len_before..eagle_len_before+γ. Slot i (in the new range) holds + // K/V for input at target position p+i (pending_prev at slot 0, drafts + // at slots 1..γ-1... wait no: step_with_aux input pending_prev writes + // at slot eagle_len_before with RoPE position p; step_recursive for + // d[0] writes at slot eagle_len_before+1 with RoPE position p+1; etc. + // So slot eagle_len_before+i holds pending_prev (i=0) or d[i-1] + // (i>=1). Accepted commits pending_prev + d[0..k-1] = k+1 EAGLE entries + // at slots eagle_len_before..eagle_len_before+k. Slots + // eagle_len_before+k+1..+γ-1 hold rejected drafts' K/V — discard. + let keep_len = eagle_len_before + k + 1; + if keep_len < eagle.current_len() { + eagle.truncate_to(keep_len); + } + + // Commit pending_prev + accepted drafts. Positions p..p+k filled. + generated.push(pending_prev); + for i in 0..k { generated.push(drafts[i]); - if generated.len() >= gen_tokens || tokenizer.is_eos(drafts[i]) { - break; + if tokenizer.is_eos(drafts[i]) { + // Truncate cache to keep only committed K/V. + cache + .truncate_sequence(slot, p + i + 2) + .expect("truncate on EOS"); + return finalize( + generated, + t0, + target_steps, + accepted_total, + proposed_total, + cache, + slot, + ); + } + if generated.len() >= gen_tokens { + cache + .truncate_sequence(slot, p + i + 2) + .expect("truncate on gen_tokens"); + return finalize( + generated, + t0, + target_steps, + accepted_total, + proposed_total, + cache, + slot, + ); } } - if generated.len() >= gen_tokens { - break; - } - if !generated.is_empty() && tokenizer.is_eos(*generated.last().unwrap()) { - break; - } - // Correction: verify_argmax[accepted] is target's answer for position - // round_pos+accepted. Write its K/V via one target decode. This also - // gives us seed_hooks for the next round. - let correction = verify_argmax[accepted]; - let (_, new_hooks) = target_decode_with_hidden(target, correction, cache, slot); - target_steps += 1; - generated.push(correction); - prev_token = correction; - // Extract the position-`accepted` slice of verify_hooks as the seed for next round? - // Actually the correction just wrote K/V AT position round_pos+accepted, and - // its hidden state = new_hooks. We use new_hooks as the seed for the next round. - seed_hooks = new_hooks; + // Truncate cache to drop unaccepted drafts' K/V. Keep positions 0..p+k. + // Cache now has committed tokens through position p+k (pending_prev at + // p, drafts[0..k] at p+1..p+k). + cache + .truncate_sequence(slot, p + k + 1) + .expect("truncate to accepted"); + + // New pending_prev = verify_argmax[k]. seed_hooks = row k of verify_hooks. + pending_prev = verify_argmax[k]; + let mut new_hooks: Vec = Vec::with_capacity(3); + for h in verify_hooks.iter() { + new_hooks.push(h.narrow(0, k, 1).contiguous()); + } + seed_hooks = [ + new_hooks.remove(0), + new_hooks.remove(0), + new_hooks.remove(0), + ]; } + + // Loop ended. If pending_prev wasn't committed yet, commit it now. + if generated.len() < gen_tokens && !tokenizer.is_eos(pending_prev) { + generated.push(pending_prev); + } else if generated.len() < gen_tokens { + generated.push(pending_prev); + } + + eprint!("[per-slot d[i] correct/total: "); + for i in 0..gamma { + let rate = if per_slot_total[i] > 0 { + per_slot_correct[i] as f64 / per_slot_total[i] as f64 + } else { + 0.0 + }; + eprint!( + "{}={}/{}({:.2}) ", + i, per_slot_correct[i], per_slot_total[i], rate + ); + } + eprintln!("]"); + + finalize( + generated, + t0, + target_steps, + accepted_total, + proposed_total, + cache, + slot, + ) +} + +fn finalize( + generated: Vec, + t0: Instant, + target_steps: usize, + accepted: usize, + proposed: usize, + cache: &mut PagedKVCache, + slot: usize, +) -> RunStats { sync_device(); let total_s = t0.elapsed().as_secs_f64(); cache.free_sequence(slot); @@ -485,8 +580,8 @@ fn run_eagle_gamma_multi( ids: generated, total_s, target_steps, - accepted: accepted_total, - proposed: proposed_total, + accepted, + proposed, } } diff --git a/crates/xserv-model/src/eagle3.rs b/crates/xserv-model/src/eagle3.rs index 0addc71..50d4765 100644 --- a/crates/xserv-model/src/eagle3.rs +++ b/crates/xserv-model/src/eagle3.rs @@ -150,6 +150,18 @@ impl Eagle3Head { self.current_len = 0; } + /// Truncate the internal KV cache to `new_len` entries. Used to discard + /// K/V of rejected drafts after a speculative round. + pub fn truncate_to(&mut self, new_len: usize) { + assert!(new_len <= self.current_len); + self.current_len = new_len; + } + + /// Current number of committed K/V entries in the internal EAGLE cache. + pub fn current_len(&self) -> usize { + self.current_len + } + /// One draft step: produce a token in target vocabulary space. /// /// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers @@ -248,12 +260,14 @@ impl Eagle3Head { let hidden = silu_mul(&gate, &up); let down = matmul_2d(&hidden, &self.down_proj_wt); - let (x, _) = add_rmsnorm(&down, &residual, &self.norm, eps); + let (x, prenorm) = add_rmsnorm(&down, &residual, &self.norm, eps); let logits = matmul_2d(&x, &self.lm_head_wt); let draft_id = argmax_bf16_single(&logits); let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32; - (target_id, logits, x) + // aux for recursive drafting = PRE-norm hidden (default norm_output=False + // in vllm/llama_eagle3.py). Feeding the pre-norm state matches training. + (target_id, logits, prenorm) } /// Write new K/V rows (shape [1, num_kv_heads, head_dim]) at position