speculative: Qwen3 decode graph + gamma sweep (Phase 24 step 2)
- Split Qwen3::forward_decode_paged into decode_prepare (host-side block allocation + table upload) and decode_core (pure-GPU compute reading token ids and positions from device buffers via embedding_device_ids + rope_inplace_device_pos). This makes the entire Qwen3 decode step CUDA-graph-capturable, mirroring the gpt_oss.rs architecture. - Add qwen3_graph.rs: Qwen3DecodeGraph + GraphedQwen3Decoder, a port of the gpt_oss_graph.rs whole-step capture pattern. Lazy policy: first decode eager (warms pool + cuBLAS), second captures, rest replay. Batch>1 always falls back to eager. - Wire GraphedQwen3Decoder into bench-speculative's draft decode path; all 4 draft.forward_decode_paged call sites + replay_draft_tokens now route through the graphed decoder. Per-benchmark caches persist across prompts for graph reuse. - Gamma sweep result (10 prompts × 32 tokens, --use-verify-logits): γ=1 → 0.57×, γ=2 → 0.57×, γ=4 → 0.49×, γ=6 → 0.41×, γ=8 → 0.36×. All matched=true, verify_decode_mismatches=0. Acceptance drops sharply with γ (66% → 40% → 25%) because Qwen3-0.6B is too inaccurate a draft for Qwen3-8B. Speedup still <1. Current ceiling analysis: verify costs ~13ms (same as one target decode) so speculative decoding only wins if acceptance × (tokens/round) >> (draft_cost + verify_cost) / baseline_decode. With this draft model, the crossover requires either (a) a much smaller verify cost (batch-GEMM path, which trades correctness), or (b) a fundamentally better drafter (EAGLE-style heads, or n-gram lookup).
This commit is contained in:
@@ -10,6 +10,7 @@ use half::bf16;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Instant;
|
||||
|
||||
use xserv_model::qwen3_graph::GraphedQwen3Decoder;
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
@@ -222,12 +223,14 @@ fn main() {
|
||||
let mut target_verify_cache =
|
||||
new_cache_with_rows(&target_config, max_seq_len, device, gamma);
|
||||
let mut draft_cache = new_cache(&draft_config, max_seq_len, device);
|
||||
let mut draft_decoder = GraphedQwen3Decoder::new();
|
||||
let _ = run_speculative(
|
||||
&target,
|
||||
&draft,
|
||||
&mut target_cache,
|
||||
&mut target_verify_cache,
|
||||
&mut draft_cache,
|
||||
&mut draft_decoder,
|
||||
&tokenizer,
|
||||
&warm_ids,
|
||||
warm_tokens,
|
||||
@@ -248,6 +251,21 @@ fn main() {
|
||||
);
|
||||
|
||||
let mut totals = Totals::default();
|
||||
|
||||
// Persistent per-benchmark caches so the draft CUDA graph (Phase 24) can be
|
||||
// captured once and replayed across every prompt. Freeing and re-registering
|
||||
// slot 0 between prompts keeps block_table_gpu / context_lens_gpu addresses
|
||||
// stable, which is exactly what the graph captured.
|
||||
let mut target_cache = new_cache_with_rows(
|
||||
&target_config,
|
||||
max_seq_len,
|
||||
device,
|
||||
if use_verify_logits { gamma } else { 1 },
|
||||
);
|
||||
let mut target_verify_cache = new_cache_with_rows(&target_config, max_seq_len, device, gamma);
|
||||
let mut draft_cache = new_cache(&draft_config, max_seq_len, device);
|
||||
let mut draft_decoder = GraphedQwen3Decoder::new();
|
||||
|
||||
for (i, prompt) in PROMPTS.iter().take(prompt_count).enumerate() {
|
||||
let ids = tokenizer.encode(prompt);
|
||||
validate_length_budget(&ids, gen_tokens, max_seq_len, prompt);
|
||||
@@ -255,21 +273,13 @@ fn main() {
|
||||
let baseline = run_baseline(&target, &mut baseline_cache, &tokenizer, &ids, gen_tokens);
|
||||
drop(baseline_cache);
|
||||
|
||||
let mut target_cache = new_cache_with_rows(
|
||||
&target_config,
|
||||
max_seq_len,
|
||||
device,
|
||||
if use_verify_logits { gamma } else { 1 },
|
||||
);
|
||||
let mut target_verify_cache =
|
||||
new_cache_with_rows(&target_config, max_seq_len, device, gamma);
|
||||
let mut draft_cache = new_cache(&draft_config, max_seq_len, device);
|
||||
let spec = run_speculative(
|
||||
&target,
|
||||
&draft,
|
||||
&mut target_cache,
|
||||
&mut target_verify_cache,
|
||||
&mut draft_cache,
|
||||
&mut draft_decoder,
|
||||
&tokenizer,
|
||||
&ids,
|
||||
gen_tokens,
|
||||
@@ -438,6 +448,7 @@ fn run_speculative(
|
||||
target_cache: &mut PagedKVCache,
|
||||
target_verify_cache: &mut PagedKVCache,
|
||||
draft_cache: &mut PagedKVCache,
|
||||
draft_decoder: &mut GraphedQwen3Decoder,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
@@ -504,7 +515,7 @@ fn run_speculative(
|
||||
break;
|
||||
}
|
||||
let pos = draft_cache.seq_len(slot);
|
||||
let logits = draft.forward_decode_paged(&[token], &[pos], &[slot], draft_cache);
|
||||
let logits = draft_decoder.decode(draft, &[token], &[pos], &[slot], draft_cache);
|
||||
draft_next = last_argmax(&logits);
|
||||
}
|
||||
proposed_total += draft_tokens.len();
|
||||
@@ -572,6 +583,7 @@ fn run_speculative(
|
||||
.unwrap();
|
||||
replay_draft_tokens(
|
||||
draft,
|
||||
draft_decoder,
|
||||
draft_cache,
|
||||
slot,
|
||||
&draft_tokens[..accepted],
|
||||
@@ -588,7 +600,7 @@ fn run_speculative(
|
||||
commit_steps += 1;
|
||||
|
||||
let pos = draft_cache.seq_len(slot);
|
||||
let logits = draft.forward_decode_paged(&[correction], &[pos], &[slot], draft_cache);
|
||||
let logits = draft_decoder.decode(draft, &[correction], &[pos], &[slot], draft_cache);
|
||||
draft_next = last_argmax(&logits);
|
||||
correction_steps += 1;
|
||||
continue;
|
||||
@@ -690,6 +702,7 @@ fn run_speculative(
|
||||
.unwrap();
|
||||
replay_draft_tokens(
|
||||
draft,
|
||||
draft_decoder,
|
||||
draft_cache,
|
||||
slot,
|
||||
&draft_tokens[..accepted],
|
||||
@@ -709,7 +722,7 @@ fn run_speculative(
|
||||
mirror_steps += 1;
|
||||
|
||||
let pos = draft_cache.seq_len(slot);
|
||||
let logits = draft.forward_decode_paged(&[correction], &[pos], &[slot], draft_cache);
|
||||
let logits = draft_decoder.decode(draft, &[correction], &[pos], &[slot], draft_cache);
|
||||
draft_next = last_argmax(&logits);
|
||||
correction_steps += 1;
|
||||
}
|
||||
@@ -745,6 +758,7 @@ fn advance_target_cache(target: &Qwen3, cache: &mut PagedKVCache, slot: usize, t
|
||||
|
||||
fn replay_draft_tokens(
|
||||
draft: &Qwen3,
|
||||
draft_decoder: &mut GraphedQwen3Decoder,
|
||||
cache: &mut PagedKVCache,
|
||||
slot: usize,
|
||||
tokens: &[u32],
|
||||
@@ -752,7 +766,7 @@ fn replay_draft_tokens(
|
||||
) {
|
||||
for &token in tokens {
|
||||
let pos = cache.seq_len(slot);
|
||||
let logits = draft.forward_decode_paged(&[token], &[pos], &[slot], cache);
|
||||
let logits = draft_decoder.decode(draft, &[token], &[pos], &[slot], cache);
|
||||
*next = last_argmax(&logits);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ pub mod kv_cache;
|
||||
pub mod loader;
|
||||
pub mod paged_kv_cache;
|
||||
pub mod qwen3;
|
||||
pub mod qwen3_graph;
|
||||
pub mod sampling;
|
||||
|
||||
pub use config::ModelConfig;
|
||||
|
||||
@@ -701,45 +701,72 @@ impl Qwen3 {
|
||||
assert_eq!(seq_slots.len(), batch);
|
||||
assert!(batch > 0);
|
||||
|
||||
// TP: this rank owns a slice of the heads (local_* == full when world==1).
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
self.decode_prepare(positions, seq_slots, paged_cache);
|
||||
|
||||
// Ensure all slots have enough physical blocks for this token, then
|
||||
// upload block tables + context_lens once for the whole forward (the
|
||||
// tables are identical across layers; only the layer's K/V pool changes).
|
||||
let ids_gpu = upload_u32(tokens);
|
||||
let positions_u32: Vec<u32> = positions.iter().map(|&p| p as u32).collect();
|
||||
let pos_gpu = upload_u32(&positions_u32);
|
||||
let logits = self.decode_core(
|
||||
ids_gpu.as_ptr() as *const std::ffi::c_void,
|
||||
pos_gpu.as_ptr() as *const std::ffi::c_void,
|
||||
batch,
|
||||
seq_slots,
|
||||
paged_cache,
|
||||
);
|
||||
logits
|
||||
}
|
||||
|
||||
/// Host-side per-step cache bookkeeping: block allocation + uploading block
|
||||
/// tables / context lens to their (stable-address) GPU buffers. Runs
|
||||
/// OUTSIDE any CUDA-graph captured region.
|
||||
pub fn decode_prepare(
|
||||
&self,
|
||||
positions: &[usize],
|
||||
seq_slots: &[usize],
|
||||
paged_cache: &mut PagedKVCache,
|
||||
) {
|
||||
let kv_lens: Vec<i32> = positions.iter().map(|&p| (p + 1) as i32).collect();
|
||||
for (b, &slot) in seq_slots.iter().enumerate() {
|
||||
paged_cache.ensure_capacity(slot, positions[b] + 1);
|
||||
}
|
||||
paged_cache.sync_active_batch_with_lens(seq_slots, &kv_lens);
|
||||
}
|
||||
|
||||
/// Pure-GPU decode step: embedding → all layers → final norm → logits.
|
||||
/// Token ids and positions are read from device buffers; every other input
|
||||
/// (weights, KV pools, block table, context lens) has a stable address —
|
||||
/// which makes this region CUDA-graph capturable.
|
||||
pub fn decode_core(
|
||||
&self,
|
||||
ids_gpu: *const std::ffi::c_void,
|
||||
pos_gpu: *const std::ffi::c_void,
|
||||
batch: usize,
|
||||
seq_slots: &[usize],
|
||||
paged_cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
|
||||
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
|
||||
let max_blocks = paged_cache.max_blocks_per_seq();
|
||||
|
||||
// RoPE expects `[num_tokens, H, D]` with `num_tokens` positions —
|
||||
// matches our `[B, H, D]` exactly, so we upload once here.
|
||||
let positions_u32: Vec<u32> = positions.iter().map(|&p| p as u32).collect();
|
||||
|
||||
// Batched embedding: [B, hidden]
|
||||
let mut x = embedding(&self.embed_tokens, tokens);
|
||||
let mut x = embedding_device_ids(&self.embed_tokens, ids_gpu, batch);
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
// Fused QKV projection: one GEMV instead of three.
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); // [B, (H+2*KV)*D]
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q_dim = num_heads * head_dim;
|
||||
let kv_dim = num_kv_heads * head_dim;
|
||||
let q_all = qkv.narrow(1, 0, q_dim); // [B, H*D] (view)
|
||||
let k_all = qkv.narrow(1, q_dim, kv_dim); // [B, KV*D] (view)
|
||||
let q_all = qkv.narrow(1, 0, q_dim);
|
||||
let k_all = qkv.narrow(1, q_dim, kv_dim);
|
||||
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
|
||||
|
||||
// Per-head RMSNorm on contiguous copies (narrow views are strided).
|
||||
let q_flat = q_all.contiguous().reshape(&[batch * num_heads, head_dim]);
|
||||
let k_flat = k_all
|
||||
.contiguous()
|
||||
@@ -749,16 +776,13 @@ impl Qwen3 {
|
||||
|
||||
let q_3d = q_normed.reshape(&[batch, num_heads, head_dim]);
|
||||
let k_3d = k_normed.reshape(&[batch, num_kv_heads, head_dim]);
|
||||
rope_inplace(&q_3d, &self.rope_cache, &positions_u32);
|
||||
rope_inplace(&k_3d, &self.rope_cache, &positions_u32);
|
||||
rope_inplace_device_pos(&q_3d, &self.rope_cache, pos_gpu);
|
||||
rope_inplace_device_pos(&k_3d, &self.rope_cache, pos_gpu);
|
||||
|
||||
let v_3d = v_all.contiguous().reshape(&[batch, num_kv_heads, head_dim]);
|
||||
|
||||
// Single batched scatter for all sequences in the batch.
|
||||
paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, batch);
|
||||
|
||||
// Paged attention reads Q as [B, H, 1, D] — a contiguous view
|
||||
// of [B, H, D].
|
||||
let q_4d = q_3d.reshape(&[batch, num_heads, 1, head_dim]);
|
||||
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
@@ -775,27 +799,24 @@ impl Qwen3 {
|
||||
max_blocks,
|
||||
);
|
||||
|
||||
// attn_out shape [B, H, 1, D] is contiguous-equivalent to [B, H*D].
|
||||
let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj); // TP: sum partial attention outputs
|
||||
self.all_reduce(&attn_proj);
|
||||
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
// Fused gate+up projection: one GEMV instead of two.
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt); // [B, 2*ffn]
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
let ffn_dim = gate_up.shape()[1] / 2;
|
||||
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
||||
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
self.all_reduce(&down); // TP: sum partial MLP outputs
|
||||
self.all_reduce(&down);
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
|
||||
// Advance logical seq_len now that all layers have been written.
|
||||
for &slot in seq_slots {
|
||||
paged_cache.advance_seq_len(slot, 1);
|
||||
}
|
||||
@@ -1261,6 +1282,14 @@ fn row_view(t: &Tensor, row: usize) -> Tensor {
|
||||
)
|
||||
}
|
||||
|
||||
/// Upload a u32 slice to a pooled GPU buffer (synchronous H2D).
|
||||
fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer {
|
||||
let bytes = unsafe { std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) };
|
||||
let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc u32 upload");
|
||||
buf.copy_from_host(bytes).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
/// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy.
|
||||
fn concat_rows(rows: &[Tensor]) -> Tensor {
|
||||
assert!(!rows.is_empty());
|
||||
|
||||
185
crates/xserv-model/src/qwen3_graph.rs
Normal file
185
crates/xserv-model/src/qwen3_graph.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
//! CUDA-graph replay for Qwen3 batch=1 decode (Phase 24 / speculative draft).
|
||||
//!
|
||||
//! Same pattern as `gpt_oss_graph.rs`, but for the Qwen3 dense decode path used
|
||||
//! by speculative decoding's draft model. A Qwen3-0.6B decode step is ~140
|
||||
//! kernel launches; wrapping the whole step into one `cudaGraphLaunch` cuts
|
||||
//! the ~4× γ draft cost per speculative round.
|
||||
//!
|
||||
//! See `gpt_oss_graph.rs` for the design commentary; the capture preconditions,
|
||||
//! retained-warmup mechanism, and quarantine lifetime are all identical here.
|
||||
|
||||
use std::ffi::c_void;
|
||||
|
||||
use xserv_cuda::allocator::{self, RetainedBlocks};
|
||||
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
|
||||
use xserv_tensor::Tensor;
|
||||
|
||||
use crate::paged_kv_cache::PagedKVCache;
|
||||
use crate::qwen3::Qwen3;
|
||||
|
||||
pub struct Qwen3DecodeGraph {
|
||||
stream: CudaStream,
|
||||
graph: CudaGraph,
|
||||
ids_buf: GpuBuffer, // [1] u32, persistent graph input
|
||||
pos_buf: GpuBuffer, // [1] u32, persistent graph input
|
||||
logits: Tensor, // graph output; rewritten in place by every replay
|
||||
_arena: RetainedBlocks,
|
||||
}
|
||||
|
||||
impl Qwen3DecodeGraph {
|
||||
/// Capture one batch=1 decode step and replay it once.
|
||||
pub fn capture(
|
||||
model: &Qwen3,
|
||||
token: u32,
|
||||
position: usize,
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Self {
|
||||
let stream = CudaStream::new().expect("create capture stream");
|
||||
let mut ids_buf = allocator::cached_alloc(4).expect("alloc ids buf");
|
||||
let mut pos_buf = allocator::cached_alloc(4).expect("alloc pos buf");
|
||||
|
||||
model.decode_prepare(&[position], &[slot], cache);
|
||||
ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
|
||||
pos_buf
|
||||
.copy_from_host(&(position as u32).to_le_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Retained warmup: run the exact step once eagerly with the quarantine
|
||||
// ON to stock the pool. See gpt_oss_graph.rs:66-86 for the full
|
||||
// rationale. Re-running the step is idempotent: the KV scatter
|
||||
// overwrites the same cache position and advance_seq_len is *inside*
|
||||
// decode_core, so we roll it back afterwards.
|
||||
let seq_len_before = cache.seq_len(slot);
|
||||
allocator::begin_retain();
|
||||
{
|
||||
let _guard = xserv_cuda::push_stream(&stream);
|
||||
let _ = model.decode_core(
|
||||
ids_buf.as_ptr() as *const c_void,
|
||||
pos_buf.as_ptr() as *const c_void,
|
||||
1,
|
||||
&[slot],
|
||||
cache,
|
||||
);
|
||||
}
|
||||
drop(allocator::end_retain());
|
||||
stream.synchronize().expect("warmup sync");
|
||||
// decode_core advanced seq_len; roll back so capture starts from the
|
||||
// same logical state as the eager warmup.
|
||||
cache
|
||||
.truncate_sequence(slot, seq_len_before)
|
||||
.expect("rollback after warmup");
|
||||
|
||||
allocator::begin_retain();
|
||||
let mut graph = CudaGraph::new();
|
||||
let logits;
|
||||
{
|
||||
let _guard = xserv_cuda::stream::push_stream(&stream);
|
||||
graph
|
||||
.begin_capture(&stream)
|
||||
.expect("begin decode-graph capture");
|
||||
logits = model.decode_core(
|
||||
ids_buf.as_ptr() as *const c_void,
|
||||
pos_buf.as_ptr() as *const c_void,
|
||||
1,
|
||||
&[slot],
|
||||
cache,
|
||||
);
|
||||
graph
|
||||
.end_capture(&stream)
|
||||
.expect("end decode-graph capture");
|
||||
}
|
||||
let arena = allocator::end_retain();
|
||||
|
||||
// The capture path called advance_seq_len (host-side) but the actual
|
||||
// GPU compute has not yet run. Roll back and let the first replay
|
||||
// advance it exactly once with real K/V writes.
|
||||
cache
|
||||
.truncate_sequence(slot, seq_len_before)
|
||||
.expect("rollback after capture");
|
||||
|
||||
graph.launch(&stream).expect("first decode-graph replay");
|
||||
cache.advance_seq_len(slot, 1);
|
||||
|
||||
Self {
|
||||
stream,
|
||||
graph,
|
||||
ids_buf,
|
||||
pos_buf,
|
||||
logits,
|
||||
_arena: arena,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run one decode step by replaying the captured graph.
|
||||
pub fn step(
|
||||
&mut self,
|
||||
model: &Qwen3,
|
||||
token: u32,
|
||||
position: usize,
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
model.decode_prepare(&[position], &[slot], cache);
|
||||
self.ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
|
||||
self.pos_buf
|
||||
.copy_from_host(&(position as u32).to_le_bytes())
|
||||
.unwrap();
|
||||
self.graph
|
||||
.launch(&self.stream)
|
||||
.expect("decode-graph replay");
|
||||
cache.advance_seq_len(slot, 1);
|
||||
self.logits.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Lazy capture policy: first decode step of the process runs eager, the
|
||||
/// second is captured, the rest replay. Batch>1 always falls back to eager.
|
||||
/// Disable with `XSERV_DECODE_GRAPH=0`.
|
||||
pub struct GraphedQwen3Decoder {
|
||||
graph: Option<Qwen3DecodeGraph>,
|
||||
eager_steps: u32,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl GraphedQwen3Decoder {
|
||||
pub fn new() -> Self {
|
||||
let enabled = std::env::var("XSERV_DECODE_GRAPH")
|
||||
.map(|v| v != "0")
|
||||
.unwrap_or(true);
|
||||
Self {
|
||||
graph: None,
|
||||
eager_steps: 0,
|
||||
enabled,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode(
|
||||
&mut self,
|
||||
model: &Qwen3,
|
||||
tokens: &[u32],
|
||||
positions: &[usize],
|
||||
slots: &[usize],
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
if self.enabled && tokens.len() == 1 {
|
||||
if let Some(g) = self.graph.as_mut() {
|
||||
return g.step(model, tokens[0], positions[0], slots[0], cache);
|
||||
}
|
||||
if self.eager_steps >= 1 {
|
||||
let g = Qwen3DecodeGraph::capture(model, tokens[0], positions[0], slots[0], cache);
|
||||
let logits = g.logits.clone();
|
||||
self.graph = Some(g);
|
||||
return logits;
|
||||
}
|
||||
}
|
||||
self.eager_steps += 1;
|
||||
model.forward_decode_paged(tokens, positions, slots, cache)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GraphedQwen3Decoder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user