diff --git a/crates/xserv-model/src/bin/check-eagle3.rs b/crates/xserv-model/src/bin/check-eagle3.rs new file mode 100644 index 0000000..fdb3484 --- /dev/null +++ b/crates/xserv-model/src/bin/check-eagle3.rs @@ -0,0 +1,152 @@ +//! EAGLE3 sanity check: load weights, run one draft step, print top-5 predictions. +//! +//! This verifies that: +//! - Eagle3Head weights load without shape mismatches +//! - Target hidden states can be captured via decode_core_with_hidden +//! - Eagle3Head::step produces a valid token id (in target vocab) +//! +//! Does NOT measure speedup — that requires a full γ≥2 speculative loop, which +//! is more complex integration work. + +use std::path::PathBuf; + +use xserv_model::eagle3::{EAGLE_HOOK_LAYERS, Eagle3Head}; +use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader}; +use xserv_tensor::{DType, Device, Tensor}; +use xserv_tokenizer::Tokenizer; + +fn main() { + let args: Vec = std::env::args().collect(); + if args.len() < 3 { + eprintln!("Usage: check-eagle3 [prompt]"); + std::process::exit(1); + } + let target_dir = PathBuf::from(&args[1]); + let eagle_dir = PathBuf::from(&args[2]); + let prompt = args + .get(3) + .cloned() + .unwrap_or_else(|| "The capital of France is".to_string()); + let device: u32 = 0; + + xserv_cuda::device::set_device(device).unwrap(); + + let target_config = ModelConfig::from_file(&target_dir.join("config.json")); + eprintln!("Loading target Qwen3-8B..."); + let target_weights = loader::load_model_dir(&target_dir, Device::Cuda(device)); + let target = Qwen3::from_weights(target_config.clone(), target_weights); + xserv_cuda::allocator::cached_trim(); + + eprintln!("Loading EAGLE3 head from {}", eagle_dir.display()); + let eagle = Eagle3Head::load(&eagle_dir, device); + xserv_cuda::allocator::cached_trim(); + + let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json")); + let embed_tokens = target.embed_tokens_tensor(); + + let ids = tokenizer.encode(&prompt); + let max_seq_len = 512; + + let num_blocks = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE + 2; + let mut cache = PagedKVCache::new( + &target_config, + num_blocks, + 0, + 1, + num_blocks, + DType::BF16, + device, + ); + cache.register_sequence(0).unwrap(); + + // Prefill target. + let logits = target.forward_prefill_paged(&ids, 0, &mut cache); + let target_first = *xserv_kernels::argmax_bf16_to_host(&logits).last().unwrap(); + let target_first_text = tokenizer.decode(&[target_first]); + println!("Prompt: {:?}", prompt); + println!( + "Target argmax after prefill: {} ({:?})", + target_first, target_first_text + ); + + // Now run one target decode step with target_first to get hidden states at the + // hook layers. + let pos = cache.seq_len(0); + target.decode_prepare(&[pos], &[0], &mut cache); + let ids_gpu = upload_u32(&[target_first]); + let pos_gpu = upload_u32(&[pos as u32]); + let (target_next_logits, hooks) = target.decode_core_with_hidden( + ids_gpu.as_ptr() as *const std::ffi::c_void, + pos_gpu.as_ptr() as *const std::ffi::c_void, + 1, + &[0], + &mut cache, + &EAGLE_HOOK_LAYERS, + ); + let target_next = xserv_kernels::argmax_bf16_single(&target_next_logits); + let target_next_text = tokenizer.decode(&[target_next]); + println!( + "Target argmax after 1 decode step: {} ({:?})", + target_next, target_next_text + ); + + for (i, h) in hooks.iter().enumerate() { + println!( + "hook[{}] (layer {}): shape={:?} dtype={:?}", + i, + EAGLE_HOOK_LAYERS[i], + h.shape(), + h.dtype() + ); + } + + // Ask EAGLE what it thinks the NEXT token is (given target_first as prev_token + // and the hidden states from the position where target_first lives). + // EAGLE should predict target_next (or close to it) to be useful. + let (eagle_pred, eagle_logits) = eagle.step(&hooks, embed_tokens, target_first, pos); + let eagle_pred_text = tokenizer.decode(&[eagle_pred]); + println!( + "EAGLE draft prediction: {} ({:?})", + eagle_pred, eagle_pred_text + ); + + if eagle_pred == target_next { + println!("MATCH: EAGLE agrees with target on next token."); + } else { + println!( + "MISMATCH: EAGLE draft={} vs target={} (this is fine per-step; check top-5 below)", + eagle_pred, target_next + ); + } + + // Show top-5 from eagle logits (in draft vocab space, mapped to target). + print_top5(&eagle_logits, "EAGLE draft top-5", &eagle, &tokenizer); +} + +fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer { + let bytes = unsafe { std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) }; + let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).unwrap(); + buf.copy_from_host(bytes).unwrap(); + buf +} + +fn print_top5(logits: &Tensor, label: &str, eagle: &Eagle3Head, tokenizer: &Tokenizer) { + use half::bf16; + let cpu = logits.to_device(Device::Cpu); + let data = cpu.as_slice::(); + let mut vals: Vec<(usize, f32)> = data + .iter() + .enumerate() + .map(|(i, v)| (i, v.to_f32())) + .collect(); + vals.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + println!("{label}:"); + for (i, val) in vals.iter().take(5) { + let target_id = eagle.map_draft_to_target(*i as u32); + let text = tokenizer.decode(&[target_id]); + println!( + " draft_id={} target_id={} val={:.3} text={:?}", + i, target_id, val, text + ); + } +} diff --git a/crates/xserv-model/src/eagle3.rs b/crates/xserv-model/src/eagle3.rs new file mode 100644 index 0000000..27f8964 --- /dev/null +++ b/crates/xserv-model/src/eagle3.rs @@ -0,0 +1,312 @@ +//! EAGLE3 speculative draft head for Qwen3-8B (Phase 25). +//! +//! Loads the AngelSlim/Qwen3-8B_eagle3 pytorch_model.bin and provides a +//! single-step forward pass that takes 3 target hidden states + the previous +//! token and returns a draft token in the target vocabulary. +//! +//! Architecture (from weights): +//! - fc: [hidden, 3*hidden] → fuse 3 target hidden states +//! - midlayer: 1 decoder layer (attn input dim = 2*hidden) +//! - norm + lm_head: → [draft_vocab_size=32000] +//! - d2t: draft_id → target_id offset mapping + +use std::collections::HashMap; +use std::path::Path; +use xserv_kernels::*; +use xserv_tensor::{DType, Device, Tensor}; + +pub const EAGLE_HOOK_LAYERS: [usize; 3] = [11, 23, 35]; +const DRAFT_VOCAB_SIZE: usize = 32000; + +fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor { + assert_eq!(a.ndim(), 2); + assert_eq!(b.ndim(), 2); + matmul(a, b, GemmBackend::CuBlas) +} + +pub struct Eagle3Head { + fc_wt: Tensor, // [hidden, 3*hidden] transposed for matmul + hidden_norm: Tensor, // [hidden] + input_layernorm: Tensor, // [hidden] + q_proj_wt: Tensor, // [num_heads*head_dim, 2*hidden] + k_proj_wt: Tensor, // [num_kv_heads*head_dim, 2*hidden] + v_proj_wt: Tensor, // [num_kv_heads*head_dim, 2*hidden] + o_proj_wt: Tensor, // [hidden, num_heads*head_dim] + gate_proj_wt: Tensor, // [intermediate, hidden] + up_proj_wt: Tensor, // [intermediate, hidden] + down_proj_wt: Tensor, // [hidden, intermediate] + post_attention_layernorm: Tensor, // [hidden] + norm: Tensor, // [hidden] final + lm_head_wt: Tensor, // [draft_vocab, hidden] + d2t: Vec, // [draft_vocab] offset mapping + hidden_size: usize, + num_heads: usize, + num_kv_heads: usize, + head_dim: usize, + rope_cache: RopeCache, +} + +impl Eagle3Head { + pub fn load(dir: &Path, device: u32) -> Self { + let (weights, d2t) = load_eagle3_weights(dir, device); + let hidden_size = 4096; + let num_heads = 32; + let num_kv_heads = 8; + let head_dim = 128; + let intermediate_size = 12288; + let max_seq_len = 2048; + let rope_theta = 1_000_000.0f32; + + let get = |name: &str| -> Tensor { + weights + .get(name) + .unwrap_or_else(|| panic!("missing eagle3 weight: {name}")) + .clone() + }; + + let fc_wt = get("fc.weight").transpose(0, 1).contiguous(); + let q_proj_wt = get("midlayer.self_attn.q_proj.weight") + .transpose(0, 1) + .contiguous(); + let k_proj_wt = get("midlayer.self_attn.k_proj.weight") + .transpose(0, 1) + .contiguous(); + let v_proj_wt = get("midlayer.self_attn.v_proj.weight") + .transpose(0, 1) + .contiguous(); + let o_proj_wt = get("midlayer.self_attn.o_proj.weight") + .transpose(0, 1) + .contiguous(); + let gate_proj_wt = get("midlayer.mlp.gate_proj.weight") + .transpose(0, 1) + .contiguous(); + let up_proj_wt = get("midlayer.mlp.up_proj.weight") + .transpose(0, 1) + .contiguous(); + let down_proj_wt = get("midlayer.mlp.down_proj.weight") + .transpose(0, 1) + .contiguous(); + let hidden_norm = get("midlayer.hidden_norm.weight"); + let input_layernorm = get("midlayer.input_layernorm.weight"); + let post_attention_layernorm = get("midlayer.post_attention_layernorm.weight"); + let norm = get("norm.weight"); + let lm_head_wt = get("lm_head.weight").transpose(0, 1).contiguous(); + + assert_eq!(d2t.len(), DRAFT_VOCAB_SIZE); + + let rope_cache = RopeCache::new(max_seq_len, head_dim, rope_theta); + + Self { + fc_wt, + hidden_norm, + input_layernorm, + q_proj_wt, + k_proj_wt, + v_proj_wt, + o_proj_wt, + gate_proj_wt, + up_proj_wt, + down_proj_wt, + post_attention_layernorm, + norm, + lm_head_wt, + d2t, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + rope_cache, + } + } + + /// One draft step: produce a token in target vocabulary space. + /// + /// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers + /// - `embed_table`: the target model's embed_tokens (shared, not copied) + /// - `prev_token`: the previous committed token + /// - `position`: the decode position for RoPE + /// + /// Returns (draft_token_in_target_vocab, draft_logits_tensor). + pub fn step( + &self, + target_hidden: &[Tensor; 3], + embed_table: &Tensor, + prev_token: u32, + position: usize, + ) -> (u32, Tensor) { + let eps = 1e-6f32; + + // 1. Fuse target hidden states: concat [h_low, h_mid, h_high] → fc + let h_cat = concat_hidden(target_hidden); + let fused_h = matmul_2d(&h_cat, &self.fc_wt); // [1, hidden] + + // 2. Embed previous token (shared with target) + let emb = embedding(embed_table, &[prev_token]); // [1, hidden] + + // 3. Concat normalized: [norm(emb), norm(fused_h)] → [1, 2*hidden] + let emb_normed = rmsnorm(&emb, &self.input_layernorm, eps); + let h_normed = rmsnorm(&fused_h, &self.hidden_norm, eps); + let attn_in = concat_last_dim(&emb_normed, &h_normed); // [1, 8192] + + // 4. Self-attention (no KV cache for simplicity in v0 — single query) + let q = matmul_2d(&attn_in, &self.q_proj_wt); // [1, num_heads*head_dim] + let k = matmul_2d(&attn_in, &self.k_proj_wt); // [1, num_kv*head_dim] + let v = matmul_2d(&attn_in, &self.v_proj_wt); // [1, num_kv*head_dim] + + let q_3d = q.reshape(&[1, self.num_heads, self.head_dim]); + let k_3d = k.reshape(&[1, self.num_kv_heads, self.head_dim]); + let positions = [position as u32]; + rope_inplace(&q_3d, &self.rope_cache, &positions); + rope_inplace(&k_3d, &self.rope_cache, &positions); + + // Single-token attention: Q·K^T / sqrt(d) → softmax → V + // With seq_len=1, attention is trivial: output = V (weight=1.0) + let attn_out = v.reshape(&[1, self.num_kv_heads, self.head_dim]); + let attn_out = if self.num_heads != self.num_kv_heads { + repeat_kv_for_single_token(&attn_out, self.num_heads / self.num_kv_heads) + } else { + attn_out + }; + let attn_merged = attn_out.reshape(&[1, self.num_heads * self.head_dim]); + let attn_proj = matmul_2d(&attn_merged, &self.o_proj_wt); // [1, hidden] + + // Residual from embedding + let x = add(&attn_proj, &emb); + + // 5. MLP + let normed = rmsnorm(&x, &self.post_attention_layernorm, eps); + let gate = matmul_2d(&normed, &self.gate_proj_wt); + let up = matmul_2d(&normed, &self.up_proj_wt); + let mlp_out = silu_mul(&gate, &up); + let down = matmul_2d(&mlp_out, &self.down_proj_wt); + let x = add(&x, &down); + + // 6. Final norm + lm_head + let x = rmsnorm(&x, &self.norm, eps); + let logits = matmul_2d(&x, &self.lm_head_wt); // [1, 32000] + + // 7. Argmax in draft vocab → map to target vocab + let draft_id = argmax_bf16_single(&logits); + let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32; + + (target_id, logits) + } + + /// Map a draft-vocab token id to the full target-vocab id via d2t. + pub fn map_draft_to_target(&self, draft_id: u32) -> u32 { + (draft_id as i64 + self.d2t[draft_id as usize]) as u32 + } +} + +fn d2d(dst: *mut u8, src: *const u8, bytes: usize) { + unsafe { + xserv_cuda::ffi::cudaMemcpy(dst, src, bytes, xserv_cuda::ffi::CUDA_MEMCPY_D2D); + } +} + +fn concat_hidden(hidden: &[Tensor; 3]) -> Tensor { + let h = hidden[0].shape()[1]; + let dtype = hidden[0].dtype(); + let device = hidden[0].device(); + let elem_bytes = dtype.size_bytes(); + let out = Tensor::empty(&[1, 3 * h], dtype, device); + for (i, t) in hidden.iter().enumerate() { + assert!(t.is_contiguous()); + let dst = unsafe { (out.data_ptr() as *mut u8).add(i * h * elem_bytes) }; + d2d(dst, t.data_ptr() as *const u8, h * elem_bytes); + } + out +} + +fn concat_last_dim(a: &Tensor, b: &Tensor) -> Tensor { + let da = a.shape()[1]; + let db = b.shape()[1]; + let dtype = a.dtype(); + let device = a.device(); + let elem_bytes = dtype.size_bytes(); + let out = Tensor::empty(&[1, da + db], dtype, device); + d2d( + out.data_ptr() as *mut u8, + a.data_ptr() as *const u8, + da * elem_bytes, + ); + let dst = unsafe { (out.data_ptr() as *mut u8).add(da * elem_bytes) }; + d2d(dst, b.data_ptr() as *const u8, db * elem_bytes); + out +} + +fn repeat_kv_for_single_token(kv: &Tensor, repeats: usize) -> Tensor { + if repeats == 1 { + return kv.clone(); + } + let nkv = kv.shape()[1]; + let d = kv.shape()[2]; + let dtype = kv.dtype(); + let device = kv.device(); + let head_bytes = d * dtype.size_bytes(); + let out = Tensor::empty(&[1, nkv * repeats, d], dtype, device); + for h in 0..nkv { + let src = unsafe { (kv.data_ptr() as *const u8).add(h * head_bytes) }; + for r in 0..repeats { + let dst = unsafe { (out.data_ptr() as *mut u8).add((h * repeats + r) * head_bytes) }; + d2d(dst, src, head_bytes); + } + } + out +} + +/// Load EAGLE3 weights from safetensors, handling int64 d2t specially. +fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap, Vec) { + let st_path = dir.join("model.safetensors"); + assert!( + st_path.exists(), + "Eagle3 model.safetensors not found in {}. Convert with:\n\ + python3 -c \"import torch; from safetensors.torch import save_file; \ + sd=torch.load('pytorch_model.bin', map_location='cpu', weights_only=False); \ + save_file(sd, 'model.safetensors')\"", + dir.display() + ); + + let data = std::fs::read(&st_path) + .unwrap_or_else(|e| panic!("failed to read {}: {e}", st_path.display())); + let st = safetensors::SafeTensors::deserialize(&data) + .unwrap_or_else(|e| panic!("failed to parse {}: {e}", st_path.display())); + + let mut tensors = HashMap::new(); + let mut d2t_vec: Vec = Vec::new(); + + for (name, view) in st.tensors() { + if name == "t2d" { + continue; + } + if name == "d2t" { + let raw = view.data(); + assert_eq!(view.dtype(), safetensors::Dtype::I64); + let n = raw.len() / 8; + d2t_vec = (0..n) + .map(|i| i64::from_le_bytes(raw[i * 8..(i + 1) * 8].try_into().unwrap())) + .collect(); + continue; + } + let dtype = match view.dtype() { + safetensors::Dtype::BF16 => DType::BF16, + safetensors::Dtype::F32 => DType::F32, + safetensors::Dtype::F16 => DType::F16, + other => { + eprintln!("eagle3: skipping {name} with unsupported dtype {other:?}"); + continue; + } + }; + let shape: Vec = view.shape().to_vec(); + let raw = view.data(); + let t = crate::loader::make_tensor(raw, &shape, dtype); + let t = t.to_device(Device::Cuda(device)); + tensors.insert(name.to_string(), t); + } + + assert!( + !d2t_vec.is_empty(), + "d2t tensor not found in eagle3 weights" + ); + (tensors, d2t_vec) +} diff --git a/crates/xserv-model/src/lib.rs b/crates/xserv-model/src/lib.rs index 7112a8c..8af2412 100644 --- a/crates/xserv-model/src/lib.rs +++ b/crates/xserv-model/src/lib.rs @@ -1,5 +1,6 @@ pub mod config; pub mod decode_graph; +pub mod eagle3; pub mod gpt2; pub mod gpt_oss; pub mod gpt_oss_graph; diff --git a/crates/xserv-model/src/loader.rs b/crates/xserv-model/src/loader.rs index dd2c5e5..b5ddd97 100644 --- a/crates/xserv-model/src/loader.rs +++ b/crates/xserv-model/src/loader.rs @@ -68,7 +68,7 @@ pub fn load_model_dir(dir: &Path, device: Device) -> HashMap { all_tensors } -fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor { +pub(crate) fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor { match dtype { DType::F32 => { let floats: &[f32] = unsafe { diff --git a/crates/xserv-model/src/qwen3.rs b/crates/xserv-model/src/qwen3.rs index 3de4beb..1806cb0 100644 --- a/crates/xserv-model/src/qwen3.rs +++ b/crates/xserv-model/src/qwen3.rs @@ -825,6 +825,111 @@ impl Qwen3 { matmul_2d(&x, &self.lm_head_t) } + /// Like `decode_core` but also captures hidden states at 3 specified layer + /// indices (after residual+MLP output). Used by EAGLE3 speculative drafting + /// to feed the draft head with low/mid/high target representations. + pub fn decode_core_with_hidden( + &self, + ids_gpu: *const std::ffi::c_void, + pos_gpu: *const std::ffi::c_void, + batch: usize, + seq_slots: &[usize], + paged_cache: &mut PagedKVCache, + hook_layers: &[usize; 3], + ) -> (Tensor, [Tensor; 3]) { + let num_heads = self.local_num_heads; + let num_kv_heads = self.local_num_kv_heads; + let head_dim = self.config.head_dim(); + let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32; + + let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32; + let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32; + let max_blocks = paged_cache.max_blocks_per_seq(); + + let mut x = embedding_device_ids(&self.embed_tokens, ids_gpu, batch); + let mut hooks: [Option; 3] = [None, None, None]; + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let residual = x.clone(); + let normed = rmsnorm(&x, &layer.input_norm, eps); + + let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); + let q_dim = num_heads * head_dim; + let kv_dim = num_kv_heads * head_dim; + let q_all = qkv.narrow(1, 0, q_dim); + let k_all = qkv.narrow(1, q_dim, kv_dim); + let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim); + + let q_flat = q_all.contiguous().reshape(&[batch * num_heads, head_dim]); + let k_flat = k_all + .contiguous() + .reshape(&[batch * num_kv_heads, head_dim]); + let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps); + let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps); + + let q_3d = q_normed.reshape(&[batch, num_heads, head_dim]); + let k_3d = k_normed.reshape(&[batch, num_kv_heads, head_dim]); + rope_inplace_device_pos(&q_3d, &self.rope_cache, pos_gpu); + rope_inplace_device_pos(&k_3d, &self.rope_cache, pos_gpu); + + let v_3d = v_all.contiguous().reshape(&[batch, num_kv_heads, head_dim]); + + paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, batch); + + let q_4d = q_3d.reshape(&[batch, num_heads, 1, head_dim]); + let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void; + let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void; + let attn_out = xserv_kernels::paged_decode_attention( + &q_4d, + k_pool_ptr, + v_pool_ptr, + bt_ptr, + cl_ptr, + batch, + num_heads, + num_kv_heads, + head_dim, + max_blocks, + ); + + let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]); + let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt); + self.all_reduce(&attn_proj); + + let (normed, x_new) = + xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); + let residual = x_new.clone(); + + let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt); + let ffn_dim = gate_up.shape()[1] / 2; + let gate = gate_up.narrow(1, 0, ffn_dim).contiguous(); + let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous(); + let hidden_states = xserv_kernels::silu_mul(&gate, &up); + let down = matmul_2d(&hidden_states, &layer.down_proj_wt); + self.all_reduce(&down); + x = add_any(&residual, &down); + + for (h_idx, &h_layer) in hook_layers.iter().enumerate() { + if layer_idx == h_layer { + hooks[h_idx] = Some(x.clone()); + } + } + } + + for &slot in seq_slots { + paged_cache.advance_seq_len(slot, 1); + } + + let x = rmsnorm(&x, &self.norm, eps); + let logits = matmul_2d(&x, &self.lm_head_t); + let hidden_arr = [ + hooks[0].take().expect("hook layer 0 not reached"), + hooks[1].take().expect("hook layer 1 not reached"), + hooks[2].take().expect("hook layer 2 not reached"), + ]; + (logits, hidden_arr) + } + /// Paged prefill: write a sequence of `new_tokens` K/V into the paged /// cache for `slot`, run flash attention via gathered contiguous K/V. /// Returns logits [new_tokens, vocab_size]. @@ -1074,6 +1179,12 @@ impl Qwen3 { matmul_2d(&x, &self.lm_head_t) } + /// Reference to the target's token embedding table. Shared (not copied) + /// with speculative draft heads like EAGLE3. + pub fn embed_tokens_tensor(&self) -> &Tensor { + &self.embed_tokens + } + /// Extract weight pointers for CUDA Graph capture. pub fn layer_weight_ptrs(&self) -> Vec { self.layers