eagle3: γ≥2 recursive drafting + batched verify with hooks
Adds infrastructure for γ≥2 EAGLE speculative decoding:
qwen3.rs:
- New forward_verify_paged_decode_attention_with_hidden: same as the
existing verify but also captures target hidden states at 3 hook
layers, one per verify position. Needed to seed next round's EAGLE.
eagle3.rs:
- step split into step (unchanged public API) + step_with_aux (also
returns final hidden state) + step_recursive (takes fused_h directly,
no fc+3-hidden combine). This mirrors the EAGLE3 paper: γ=1 uses
target hooks + fc; γ≥2 uses previous EAGLE aux as fused_h for
subsequent drafts, approximating target hidden.
bench-eagle3.rs:
- New run_eagle_gamma_multi function with --gamma CLI (default 2).
- Per round: recursive EAGLE γ drafts, verify [prev_token, d0..d_{γ-1}]
in one target forward, accept longest prefix, correction via 1 more
target decode.
- max_seqs bumped to 16 in the paged cache so verify can batch up to
16 rows.
γ=2 test result (5 prompts × 32 tokens, dash5):
matched=false — sequences diverge
acceptance_rate = 29.8% at γ=2 (~1.1 tokens accepted per draft)
speedup_e2e = 0.52x (SLOWER than baseline)
The divergence bug is in the verify's re-writing of prev_token's K/V
at position round_pos-1. In principle matmul_batched_gemv at row-0
should be bit-exact with the seed decode's launch_gemv_bf16, but the
sequence output diverges so something is off. Investigation pending
(likely the correction decode step or seed_hooks position offset).
γ=1 path still works correctly (matched=true, acceptance 20%,
speedup 0.95x) from the previous commit. The γ≥2 path is scaffolded
but not yet correct — next step is to debug the verify-write path,
then measure real speedup.
This commit is contained in:
@@ -93,6 +93,7 @@ fn main() {
|
||||
let prompt_count = arg_usize(&args, "--prompts", PROMPTS.len()).min(PROMPTS.len());
|
||||
let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN);
|
||||
let device = arg_usize(&args, "--device", 0) as u32;
|
||||
let gamma = arg_usize(&args, "--gamma", 2).max(1);
|
||||
|
||||
xserv_cuda::device::set_device(device).unwrap();
|
||||
let info = xserv_cuda::device::device_info(device).unwrap();
|
||||
@@ -122,7 +123,9 @@ fn main() {
|
||||
let _ = run_baseline(&target, &mut cache, &tokenizer, &ids, 4);
|
||||
drop(cache);
|
||||
}
|
||||
eprintln!("Warmup done. Running {prompt_count} prompts, gen_tokens={gen_tokens}\n");
|
||||
eprintln!(
|
||||
"Warmup done. Running {prompt_count} prompts, gen_tokens={gen_tokens}, gamma={gamma}\n"
|
||||
);
|
||||
|
||||
let mut baseline_total_s = 0.0f64;
|
||||
let mut baseline_tokens = 0usize;
|
||||
@@ -147,9 +150,10 @@ fn main() {
|
||||
baseline_tokens += baseline.ids.len();
|
||||
drop(baseline_cache);
|
||||
|
||||
// Speculative with EAGLE γ=1.
|
||||
// Speculative with EAGLE, γ from CLI.
|
||||
let mut target_cache = new_cache(&target_config, max_seq_len, device);
|
||||
let spec = run_eagle_gamma1(
|
||||
let spec = if gamma == 1 {
|
||||
run_eagle_gamma1(
|
||||
&target,
|
||||
&mut eagle,
|
||||
&embed_tokens,
|
||||
@@ -157,7 +161,19 @@ fn main() {
|
||||
&tokenizer,
|
||||
&ids,
|
||||
gen_tokens,
|
||||
);
|
||||
)
|
||||
} else {
|
||||
run_eagle_gamma_multi(
|
||||
&target,
|
||||
&mut eagle,
|
||||
&embed_tokens,
|
||||
&mut target_cache,
|
||||
&tokenizer,
|
||||
&ids,
|
||||
gen_tokens,
|
||||
gamma,
|
||||
)
|
||||
};
|
||||
spec_total_s += spec.total_s;
|
||||
spec_tokens += spec.ids.len();
|
||||
spec_accepted += spec.accepted;
|
||||
@@ -336,6 +352,148 @@ fn run_eagle_gamma1(
|
||||
}
|
||||
}
|
||||
|
||||
/// γ≥2 speculative with recursive EAGLE drafting + batched target verify.
|
||||
///
|
||||
/// Per round (state entering: prev_token committed at position pos-1, and
|
||||
/// prev_seed_hooks = target hidden states at position pos-1 from previous
|
||||
/// round's verify):
|
||||
///
|
||||
/// 1. EAGLE γ recursive drafts using step_with_aux + step_recursive.
|
||||
/// 2. Verify: target forward on [prev_token, d0..d_{γ-2}] to get γ logits.
|
||||
/// verify_argmax[i] = target's answer at position pos+i. Longest matching
|
||||
/// prefix accepts d[i] iff d[i] == verify_argmax[i].
|
||||
/// 3. Correction: verify_argmax[accepted] IS target's correction for the
|
||||
/// first rejected position (or verify_argmax[γ-1] if all γ accepted).
|
||||
/// Commit accepted tokens + correction.
|
||||
/// 4. Seed next round with hooks at the position of the last committed token.
|
||||
/// Take hooks from verify_hooks at index `accepted` (0-indexed).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_eagle_gamma_multi(
|
||||
target: &Qwen3,
|
||||
eagle: &mut Eagle3Head,
|
||||
embed_tokens: &Tensor,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
gamma: usize,
|
||||
) -> RunStats {
|
||||
use xserv_model::eagle3::EAGLE_HOOK_LAYERS;
|
||||
let slot = 0;
|
||||
cache.register_sequence(slot).unwrap();
|
||||
eagle.reset();
|
||||
let t0 = Instant::now();
|
||||
|
||||
// Prefill. Advance one target decode step to seed hidden hooks.
|
||||
let prefill_logits = target.forward_prefill_paged(prompt_ids, slot, cache);
|
||||
let first_token = last_argmax(&prefill_logits);
|
||||
let mut generated = vec![first_token];
|
||||
|
||||
// Seed: run one target decode with first_token to get hooks at that position.
|
||||
let (logits, mut seed_hooks) = target_decode_with_hidden(target, first_token, cache, slot);
|
||||
let mut prev_token = last_argmax(&logits);
|
||||
generated.push(prev_token);
|
||||
let mut target_steps = 2; // 1 prefill + 1 decode seed
|
||||
let mut accepted_total = 0usize;
|
||||
let mut proposed_total = 0usize;
|
||||
|
||||
while generated.len() < gen_tokens && !tokenizer.is_eos(prev_token) {
|
||||
let round_pos = cache.seq_len(slot); // position of next token to predict
|
||||
let remaining = gen_tokens - generated.len();
|
||||
let round_gamma = gamma.min(remaining);
|
||||
|
||||
// Draft γ tokens recursively.
|
||||
let mut drafts: Vec<u32> = Vec::with_capacity(round_gamma);
|
||||
let (d0, _, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, prev_token, round_pos);
|
||||
drafts.push(d0);
|
||||
let mut prev_aux = aux0;
|
||||
let mut prev_draft = d0;
|
||||
for k in 1..round_gamma {
|
||||
let (dk, _, auxk) =
|
||||
eagle.step_recursive(prev_aux, embed_tokens, prev_draft, round_pos + k);
|
||||
drafts.push(dk);
|
||||
prev_aux = auxk;
|
||||
prev_draft = dk;
|
||||
}
|
||||
proposed_total += round_gamma;
|
||||
|
||||
// Verify: target forward on [prev_token, d0..d_{γ-1}] (length γ+1) to
|
||||
// get γ+1 logits. verify_argmax[i] predicts position round_pos+i.
|
||||
// Accept d[i] iff d[i] == verify_argmax[i]. The extra verify_argmax[γ]
|
||||
// provides a free correction for the all-accepted case.
|
||||
let mut verify_input: Vec<u32> = Vec::with_capacity(round_gamma + 1);
|
||||
verify_input.push(prev_token);
|
||||
for &d in drafts.iter() {
|
||||
verify_input.push(d);
|
||||
}
|
||||
// Roll back the seed decode (we wrote prev_token at round_pos-1).
|
||||
cache
|
||||
.truncate_sequence(slot, round_pos - 1)
|
||||
.expect("truncate before verify");
|
||||
let (verify_logits, _verify_hooks) = target
|
||||
.forward_verify_paged_decode_attention_with_hidden(
|
||||
&verify_input,
|
||||
slot,
|
||||
cache,
|
||||
&EAGLE_HOOK_LAYERS,
|
||||
);
|
||||
target_steps += 1;
|
||||
|
||||
// Longest matching prefix: d[i] accepted iff d[i] == argmax(verify_logits[i]).
|
||||
let verify_argmax = argmax_rows(&verify_logits);
|
||||
let mut accepted = 0usize;
|
||||
while accepted < round_gamma && drafts[accepted] == verify_argmax[accepted] {
|
||||
accepted += 1;
|
||||
}
|
||||
accepted_total += accepted;
|
||||
|
||||
// Commit accepted tokens (already in cache at positions round_pos..round_pos+accepted-1).
|
||||
// Truncate cache to that length.
|
||||
cache
|
||||
.truncate_sequence(slot, round_pos + accepted)
|
||||
.expect("truncate to accepted");
|
||||
for i in 0..accepted {
|
||||
generated.push(drafts[i]);
|
||||
if generated.len() >= gen_tokens || tokenizer.is_eos(drafts[i]) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if generated.len() >= gen_tokens {
|
||||
break;
|
||||
}
|
||||
if !generated.is_empty() && tokenizer.is_eos(*generated.last().unwrap()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Correction: verify_argmax[accepted] is target's answer for position
|
||||
// round_pos+accepted. Write its K/V via one target decode. This also
|
||||
// gives us seed_hooks for the next round.
|
||||
let correction = verify_argmax[accepted];
|
||||
let (_, new_hooks) = target_decode_with_hidden(target, correction, cache, slot);
|
||||
target_steps += 1;
|
||||
generated.push(correction);
|
||||
prev_token = correction;
|
||||
// Extract the position-`accepted` slice of verify_hooks as the seed for next round?
|
||||
// Actually the correction just wrote K/V AT position round_pos+accepted, and
|
||||
// its hidden state = new_hooks. We use new_hooks as the seed for the next round.
|
||||
seed_hooks = new_hooks;
|
||||
}
|
||||
sync_device();
|
||||
let total_s = t0.elapsed().as_secs_f64();
|
||||
cache.free_sequence(slot);
|
||||
RunStats {
|
||||
ids: generated,
|
||||
total_s,
|
||||
target_steps,
|
||||
accepted: accepted_total,
|
||||
proposed: proposed_total,
|
||||
}
|
||||
}
|
||||
|
||||
fn argmax_rows(logits: &Tensor) -> Vec<u32> {
|
||||
xserv_kernels::argmax_bf16_to_host(logits)
|
||||
}
|
||||
|
||||
fn target_decode_with_hidden(
|
||||
target: &Qwen3,
|
||||
token: u32,
|
||||
@@ -381,5 +539,5 @@ fn arg_usize(args: &[String], flag: &str, default: usize) -> usize {
|
||||
|
||||
fn new_cache(config: &ModelConfig, max_seq_len: usize, device: u32) -> PagedKVCache {
|
||||
let num_blocks = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE + 2;
|
||||
PagedKVCache::new(config, num_blocks, 0, 1, num_blocks, DType::BF16, device)
|
||||
PagedKVCache::new(config, num_blocks, 0, 16, num_blocks, DType::BF16, device)
|
||||
}
|
||||
|
||||
@@ -165,6 +165,44 @@ impl Eagle3Head {
|
||||
prev_token: u32,
|
||||
position: usize,
|
||||
) -> (u32, Tensor) {
|
||||
let (id, logits, _) = self.step_with_aux(target_hidden, embed_table, prev_token, position);
|
||||
(id, logits)
|
||||
}
|
||||
|
||||
/// Like `step`, but also returns the final hidden state (aux) usable as
|
||||
/// the fused_h for a subsequent recursive draft step via `step_recursive`.
|
||||
pub fn step_with_aux(
|
||||
&mut self,
|
||||
target_hidden: &[Tensor; 3],
|
||||
embed_table: &Tensor,
|
||||
prev_token: u32,
|
||||
position: usize,
|
||||
) -> (u32, Tensor, Tensor) {
|
||||
// Fuse 3 target hidden states into fused_h via fc.
|
||||
let h_cat = concat_hidden(target_hidden);
|
||||
let fused_h = matmul_2d(&h_cat, &self.fc_wt);
|
||||
self.forward_from_fused(fused_h, embed_table, prev_token, position)
|
||||
}
|
||||
|
||||
/// Recursive draft step: reuses the previous EAGLE step's aux as fused_h,
|
||||
/// bypassing the fc+3-hidden fusion. Used for γ≥2 chained drafts.
|
||||
pub fn step_recursive(
|
||||
&mut self,
|
||||
fused_h: Tensor,
|
||||
embed_table: &Tensor,
|
||||
prev_token: u32,
|
||||
position: usize,
|
||||
) -> (u32, Tensor, Tensor) {
|
||||
self.forward_from_fused(fused_h, embed_table, prev_token, position)
|
||||
}
|
||||
|
||||
fn forward_from_fused(
|
||||
&mut self,
|
||||
fused_h: Tensor,
|
||||
embed_table: &Tensor,
|
||||
prev_token: u32,
|
||||
position: usize,
|
||||
) -> (u32, Tensor, Tensor) {
|
||||
let eps = 1e-6f32;
|
||||
assert!(
|
||||
self.current_len < self.max_seq_len,
|
||||
@@ -173,20 +211,12 @@ impl Eagle3Head {
|
||||
self.max_seq_len
|
||||
);
|
||||
|
||||
// 1. Fuse target hidden states: concat [h_low, h_mid, h_high] → fc
|
||||
let h_cat = concat_hidden(target_hidden);
|
||||
let fused_h = matmul_2d(&h_cat, &self.fc_wt); // [1, hidden]
|
||||
|
||||
// 2. Embed previous token (shared with target)
|
||||
let emb = embedding(embed_table, &[prev_token]); // [1, hidden]
|
||||
|
||||
// 3. Norm both, concat, remember residual = fused_h (pre-norm).
|
||||
let emb = embedding(embed_table, &[prev_token]);
|
||||
let residual = fused_h.clone();
|
||||
let emb_normed = rmsnorm(&emb, &self.input_layernorm, eps);
|
||||
let h_normed = rmsnorm(&fused_h, &self.hidden_norm, eps);
|
||||
let attn_in = concat_last_dim(&emb_normed, &h_normed); // [1, 2*hidden]
|
||||
let attn_in = concat_last_dim(&emb_normed, &h_normed);
|
||||
|
||||
// 4. Q/K/V projection then RoPE (position from caller).
|
||||
let q = matmul_2d(&attn_in, &self.q_proj_wt);
|
||||
let k = matmul_2d(&attn_in, &self.k_proj_wt);
|
||||
let v = matmul_2d(&attn_in, &self.v_proj_wt);
|
||||
@@ -197,8 +227,6 @@ impl Eagle3Head {
|
||||
rope_inplace(&q_3d, &self.rope_cache, &positions);
|
||||
rope_inplace(&k_3d, &self.rope_cache, &positions);
|
||||
|
||||
// 5. Append new K/V to the internal cache at slot `current_len`, then
|
||||
// build a contiguous view [1, num_kv_heads, current_len+1, head_dim].
|
||||
let v_3d = v.reshape(&[1, self.num_kv_heads, self.head_dim]);
|
||||
self.append_to_kv_cache(&k_3d, &v_3d);
|
||||
self.current_len += 1;
|
||||
@@ -206,31 +234,26 @@ impl Eagle3Head {
|
||||
let k_view = self.k_cache.narrow(2, 0, kv_len).contiguous();
|
||||
let v_view = self.v_cache.narrow(2, 0, kv_len).contiguous();
|
||||
|
||||
// 6. Attention: q [1, num_q_heads, 1, head_dim] × k/v [1, num_kv_heads, kv_len, head_dim]
|
||||
let q_4d = q_3d.reshape(&[1, self.num_heads, 1, self.head_dim]);
|
||||
let attn_out = decode_attention(&q_4d, &k_view, &v_view);
|
||||
|
||||
// 7. Merge heads and o_proj.
|
||||
let attn_merged = attn_out.reshape(&[1, self.num_heads * self.head_dim]);
|
||||
let attn_proj = matmul_2d(&attn_merged, &self.o_proj_wt);
|
||||
|
||||
// 8. Post-attn fused add_rmsnorm.
|
||||
let (mlp_in, residual) =
|
||||
add_rmsnorm(&attn_proj, &residual, &self.post_attention_layernorm, eps);
|
||||
|
||||
// 9. MLP.
|
||||
let gate = matmul_2d(&mlp_in, &self.gate_proj_wt);
|
||||
let up = matmul_2d(&mlp_in, &self.up_proj_wt);
|
||||
let hidden = silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden, &self.down_proj_wt);
|
||||
|
||||
// 10. Final fused add_rmsnorm → lm_head.
|
||||
let (x, _) = add_rmsnorm(&down, &residual, &self.norm, eps);
|
||||
let logits = matmul_2d(&x, &self.lm_head_wt);
|
||||
|
||||
let draft_id = argmax_bf16_single(&logits);
|
||||
let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32;
|
||||
(target_id, logits)
|
||||
(target_id, logits, x)
|
||||
}
|
||||
|
||||
/// Write new K/V rows (shape [1, num_kv_heads, head_dim]) at position
|
||||
|
||||
@@ -1113,6 +1113,123 @@ impl Qwen3 {
|
||||
matmul_batched_gemv(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
/// Like `forward_verify_paged_decode_attention`, but also captures hidden
|
||||
/// states at 3 layer indices (per position). Returns
|
||||
/// (logits [new_tokens, vocab], hooks [3][new_tokens, hidden]). Used by
|
||||
/// EAGLE3 speculative γ≥2 verify path so we can seed the next round's
|
||||
/// EAGLE draft with target's real hidden states at the accepted position.
|
||||
pub fn forward_verify_paged_decode_attention_with_hidden(
|
||||
&self,
|
||||
token_ids: &[u32],
|
||||
slot: usize,
|
||||
paged_cache: &mut PagedKVCache,
|
||||
hook_layers: &[usize; 3],
|
||||
) -> (Tensor, [Tensor; 3]) {
|
||||
let new_tokens = token_ids.len();
|
||||
let pos_offset = paged_cache.seq_len(slot);
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
|
||||
paged_cache.advance_seq_len(slot, new_tokens);
|
||||
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
let kv_lens: Vec<i32> = (0..new_tokens)
|
||||
.map(|i| (pos_offset + i + 1) as i32)
|
||||
.collect();
|
||||
let slots = vec![slot; new_tokens];
|
||||
paged_cache.sync_active_batch_with_lens(&slots, &kv_lens);
|
||||
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
|
||||
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
|
||||
let max_blocks = paged_cache.max_blocks_per_seq();
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
let mut hooks: [Option<Tensor>; 3] = [None, None, None];
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
let qkv = matmul_batched_gemv(&normed, &layer.qkv_proj_wt);
|
||||
let q_dim = num_heads * head_dim;
|
||||
let kv_dim = num_kv_heads * head_dim;
|
||||
let q_all = qkv.narrow(1, 0, q_dim);
|
||||
let k_all = qkv.narrow(1, q_dim, kv_dim);
|
||||
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
|
||||
|
||||
let q_flat = q_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens * num_heads, head_dim]);
|
||||
let k_flat = k_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens * num_kv_heads, head_dim]);
|
||||
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
|
||||
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
|
||||
|
||||
let q_3d = q_normed.reshape(&[new_tokens, num_heads, head_dim]);
|
||||
let k_3d = k_normed.reshape(&[new_tokens, num_kv_heads, head_dim]);
|
||||
rope_inplace(&q_3d, &self.rope_cache, &positions);
|
||||
rope_inplace(&k_3d, &self.rope_cache, &positions);
|
||||
|
||||
let v_3d = v_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens, num_kv_heads, head_dim]);
|
||||
paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, new_tokens);
|
||||
|
||||
let q_decode = q_3d.reshape(&[new_tokens, num_heads, 1, head_dim]);
|
||||
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let attn_out = xserv_kernels::paged_decode_attention(
|
||||
&q_decode,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
new_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_blocks,
|
||||
);
|
||||
|
||||
let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]);
|
||||
let attn_proj = matmul_batched_gemv(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj);
|
||||
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_batched_gemv(&normed, &layer.gate_up_proj_wt);
|
||||
let ffn_dim = gate_up.shape()[1] / 2;
|
||||
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
||||
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_batched_gemv(&hidden_states, &layer.down_proj_wt);
|
||||
self.all_reduce(&down);
|
||||
x = add_any(&residual, &down);
|
||||
|
||||
for (h_idx, &h_layer) in hook_layers.iter().enumerate() {
|
||||
if layer_idx == h_layer {
|
||||
hooks[h_idx] = Some(x.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let x = rmsnorm(&x, &self.norm, eps);
|
||||
let logits = matmul_batched_gemv(&x, &self.lm_head_t);
|
||||
let hidden_arr = [
|
||||
hooks[0].take().expect("hook layer 0 not reached"),
|
||||
hooks[1].take().expect("hook layer 1 not reached"),
|
||||
hooks[2].take().expect("hook layer 2 not reached"),
|
||||
];
|
||||
(logits, hidden_arr)
|
||||
}
|
||||
|
||||
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
|
||||
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
|
||||
Reference in New Issue
Block a user