diff --git a/crates/xserv-model/src/bin/bench-eagle3.rs b/crates/xserv-model/src/bin/bench-eagle3.rs index bb595b0..30a2e15 100644 --- a/crates/xserv-model/src/bin/bench-eagle3.rs +++ b/crates/xserv-model/src/bin/bench-eagle3.rs @@ -400,6 +400,8 @@ fn run_eagle_gamma_multi( let mut accepted_total = 0usize; let mut proposed_total = 0usize; let mut per_slot_correct: Vec = vec![0; gamma]; + let mut per_slot_covered: Vec = vec![0; gamma]; + let mut per_slot_top3: Vec = vec![0; gamma]; let mut per_slot_total: Vec = vec![0; gamma]; while generated.len() + 1 < gen_tokens && !tokenizer.is_eos(pending_prev) { @@ -416,13 +418,16 @@ fn run_eagle_gamma_multi( // Snapshot EAGLE's cache len so we can roll back rejected drafts' K/V. let eagle_len_before = eagle.current_len(); let mut drafts: Vec = Vec::with_capacity(round_gamma); - let (d0, _, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p); + let mut draft_logits: Vec = Vec::with_capacity(round_gamma); + let (d0, l0, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p); drafts.push(d0); + draft_logits.push(l0); let mut prev_aux = aux0; let mut prev_draft = d0; for k in 1..round_gamma { - let (dk, _, auxk) = eagle.step_recursive(prev_aux, embed_tokens, prev_draft, p + k); + let (dk, lk, auxk) = eagle.step_recursive(prev_aux, embed_tokens, prev_draft, p + k); drafts.push(dk); + draft_logits.push(lk); prev_aux = auxk; prev_draft = dk; } @@ -454,12 +459,22 @@ fn run_eagle_gamma_multi( k += 1; } accepted_total += k; - // Per-slot diagnostic: independent acceptance rate per position. + // Per-slot diagnostic: independent acceptance rate per position AND + // whether verify's argmax is even representable in EAGLE's draft vocab + // (coverage upper-bound on acceptance). for (i, &d) in drafts.iter().enumerate() { if d == verify_argmax[i] { per_slot_correct[i] += 1; } + if eagle.target_id_in_draft_vocab(verify_argmax[i]) { + per_slot_covered[i] += 1; + } per_slot_total[i] += 1; + // Top-3 diagnostic: is verify_argmax[i] in EAGLE's top-3 for slot i? + let top3 = top_k_target_ids(&draft_logits[i], 3, eagle); + if top3.contains(&verify_argmax[i]) { + per_slot_top3[i] += 1; + } } // EAGLE wrote γ K/V entries this round at slots @@ -539,16 +554,27 @@ fn run_eagle_gamma_multi( generated.push(pending_prev); } - eprint!("[per-slot d[i] correct/total: "); + eprint!("[per-slot i=correct/top3/covered/total (c,t3,cov): "); for i in 0..gamma { - let rate = if per_slot_total[i] > 0 { - per_slot_correct[i] as f64 / per_slot_total[i] as f64 + let (c_rate, t3_rate, cov_rate) = if per_slot_total[i] > 0 { + ( + per_slot_correct[i] as f64 / per_slot_total[i] as f64, + per_slot_top3[i] as f64 / per_slot_total[i] as f64, + per_slot_covered[i] as f64 / per_slot_total[i] as f64, + ) } else { - 0.0 + (0.0, 0.0, 0.0) }; eprint!( - "{}={}/{}({:.2}) ", - i, per_slot_correct[i], per_slot_total[i], rate + "{}={}/{}/{}/{}({:.2},{:.2},{:.2}) ", + i, + per_slot_correct[i], + per_slot_top3[i], + per_slot_covered[i], + per_slot_total[i], + c_rate, + t3_rate, + cov_rate ); } eprintln!("]"); @@ -585,6 +611,22 @@ fn finalize( } } +/// Return top-k target-vocab ids from a [1, draft_vocab_size] BF16 logits +/// tensor. Uses partial sort on CPU (draft_vocab=32k is small enough). +fn top_k_target_ids(logits: &Tensor, k: usize, eagle: &Eagle3Head) -> Vec { + use half::bf16; + let cpu = logits.to_device(Device::Cpu); + let data = cpu.as_slice::(); + let mut idx: Vec = (0..data.len()).collect(); + idx.select_nth_unstable_by(k, |&a, &b| { + data[b].to_f32().partial_cmp(&data[a].to_f32()).unwrap() + }); + idx[..k] + .iter() + .map(|&i| eagle.map_draft_to_target(i as u32)) + .collect() +} + fn argmax_rows(logits: &Tensor) -> Vec { xserv_kernels::argmax_bf16_to_host(logits) } diff --git a/crates/xserv-model/src/eagle3.rs b/crates/xserv-model/src/eagle3.rs index 50d4765..d74ecc8 100644 --- a/crates/xserv-model/src/eagle3.rs +++ b/crates/xserv-model/src/eagle3.rs @@ -43,6 +43,10 @@ pub struct Eagle3Head { norm: Tensor, // [hidden] final lm_head_wt: Tensor, // [draft_vocab, hidden] d2t: Vec, // [draft_vocab] offset mapping + /// t2d[target_id] = true iff target_id has a corresponding draft-vocab id + /// (i.e. can potentially be produced by EAGLE). Used to measure the + /// coverage cap on acceptance. + t2d: Vec, hidden_size: usize, num_heads: usize, num_kv_heads: usize, @@ -59,7 +63,7 @@ pub struct Eagle3Head { impl Eagle3Head { pub fn load(dir: &Path, device: u32) -> Self { - let (weights, d2t) = load_eagle3_weights(dir, device); + let (weights, d2t, t2d) = load_eagle3_weights(dir, device); let hidden_size = 4096; let num_heads = 32; let num_kv_heads = 8; @@ -133,6 +137,7 @@ impl Eagle3Head { norm, lm_head_wt, d2t, + t2d, hidden_size, num_heads, num_kv_heads, @@ -290,6 +295,12 @@ impl Eagle3Head { pub fn map_draft_to_target(&self, draft_id: u32) -> u32 { (draft_id as i64 + self.d2t[draft_id as usize]) as u32 } + + /// Returns true iff `target_id` is representable in the draft vocabulary + /// (i.e., EAGLE could in principle produce it). + pub fn target_id_in_draft_vocab(&self, target_id: u32) -> bool { + self.t2d.get(target_id as usize).copied().unwrap_or(false) + } } fn d2d(dst: *mut u8, src: *const u8, bytes: usize) { @@ -349,8 +360,8 @@ fn repeat_kv_for_single_token(kv: &Tensor, repeats: usize) -> Tensor { out } -/// Load EAGLE3 weights from safetensors, handling int64 d2t specially. -fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap, Vec) { +/// Load EAGLE3 weights from safetensors, handling int64 d2t + bool t2d specially. +fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap, Vec, Vec) { let st_path = dir.join("model.safetensors"); assert!( st_path.exists(), @@ -368,9 +379,13 @@ fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap, Vec let mut tensors = HashMap::new(); let mut d2t_vec: Vec = Vec::new(); + let mut t2d_vec: Vec = Vec::new(); for (name, view) in st.tensors() { if name == "t2d" { + let raw = view.data(); + assert_eq!(view.dtype(), safetensors::Dtype::BOOL); + t2d_vec = raw.iter().map(|&b| b != 0).collect(); continue; } if name == "d2t" { @@ -402,5 +417,9 @@ fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap, Vec !d2t_vec.is_empty(), "d2t tensor not found in eagle3 weights" ); - (tensors, d2t_vec) + assert!( + !t2d_vec.is_empty(), + "t2d tensor not found in eagle3 weights" + ); + (tensors, d2t_vec, t2d_vec) }