diff --git a/crates/xserv-model/src/bin/bench-speculative.rs b/crates/xserv-model/src/bin/bench-speculative.rs index a8fef1a..723dab9 100644 --- a/crates/xserv-model/src/bin/bench-speculative.rs +++ b/crates/xserv-model/src/bin/bench-speculative.rs @@ -10,6 +10,7 @@ use half::bf16; use std::path::{Path, PathBuf}; use std::time::Instant; +use xserv_model::qwen3_graph::GraphedQwen3Decoder; use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader}; use xserv_tensor::{DType, Device, Tensor}; use xserv_tokenizer::Tokenizer; @@ -222,12 +223,14 @@ fn main() { let mut target_verify_cache = new_cache_with_rows(&target_config, max_seq_len, device, gamma); let mut draft_cache = new_cache(&draft_config, max_seq_len, device); + let mut draft_decoder = GraphedQwen3Decoder::new(); let _ = run_speculative( &target, &draft, &mut target_cache, &mut target_verify_cache, &mut draft_cache, + &mut draft_decoder, &tokenizer, &warm_ids, warm_tokens, @@ -248,6 +251,21 @@ fn main() { ); let mut totals = Totals::default(); + + // Persistent per-benchmark caches so the draft CUDA graph (Phase 24) can be + // captured once and replayed across every prompt. Freeing and re-registering + // slot 0 between prompts keeps block_table_gpu / context_lens_gpu addresses + // stable, which is exactly what the graph captured. + let mut target_cache = new_cache_with_rows( + &target_config, + max_seq_len, + device, + if use_verify_logits { gamma } else { 1 }, + ); + let mut target_verify_cache = new_cache_with_rows(&target_config, max_seq_len, device, gamma); + let mut draft_cache = new_cache(&draft_config, max_seq_len, device); + let mut draft_decoder = GraphedQwen3Decoder::new(); + for (i, prompt) in PROMPTS.iter().take(prompt_count).enumerate() { let ids = tokenizer.encode(prompt); validate_length_budget(&ids, gen_tokens, max_seq_len, prompt); @@ -255,21 +273,13 @@ fn main() { let baseline = run_baseline(&target, &mut baseline_cache, &tokenizer, &ids, gen_tokens); drop(baseline_cache); - let mut target_cache = new_cache_with_rows( - &target_config, - max_seq_len, - device, - if use_verify_logits { gamma } else { 1 }, - ); - let mut target_verify_cache = - new_cache_with_rows(&target_config, max_seq_len, device, gamma); - let mut draft_cache = new_cache(&draft_config, max_seq_len, device); let spec = run_speculative( &target, &draft, &mut target_cache, &mut target_verify_cache, &mut draft_cache, + &mut draft_decoder, &tokenizer, &ids, gen_tokens, @@ -438,6 +448,7 @@ fn run_speculative( target_cache: &mut PagedKVCache, target_verify_cache: &mut PagedKVCache, draft_cache: &mut PagedKVCache, + draft_decoder: &mut GraphedQwen3Decoder, tokenizer: &Tokenizer, prompt_ids: &[u32], gen_tokens: usize, @@ -504,7 +515,7 @@ fn run_speculative( break; } let pos = draft_cache.seq_len(slot); - let logits = draft.forward_decode_paged(&[token], &[pos], &[slot], draft_cache); + let logits = draft_decoder.decode(draft, &[token], &[pos], &[slot], draft_cache); draft_next = last_argmax(&logits); } proposed_total += draft_tokens.len(); @@ -572,6 +583,7 @@ fn run_speculative( .unwrap(); replay_draft_tokens( draft, + draft_decoder, draft_cache, slot, &draft_tokens[..accepted], @@ -588,7 +600,7 @@ fn run_speculative( commit_steps += 1; let pos = draft_cache.seq_len(slot); - let logits = draft.forward_decode_paged(&[correction], &[pos], &[slot], draft_cache); + let logits = draft_decoder.decode(draft, &[correction], &[pos], &[slot], draft_cache); draft_next = last_argmax(&logits); correction_steps += 1; continue; @@ -690,6 +702,7 @@ fn run_speculative( .unwrap(); replay_draft_tokens( draft, + draft_decoder, draft_cache, slot, &draft_tokens[..accepted], @@ -709,7 +722,7 @@ fn run_speculative( mirror_steps += 1; let pos = draft_cache.seq_len(slot); - let logits = draft.forward_decode_paged(&[correction], &[pos], &[slot], draft_cache); + let logits = draft_decoder.decode(draft, &[correction], &[pos], &[slot], draft_cache); draft_next = last_argmax(&logits); correction_steps += 1; } @@ -745,6 +758,7 @@ fn advance_target_cache(target: &Qwen3, cache: &mut PagedKVCache, slot: usize, t fn replay_draft_tokens( draft: &Qwen3, + draft_decoder: &mut GraphedQwen3Decoder, cache: &mut PagedKVCache, slot: usize, tokens: &[u32], @@ -752,7 +766,7 @@ fn replay_draft_tokens( ) { for &token in tokens { let pos = cache.seq_len(slot); - let logits = draft.forward_decode_paged(&[token], &[pos], &[slot], cache); + let logits = draft_decoder.decode(draft, &[token], &[pos], &[slot], cache); *next = last_argmax(&logits); } } diff --git a/crates/xserv-model/src/lib.rs b/crates/xserv-model/src/lib.rs index 922aa33..7112a8c 100644 --- a/crates/xserv-model/src/lib.rs +++ b/crates/xserv-model/src/lib.rs @@ -7,6 +7,7 @@ pub mod kv_cache; pub mod loader; pub mod paged_kv_cache; pub mod qwen3; +pub mod qwen3_graph; pub mod sampling; pub use config::ModelConfig; diff --git a/crates/xserv-model/src/qwen3.rs b/crates/xserv-model/src/qwen3.rs index 933eee4..3de4beb 100644 --- a/crates/xserv-model/src/qwen3.rs +++ b/crates/xserv-model/src/qwen3.rs @@ -701,45 +701,72 @@ impl Qwen3 { assert_eq!(seq_slots.len(), batch); assert!(batch > 0); - // TP: this rank owns a slice of the heads (local_* == full when world==1). - let num_heads = self.local_num_heads; - let num_kv_heads = self.local_num_kv_heads; - let head_dim = self.config.head_dim(); - let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32; + self.decode_prepare(positions, seq_slots, paged_cache); - // Ensure all slots have enough physical blocks for this token, then - // upload block tables + context_lens once for the whole forward (the - // tables are identical across layers; only the layer's K/V pool changes). + let ids_gpu = upload_u32(tokens); + let positions_u32: Vec = positions.iter().map(|&p| p as u32).collect(); + let pos_gpu = upload_u32(&positions_u32); + let logits = self.decode_core( + ids_gpu.as_ptr() as *const std::ffi::c_void, + pos_gpu.as_ptr() as *const std::ffi::c_void, + batch, + seq_slots, + paged_cache, + ); + logits + } + + /// Host-side per-step cache bookkeeping: block allocation + uploading block + /// tables / context lens to their (stable-address) GPU buffers. Runs + /// OUTSIDE any CUDA-graph captured region. + pub fn decode_prepare( + &self, + positions: &[usize], + seq_slots: &[usize], + paged_cache: &mut PagedKVCache, + ) { let kv_lens: Vec = positions.iter().map(|&p| (p + 1) as i32).collect(); for (b, &slot) in seq_slots.iter().enumerate() { paged_cache.ensure_capacity(slot, positions[b] + 1); } paged_cache.sync_active_batch_with_lens(seq_slots, &kv_lens); + } + + /// Pure-GPU decode step: embedding → all layers → final norm → logits. + /// Token ids and positions are read from device buffers; every other input + /// (weights, KV pools, block table, context lens) has a stable address — + /// which makes this region CUDA-graph capturable. + pub fn decode_core( + &self, + ids_gpu: *const std::ffi::c_void, + pos_gpu: *const std::ffi::c_void, + batch: usize, + seq_slots: &[usize], + paged_cache: &mut PagedKVCache, + ) -> Tensor { + let num_heads = self.local_num_heads; + let num_kv_heads = self.local_num_kv_heads; + let head_dim = self.config.head_dim(); + let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32; let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32; let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32; let max_blocks = paged_cache.max_blocks_per_seq(); - // RoPE expects `[num_tokens, H, D]` with `num_tokens` positions — - // matches our `[B, H, D]` exactly, so we upload once here. - let positions_u32: Vec = positions.iter().map(|&p| p as u32).collect(); - - // Batched embedding: [B, hidden] - let mut x = embedding(&self.embed_tokens, tokens); + let mut x = embedding_device_ids(&self.embed_tokens, ids_gpu, batch); for (layer_idx, layer) in self.layers.iter().enumerate() { let residual = x.clone(); let normed = rmsnorm(&x, &layer.input_norm, eps); // Fused QKV projection: one GEMV instead of three. - let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); // [B, (H+2*KV)*D] + let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); let q_dim = num_heads * head_dim; let kv_dim = num_kv_heads * head_dim; - let q_all = qkv.narrow(1, 0, q_dim); // [B, H*D] (view) - let k_all = qkv.narrow(1, q_dim, kv_dim); // [B, KV*D] (view) + let q_all = qkv.narrow(1, 0, q_dim); + let k_all = qkv.narrow(1, q_dim, kv_dim); let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim); - // Per-head RMSNorm on contiguous copies (narrow views are strided). let q_flat = q_all.contiguous().reshape(&[batch * num_heads, head_dim]); let k_flat = k_all .contiguous() @@ -749,16 +776,13 @@ impl Qwen3 { let q_3d = q_normed.reshape(&[batch, num_heads, head_dim]); let k_3d = k_normed.reshape(&[batch, num_kv_heads, head_dim]); - rope_inplace(&q_3d, &self.rope_cache, &positions_u32); - rope_inplace(&k_3d, &self.rope_cache, &positions_u32); + rope_inplace_device_pos(&q_3d, &self.rope_cache, pos_gpu); + rope_inplace_device_pos(&k_3d, &self.rope_cache, pos_gpu); let v_3d = v_all.contiguous().reshape(&[batch, num_kv_heads, head_dim]); - // Single batched scatter for all sequences in the batch. paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, batch); - // Paged attention reads Q as [B, H, 1, D] — a contiguous view - // of [B, H, D]. let q_4d = q_3d.reshape(&[batch, num_heads, 1, head_dim]); let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void; let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void; @@ -775,27 +799,24 @@ impl Qwen3 { max_blocks, ); - // attn_out shape [B, H, 1, D] is contiguous-equivalent to [B, H*D]. let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]); let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt); - self.all_reduce(&attn_proj); // TP: sum partial attention outputs + self.all_reduce(&attn_proj); let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); let residual = x_new.clone(); - // Fused gate+up projection: one GEMV instead of two. - let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt); // [B, 2*ffn] + let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt); let ffn_dim = gate_up.shape()[1] / 2; let gate = gate_up.narrow(1, 0, ffn_dim).contiguous(); let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous(); let hidden_states = xserv_kernels::silu_mul(&gate, &up); let down = matmul_2d(&hidden_states, &layer.down_proj_wt); - self.all_reduce(&down); // TP: sum partial MLP outputs + self.all_reduce(&down); x = add_any(&residual, &down); } - // Advance logical seq_len now that all layers have been written. for &slot in seq_slots { paged_cache.advance_seq_len(slot, 1); } @@ -1261,6 +1282,14 @@ fn row_view(t: &Tensor, row: usize) -> Tensor { ) } +/// Upload a u32 slice to a pooled GPU buffer (synchronous H2D). +fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer { + let bytes = unsafe { std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) }; + let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc u32 upload"); + buf.copy_from_host(bytes).unwrap(); + buf +} + /// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy. fn concat_rows(rows: &[Tensor]) -> Tensor { assert!(!rows.is_empty()); diff --git a/crates/xserv-model/src/qwen3_graph.rs b/crates/xserv-model/src/qwen3_graph.rs new file mode 100644 index 0000000..0a98d35 --- /dev/null +++ b/crates/xserv-model/src/qwen3_graph.rs @@ -0,0 +1,185 @@ +//! CUDA-graph replay for Qwen3 batch=1 decode (Phase 24 / speculative draft). +//! +//! Same pattern as `gpt_oss_graph.rs`, but for the Qwen3 dense decode path used +//! by speculative decoding's draft model. A Qwen3-0.6B decode step is ~140 +//! kernel launches; wrapping the whole step into one `cudaGraphLaunch` cuts +//! the ~4× γ draft cost per speculative round. +//! +//! See `gpt_oss_graph.rs` for the design commentary; the capture preconditions, +//! retained-warmup mechanism, and quarantine lifetime are all identical here. + +use std::ffi::c_void; + +use xserv_cuda::allocator::{self, RetainedBlocks}; +use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer}; +use xserv_tensor::Tensor; + +use crate::paged_kv_cache::PagedKVCache; +use crate::qwen3::Qwen3; + +pub struct Qwen3DecodeGraph { + stream: CudaStream, + graph: CudaGraph, + ids_buf: GpuBuffer, // [1] u32, persistent graph input + pos_buf: GpuBuffer, // [1] u32, persistent graph input + logits: Tensor, // graph output; rewritten in place by every replay + _arena: RetainedBlocks, +} + +impl Qwen3DecodeGraph { + /// Capture one batch=1 decode step and replay it once. + pub fn capture( + model: &Qwen3, + token: u32, + position: usize, + slot: usize, + cache: &mut PagedKVCache, + ) -> Self { + let stream = CudaStream::new().expect("create capture stream"); + let mut ids_buf = allocator::cached_alloc(4).expect("alloc ids buf"); + let mut pos_buf = allocator::cached_alloc(4).expect("alloc pos buf"); + + model.decode_prepare(&[position], &[slot], cache); + ids_buf.copy_from_host(&token.to_le_bytes()).unwrap(); + pos_buf + .copy_from_host(&(position as u32).to_le_bytes()) + .unwrap(); + + // Retained warmup: run the exact step once eagerly with the quarantine + // ON to stock the pool. See gpt_oss_graph.rs:66-86 for the full + // rationale. Re-running the step is idempotent: the KV scatter + // overwrites the same cache position and advance_seq_len is *inside* + // decode_core, so we roll it back afterwards. + let seq_len_before = cache.seq_len(slot); + allocator::begin_retain(); + { + let _guard = xserv_cuda::push_stream(&stream); + let _ = model.decode_core( + ids_buf.as_ptr() as *const c_void, + pos_buf.as_ptr() as *const c_void, + 1, + &[slot], + cache, + ); + } + drop(allocator::end_retain()); + stream.synchronize().expect("warmup sync"); + // decode_core advanced seq_len; roll back so capture starts from the + // same logical state as the eager warmup. + cache + .truncate_sequence(slot, seq_len_before) + .expect("rollback after warmup"); + + allocator::begin_retain(); + let mut graph = CudaGraph::new(); + let logits; + { + let _guard = xserv_cuda::stream::push_stream(&stream); + graph + .begin_capture(&stream) + .expect("begin decode-graph capture"); + logits = model.decode_core( + ids_buf.as_ptr() as *const c_void, + pos_buf.as_ptr() as *const c_void, + 1, + &[slot], + cache, + ); + graph + .end_capture(&stream) + .expect("end decode-graph capture"); + } + let arena = allocator::end_retain(); + + // The capture path called advance_seq_len (host-side) but the actual + // GPU compute has not yet run. Roll back and let the first replay + // advance it exactly once with real K/V writes. + cache + .truncate_sequence(slot, seq_len_before) + .expect("rollback after capture"); + + graph.launch(&stream).expect("first decode-graph replay"); + cache.advance_seq_len(slot, 1); + + Self { + stream, + graph, + ids_buf, + pos_buf, + logits, + _arena: arena, + } + } + + /// Run one decode step by replaying the captured graph. + pub fn step( + &mut self, + model: &Qwen3, + token: u32, + position: usize, + slot: usize, + cache: &mut PagedKVCache, + ) -> Tensor { + model.decode_prepare(&[position], &[slot], cache); + self.ids_buf.copy_from_host(&token.to_le_bytes()).unwrap(); + self.pos_buf + .copy_from_host(&(position as u32).to_le_bytes()) + .unwrap(); + self.graph + .launch(&self.stream) + .expect("decode-graph replay"); + cache.advance_seq_len(slot, 1); + self.logits.clone() + } +} + +/// Lazy capture policy: first decode step of the process runs eager, the +/// second is captured, the rest replay. Batch>1 always falls back to eager. +/// Disable with `XSERV_DECODE_GRAPH=0`. +pub struct GraphedQwen3Decoder { + graph: Option, + eager_steps: u32, + enabled: bool, +} + +impl GraphedQwen3Decoder { + pub fn new() -> Self { + let enabled = std::env::var("XSERV_DECODE_GRAPH") + .map(|v| v != "0") + .unwrap_or(true); + Self { + graph: None, + eager_steps: 0, + enabled, + } + } + + pub fn decode( + &mut self, + model: &Qwen3, + tokens: &[u32], + positions: &[usize], + slots: &[usize], + cache: &mut PagedKVCache, + ) -> Tensor { + if self.enabled && tokens.len() == 1 { + if let Some(g) = self.graph.as_mut() { + return g.step(model, tokens[0], positions[0], slots[0], cache); + } + if self.eager_steps >= 1 { + let g = Qwen3DecodeGraph::capture(model, tokens[0], positions[0], slots[0], cache); + let logits = g.logits.clone(); + self.graph = Some(g); + return logits; + } + } + self.eager_steps += 1; + model.forward_decode_paged(tokens, positions, slots, cache) + } +} + +impl Default for GraphedQwen3Decoder { + fn default() -> Self { + Self::new() + } +}