diff --git a/crates/xserv-model/src/bin/bench-eagle3.rs b/crates/xserv-model/src/bin/bench-eagle3.rs index bd1bae7..2b0f47d 100644 --- a/crates/xserv-model/src/bin/bench-eagle3.rs +++ b/crates/xserv-model/src/bin/bench-eagle3.rs @@ -151,13 +151,19 @@ fn main() { baseline_tokens += baseline.ids.len(); drop(baseline_cache); - // Speculative with EAGLE, γ from CLI. Verify uses the tree kernel with - // a causal mask (equivalent to non-tree behavior); a real tree - // (siblings sharing target positions) would require KV cache slot - // remap after acceptance, which is out of scope for this iteration. - let _ = use_tree; // reserved for future tree drafting + // Speculative with EAGLE. let mut target_cache = new_cache(&target_config, max_seq_len, device); - let spec = if gamma == 1 { + let spec = if use_tree { + run_eagle_tree( + &target, + &mut eagle, + &embed_tokens, + &mut target_cache, + &tokenizer, + &ids, + gen_tokens, + ) + } else if gamma == 1 { run_eagle_gamma1( &target, &mut eagle, @@ -609,6 +615,200 @@ fn run_eagle_gamma_multi( ) } +/// Tree drafting with top-2 siblings at slot 0 + chain from top-1. +/// +/// Tree structure per round (4 verify tokens): +/// [pending_prev, d0_top1, d0_top2, d1_chain_from_top1] +/// positions: [P, P+1, P+1, P+2] +/// tree_mask: [[1,0,0,0],[1,1,0,0],[1,0,1,0],[1,1,0,1]] +/// +/// Acceptance: +/// - d0_top1 matches target → check d1 → commit 2 or 3 tokens. +/// - d0_top2 matches target → copy_kv_position(P+2 → P+1) + commit 2 tokens. +/// - Neither → commit pending_prev only (1 token). +#[allow(clippy::too_many_arguments)] +fn run_eagle_tree( + target: &Qwen3, + eagle: &mut Eagle3Head, + embed_tokens: &Tensor, + cache: &mut PagedKVCache, + tokenizer: &Tokenizer, + prompt_ids: &[u32], + gen_tokens: usize, +) -> RunStats { + use xserv_model::eagle3::EAGLE_HOOK_LAYERS; + let slot = 0; + cache.register_sequence(slot).unwrap(); + eagle.reset(); + let t0 = Instant::now(); + + // Prefill + seed decode for initial hooks. + let prefill_logits = target.forward_prefill_paged(prompt_ids, slot, cache); + let first_token = last_argmax(&prefill_logits); + let (seed_logits, seed_hooks_v) = target_decode_with_hidden(target, first_token, cache, slot); + let mut generated = vec![first_token]; + let mut pending_prev = last_argmax(&seed_logits); + let mut seed_hooks: [Tensor; 3] = seed_hooks_v; + let mut target_steps = 2usize; + let mut accepted_total = 0usize; + let mut proposed_total = 0usize; + + while generated.len() + 1 < gen_tokens && !tokenizer.is_eos(pending_prev) { + let p = cache.seq_len(slot); + if generated.len() + 4 > gen_tokens { + // Not enough room for tree; fall back to committing pending_prev only. + generated.push(pending_prev); + break; + } + + let eagle_len_before = eagle.current_len(); + + // Draft: EAGLE step 0 → top-2 candidates for position P+1. + let (d0_top1, l0, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p); + let top2 = top_k_target_ids(&l0, 2, eagle); + let d0_top2 = if top2[0] == d0_top1 { top2[1] } else { top2[0] }; + + // Draft: EAGLE step 1 → chain from d0_top1 for position P+2. + let (d1, _, _) = eagle.step_recursive(aux0, embed_tokens, d0_top1, p + 1); + + proposed_total += 3; // 3 candidates proposed (d0_top1, d0_top2, d1) + + // Verify with tree: 4 tokens. + let verify_input = vec![pending_prev, d0_top1, d0_top2, d1]; + let positions_v: Vec = vec![p as u32, (p + 1) as u32, (p + 1) as u32, (p + 2) as u32]; + let kv_lens_v: Vec = vec![ + (p + 1) as i32, + (p + 2) as i32, + (p + 3) as i32, + (p + 4) as i32, + ]; + #[rustfmt::skip] + let tree_mask: Vec = vec![ + 1, 0, 0, 0, + 1, 1, 0, 0, + 1, 0, 1, 0, + 1, 1, 0, 1, + ]; + + let (verify_logits, verify_hooks) = target + .forward_verify_paged_decode_attention_tree_with_hidden( + &verify_input, + &positions_v, + &kv_lens_v, + &tree_mask, + slot, + cache, + &EAGLE_HOOK_LAYERS, + ); + target_steps += 1; + // cache.seq_len is now p + 4. + + let va = argmax_rows(&verify_logits); + // va[0] = target at P+1; va[1] = target at P+2 given ..d0_top1; + // va[2] = target at P+2 given ..d0_top2; va[3] = target at P+3 given ..d1. + + if d0_top1 == va[0] { + // Top-1 path accepted at slot 0. + accepted_total += 1; + if d1 == va[1] { + // Chain also accepted. + accepted_total += 1; + // Commit: pending_prev + d0_top1 + d1 = 3 tokens at [P, P+1, P+3]. + // But K/V layout: [P]=pending_prev, [P+1]=d0_top1, [P+2]=d0_top2, [P+3]=d1. + // We need d1 at cache position P+2 (not P+3). Remap: + cache.copy_kv_position(slot, p + 3, p + 2); + cache.truncate_sequence(slot, p + 3).unwrap(); + generated.push(pending_prev); + generated.push(d0_top1); + generated.push(d1); + pending_prev = va[3]; + // seed_hooks = verify_hooks[3] (state that produced va[3]). + let mut new_hooks: Vec = Vec::with_capacity(3); + for h in verify_hooks.iter() { + new_hooks.push(h.narrow(0, 3, 1).contiguous()); + } + seed_hooks = [ + new_hooks.remove(0), + new_hooks.remove(0), + new_hooks.remove(0), + ]; + eagle.truncate_to(eagle_len_before + 2); // keep pending_prev + d0_top1 + } else { + // Only d0_top1 accepted, d1 rejected. + // K/V [P]=pp, [P+1]=d0_top1 — correct. Drop [P+2], [P+3]. + cache.truncate_sequence(slot, p + 2).unwrap(); + generated.push(pending_prev); + generated.push(d0_top1); + pending_prev = va[1]; + let mut new_hooks: Vec = Vec::with_capacity(3); + for h in verify_hooks.iter() { + new_hooks.push(h.narrow(0, 1, 1).contiguous()); + } + seed_hooks = [ + new_hooks.remove(0), + new_hooks.remove(0), + new_hooks.remove(0), + ]; + eagle.truncate_to(eagle_len_before + 2); + } + } else if d0_top2 == va[0] { + // Top-2 sibling accepted! K/V for d0_top2 is at cache position P+2. + // Copy it to P+1 (canonical position for the accepted token). + accepted_total += 1; + cache.copy_kv_position(slot, p + 2, p + 1); + cache.truncate_sequence(slot, p + 2).unwrap(); + generated.push(pending_prev); + generated.push(d0_top2); + pending_prev = va[2]; + let mut new_hooks: Vec = Vec::with_capacity(3); + for h in verify_hooks.iter() { + new_hooks.push(h.narrow(0, 2, 1).contiguous()); + } + seed_hooks = [ + new_hooks.remove(0), + new_hooks.remove(0), + new_hooks.remove(0), + ]; + // EAGLE cache: only step 0 was relevant (pending_prev). Truncate to 1. + eagle.truncate_to(eagle_len_before + 1); + } else { + // Both rejected. Commit only pending_prev. + cache.truncate_sequence(slot, p + 1).unwrap(); + generated.push(pending_prev); + pending_prev = va[0]; + let mut new_hooks: Vec = Vec::with_capacity(3); + for h in verify_hooks.iter() { + new_hooks.push(h.narrow(0, 0, 1).contiguous()); + } + seed_hooks = [ + new_hooks.remove(0), + new_hooks.remove(0), + new_hooks.remove(0), + ]; + eagle.truncate_to(eagle_len_before + 1); + } + + if generated.len() >= gen_tokens || tokenizer.is_eos(*generated.last().unwrap()) { + break; + } + } + + // Commit remaining pending_prev if needed. + if generated.len() < gen_tokens { + generated.push(pending_prev); + } + + finalize( + generated, + t0, + target_steps, + accepted_total, + proposed_total, + cache, + slot, + ) +} + fn finalize( generated: Vec, t0: Instant,