From 34224c7c936554265dfcc5f507e15ba0481dec52 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 12 Jun 2026 20:12:37 +0800 Subject: [PATCH] gpt-oss: replay the whole batch=1 decode step as one CUDA graph MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Split forward_decode_paged into host bookkeeping (decode_prepare + ids/pos upload + advance_seq_len) and a pure-GPU decode_core. The paged-KV and sparse-MoE designs already read every per-step variable (block table, context lens, expert ids) from stable-address device buffers, so decode_core captures as-is. GptOssDecodeGraph captures lazily on the second decode step (the first eager step warms cuBLAS) after a "retained warmup": the step runs once with the allocator quarantine on, stocking the pool with a dedicated block for every allocation so the capture itself never pool-misses (a cudaMalloc while capturing is illegal — and the capture's own quarantine is what would otherwise starve the pool). NCCL all-reduces capture cleanly; TP=2 replays in lockstep. Wired into tp_engine, bench-gpt-oss, and xserv-chat via GraphedGptOssDecoder (batch>1 falls back to eager; XSERV_DECODE_GRAPH=0 disables). Greedy tokens identical to eager. Co-Authored-By: Claude Fable 5 --- crates/xserv-model/src/bin/bench-gpt-oss.rs | 13 +- crates/xserv-model/src/bin/xserv-chat.rs | 26 ++- crates/xserv-model/src/gpt_oss.rs | 71 ++++++-- crates/xserv-model/src/gpt_oss_graph.rs | 172 ++++++++++++++++++++ crates/xserv-model/src/lib.rs | 2 + crates/xserv-server/src/tp_engine.rs | 18 +- 6 files changed, 277 insertions(+), 25 deletions(-) create mode 100644 crates/xserv-model/src/gpt_oss_graph.rs diff --git a/crates/xserv-model/src/bin/bench-gpt-oss.rs b/crates/xserv-model/src/bin/bench-gpt-oss.rs index f80f27c..de983a0 100644 --- a/crates/xserv-model/src/bin/bench-gpt-oss.rs +++ b/crates/xserv-model/src/bin/bench-gpt-oss.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use std::time::Instant; use xserv_distributed::{TpContext, UniqueId, get_unique_id}; -use xserv_model::{loader, GptOss, ModelConfig, PagedKVCache, BLOCK_SIZE}; +use xserv_model::{loader, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, BLOCK_SIZE}; use xserv_tensor::{DType, Device}; use xserv_tokenizer::Tokenizer; @@ -172,6 +172,7 @@ fn main() { print!("{prompt}"); // Decode + let mut decoder = GraphedGptOssDecoder::new(); let decode_start = Instant::now(); for _ in 1..max_tokens { let text = tokenizer.decode(&[next]); @@ -183,7 +184,7 @@ fn main() { broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode { tokens: vec![next], positions: vec![pos], slots: vec![slot], }); - let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut cache); + let logits = decoder.decode(&model, &[next], &[pos], &[slot], &mut cache); wait_workers(&worker_handles); next = sample_greedy_last(&logits); @@ -250,6 +251,7 @@ fn worker_loop( eprintln!("[rank {rank}] Ready."); ack_tx.send(()).unwrap(); + let mut decoder = GraphedGptOssDecoder::new(); while let Ok(cmd) = rx.recv() { match cmd { WorkerCmd::Register(slot) => { @@ -259,7 +261,7 @@ fn worker_loop( let _ = model.forward_prefill_paged(&tokens, slot, &mut cache); } WorkerCmd::Decode { tokens, positions, slots } => { - let _ = model.forward_decode_paged(&tokens, &positions, &slots, &mut cache); + let _ = decoder.decode(&model, &tokens, &positions, &slots, &mut cache); } WorkerCmd::Shutdown => break, } @@ -286,6 +288,11 @@ fn wait_workers(handles: &[(std::thread::JoinHandle<()>, std::sync::mpsc::Receiv fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 { use half::bf16; assert_eq!(logits.ndim(), 2); + // GPU argmax fast path (4-byte D2H instead of the full logits row). + if logits.dtype() == xserv_tensor::DType::BF16 && logits.is_contiguous() { + let ids = xserv_kernels::argmax_bf16_to_host(logits); + return *ids.last().unwrap(); + } let logits_cpu = logits.to_device(Device::Cpu); let vocab_size = logits.shape()[1]; let seq_len = logits.shape()[0]; diff --git a/crates/xserv-model/src/bin/xserv-chat.rs b/crates/xserv-model/src/bin/xserv-chat.rs index 51917d4..fbe8357 100644 --- a/crates/xserv-model/src/bin/xserv-chat.rs +++ b/crates/xserv-model/src/bin/xserv-chat.rs @@ -4,7 +4,7 @@ use std::path::PathBuf; use std::sync::{mpsc, Arc}; use std::thread; -use xserv_model::{loader, sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE}; +use xserv_model::{GraphedGptOssDecoder, loader, sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE}; use xserv_tensor::{DType, Device}; use xserv_tokenizer::Tokenizer; @@ -77,6 +77,7 @@ fn tp_worker_loop( let mut cache = PagedKVCache::new_tp( &config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, rank as u32, ); + let mut decoder = GraphedGptOssDecoder::new(); while let Ok(cmd) = cmd_rx.recv() { match cmd { TpCommand::Register(slot) => { let _ = cache.register_sequence(slot); } @@ -85,7 +86,7 @@ fn tp_worker_loop( let _ = model.forward_prefill_paged(&tokens, slot, &mut cache); } TpCommand::Decode { tokens, positions, slots } => { - let _ = model.forward_decode_paged(&tokens, &positions, &slots, &mut cache); + let _ = chat_decode(&model, &mut decoder, &tokens, &positions, &slots, &mut cache); } } let _ = ack_tx.send(()); @@ -321,6 +322,7 @@ fn main() { }; let tokenizer = Tokenizer::from_file(&opts.model_dir.join("tokenizer.json")); + let mut decoder = GraphedGptOssDecoder::new(); if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); } cache.register_sequence(SLOT).expect("register chat slot"); let use_color = opts.color && io::stdout().is_terminal(); @@ -384,7 +386,7 @@ fn main() { print!("assistant> "); io::stdout().flush().unwrap(); let (_finish, answer) = generate_with_paged_cache( - &model, &mut cache, &tokenizer, &prompt_tokens, &opts.sampling, + &model, &mut decoder, &mut cache, &tokenizer, &prompt_tokens, &opts.sampling, max_new_tokens, use_color, &tp_handle, is_moe, opts.enable_thinking, ); moe_history.push((input.to_string(), answer)); @@ -421,6 +423,7 @@ fn main() { io::stdout().flush().unwrap(); let (finish, _answer) = generate_with_paged_cache( &model, + &mut decoder, &mut cache, &tokenizer, &prompt_tokens, @@ -679,8 +682,23 @@ fn today_ymd() -> String { format!("{y:04}-{m:02}-{d:02}") } +fn chat_decode( + model: &ChatModel, + decoder: &mut GraphedGptOssDecoder, + tokens: &[u32], + positions: &[usize], + slots: &[usize], + cache: &mut PagedKVCache, +) -> xserv_tensor::Tensor { + match model { + ChatModel::GptOss(m) => decoder.decode(m, tokens, positions, slots, cache), + ChatModel::Qwen3(_) => model.forward_decode_paged(tokens, positions, slots, cache), + } +} + fn generate_with_paged_cache( model: &ChatModel, + decoder: &mut GraphedGptOssDecoder, cache: &mut PagedKVCache, tokenizer: &Tokenizer, prompt_tokens: &[u32], @@ -745,7 +763,7 @@ fn generate_with_paged_cache( for _ in 0..max_tokens { let position = cache.seq_len(SLOT); if let Some(h) = tp { h.send(TpCommand::Decode { tokens: vec![next], positions: vec![position], slots: vec![SLOT] }); } - let logits = model.forward_decode_paged(&[next], &[position], &[SLOT], cache); + let logits = chat_decode(model, decoder, &[next], &[position], &[SLOT], cache); if let Some(h) = tp { h.wait(); } if tokenizer.is_eos(next) { print_stream_text( diff --git a/crates/xserv-model/src/gpt_oss.rs b/crates/xserv-model/src/gpt_oss.rs index 5250f54..4c4378c 100644 --- a/crates/xserv-model/src/gpt_oss.rs +++ b/crates/xserv-model/src/gpt_oss.rs @@ -373,24 +373,62 @@ impl GptOss { assert_eq!(seq_slots.len(), batch); assert!(batch > 0); - let num_heads = self.local_num_heads; - let num_kv_heads = self.local_num_kv_heads; - let head_dim = self.config.head_dim(); - let eps = self.norm_eps(); + self.decode_prepare(positions, seq_slots, paged_cache); + // Upload token ids + positions, then run the pure-GPU core. + let ids_gpu = upload_u32(tokens); + let positions_u32: Vec = positions.iter().map(|&p| p as u32).collect(); + let pos_gpu = upload_u32(&positions_u32); + let logits = self.decode_core( + ids_gpu.as_ptr() as *const c_void, + pos_gpu.as_ptr() as *const c_void, + batch, + paged_cache, + ); + + for &slot in seq_slots { + paged_cache.advance_seq_len(slot, 1); + } + logits + } + + /// Host-side per-step cache bookkeeping: block allocation + uploading + /// block tables / context lens to their (stable-address) GPU buffers. + /// Runs OUTSIDE the CUDA-graph captured region. + pub fn decode_prepare( + &self, + positions: &[usize], + seq_slots: &[usize], + paged_cache: &mut PagedKVCache, + ) { let kv_lens: Vec = positions.iter().map(|&p| (p + 1) as i32).collect(); for (b, &slot) in seq_slots.iter().enumerate() { paged_cache.ensure_capacity(slot, positions[b] + 1); } paged_cache.sync_active_batch_with_lens(seq_slots, &kv_lens); + } + + /// The pure-GPU decode step: embedding → 24 layers → final norm → logits. + /// Token ids and positions are read from device buffers; every other input + /// (weights, KV pools, block table, context lens) has a stable address — + /// which is exactly what makes this region CUDA-graph capturable. + pub fn decode_core( + &self, + ids_gpu: *const c_void, + pos_gpu: *const c_void, + batch: usize, + paged_cache: &mut PagedKVCache, + ) -> Tensor { + let num_heads = self.local_num_heads; + let num_kv_heads = self.local_num_kv_heads; + let head_dim = self.config.head_dim(); + let eps = self.norm_eps(); let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32; let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32; let max_blocks = paged_cache.max_blocks_per_seq(); - let positions_u32: Vec = positions.iter().map(|&p| p as u32).collect(); - - let mut x = embedding(&self.embed_tokens, tokens); + let mut x = embedding_device_ids(&self.embed_tokens, ids_gpu, batch); for (layer_idx, layer) in self.layers.iter().enumerate() { let residual = x.clone(); @@ -407,8 +445,8 @@ impl GptOss { let k_3d = k_all.reshape(&[batch, num_kv_heads, head_dim]); // RoPE (no QK-norm for gpt-oss) - rope_inplace(&q_3d, &self.rope_cache, &positions_u32); - rope_inplace(&k_3d, &self.rope_cache, &positions_u32); + rope_inplace_device_pos(&q_3d, &self.rope_cache, pos_gpu); + rope_inplace_device_pos(&k_3d, &self.rope_cache, pos_gpu); let v_3d = v_all.reshape(&[batch, num_kv_heads, head_dim]); @@ -445,11 +483,6 @@ impl GptOss { x = xserv_kernels::add(&residual, &moe_out); } - // Advance KV cache - for &slot in seq_slots { - paged_cache.advance_seq_len(slot, 1); - } - let x = Self::norm(&x, &self.norm, &self.norm_bias, eps); matmul_2d(&x, &self.lm_head_t) } @@ -673,6 +706,16 @@ impl GptOss { // --- Helpers --- +/// Upload a u32 slice to a pooled GPU buffer (synchronous H2D). +fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer { + let bytes = unsafe { + std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) + }; + let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc u32 upload"); + buf.copy_from_host(bytes).unwrap(); + buf +} + /// XSERV_DENSE_MOE=1 forces the dense all-expert path (A/B benchmarking). fn dense_moe_forced() -> bool { static FORCED: std::sync::OnceLock = std::sync::OnceLock::new(); diff --git a/crates/xserv-model/src/gpt_oss_graph.rs b/crates/xserv-model/src/gpt_oss_graph.rs new file mode 100644 index 0000000..3785427 --- /dev/null +++ b/crates/xserv-model/src/gpt_oss_graph.rs @@ -0,0 +1,172 @@ +//! CUDA-graph replay for gpt-oss batch=1 decode (Phase 21). +//! +//! A decode step launches ~200 kernels; with sparse MoE the GPU work is only +//! a few ms, so launch overhead dominates TPOT. The whole step (embedding → +//! 24 layers → logits) is captured ONCE into a CUDA graph and replayed per +//! token with a single `cudaGraphLaunch`. +//! +//! Why the existing forward is capturable as-is: +//! - Every per-step variable input lives in a stable-address device buffer +//! whose CONTENTS are updated outside the captured region: token id and +//! position (persistent buffers owned here), block table and context lens +//! (PagedKVCache GPU buffers, refreshed by `decode_prepare`). The KV scatter +//! and paged attention kernels read their write/read positions from those +//! buffers, and the sparse-MoE GEMVs read expert ids from `topk_ids` written +//! earlier in the same graph — all data-dependent, no host branching. +//! - Kernel launches go through the thread-local launch stream +//! (`xserv_cuda::stream::push_stream`), so the capture stream sees them. +//! - Intermediate tensors come from the caching allocator. Blocks freed while +//! capturing are quarantined (`allocator::begin_retain`) for the graph's +//! lifetime so no later allocation can take ownership of memory the graph +//! still references on every replay. +//! +//! Capture preconditions: at least one EAGER decode step must have run first, +//! so the allocator pool already holds every bucket size the step needs +//! (a pool-miss inside capture would call cudaMalloc — illegal while +//! capturing) and cuBLAS has finished its one-time per-shape setup. + +use std::ffi::c_void; + +use xserv_cuda::allocator::{self, RetainedBlocks}; +use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer}; +use xserv_tensor::Tensor; + +use crate::gpt_oss::GptOss; +use crate::paged_kv_cache::PagedKVCache; + +pub struct GptOssDecodeGraph { + stream: CudaStream, + graph: CudaGraph, + ids_buf: GpuBuffer, // [1] u32, persistent graph input + pos_buf: GpuBuffer, // [1] u32, persistent graph input + logits: Tensor, // graph output; rewritten in place by every replay + _arena: RetainedBlocks, +} + +impl GptOssDecodeGraph { + /// Capture one batch=1 decode step and replay it once (capture records + /// without executing, so the replay performs this token's computation). + pub fn capture( + model: &GptOss, + token: u32, + position: usize, + slot: usize, + cache: &mut PagedKVCache, + ) -> Self { + let stream = CudaStream::new().expect("create capture stream"); + let mut ids_buf = allocator::cached_alloc(4).expect("alloc ids buf"); + let mut pos_buf = allocator::cached_alloc(4).expect("alloc pos buf"); + + model.decode_prepare(&[position], &[slot], cache); + ids_buf.copy_from_host(&token.to_le_bytes()).unwrap(); + pos_buf.copy_from_host(&(position as u32).to_le_bytes()).unwrap(); + + // Retained warmup: run the exact step once eagerly with the quarantine + // ON. Freed intermediates are held back instead of recycled, so the + // pool ends up stocked with a dedicated block for EVERY allocation the + // step performs. The capture below repeats the same allocation + // sequence and therefore never misses the pool — a pool miss would + // call cudaMalloc, which is illegal while a stream is capturing (this + // is also why one block per bucket is not enough: the capture's own + // quarantine keeps freed blocks out of reuse). Re-running the step is + // idempotent: the KV scatter rewrites the same cache position. + allocator::begin_retain(); + { + let _guard = xserv_cuda::push_stream(&stream); + let _ = model.decode_core( + ids_buf.as_ptr() as *const c_void, + pos_buf.as_ptr() as *const c_void, + 1, + cache, + ); + } + drop(allocator::end_retain()); // release the warmup blocks to the pool + stream.synchronize().expect("warmup sync"); + + allocator::begin_retain(); + let mut graph = CudaGraph::new(); + let logits; + { + let _guard = xserv_cuda::stream::push_stream(&stream); + graph.begin_capture(&stream).expect("begin decode-graph capture"); + logits = model.decode_core( + ids_buf.as_ptr() as *const c_void, + pos_buf.as_ptr() as *const c_void, + 1, + cache, + ); + graph.end_capture(&stream).expect("end decode-graph capture"); + } + let arena = allocator::end_retain(); + + graph.launch(&stream).expect("first decode-graph replay"); + cache.advance_seq_len(slot, 1); + + Self { stream, graph, ids_buf, pos_buf, logits, _arena: arena } + } + + /// Run one decode step by replaying the captured graph. + pub fn step( + &mut self, + model: &GptOss, + token: u32, + position: usize, + slot: usize, + cache: &mut PagedKVCache, + ) -> Tensor { + model.decode_prepare(&[position], &[slot], cache); + self.ids_buf.copy_from_host(&token.to_le_bytes()).unwrap(); + self.pos_buf.copy_from_host(&(position as u32).to_le_bytes()).unwrap(); + self.graph.launch(&self.stream).expect("decode-graph replay"); + cache.advance_seq_len(slot, 1); + // Shallow clone: the caller reads these logits before the next replay + // rewrites the underlying buffer. + self.logits.clone() + } +} + +/// Lazy capture policy: first decode step of the process runs eager (warms the +/// allocator pool + cuBLAS so capture performs no "unsafe" CUDA calls), the +/// second is captured, the rest replay. Batch>1 always falls back to eager. +/// Disable with XSERV_DECODE_GRAPH=0. +pub struct GraphedGptOssDecoder { + graph: Option, + eager_steps: u32, + enabled: bool, +} + +impl GraphedGptOssDecoder { + pub fn new() -> Self { + let enabled = std::env::var("XSERV_DECODE_GRAPH").map(|v| v != "0").unwrap_or(true); + Self { graph: None, eager_steps: 0, enabled } + } + + pub fn decode( + &mut self, + model: &GptOss, + tokens: &[u32], + positions: &[usize], + slots: &[usize], + cache: &mut PagedKVCache, + ) -> Tensor { + if self.enabled && tokens.len() == 1 { + if let Some(g) = self.graph.as_mut() { + return g.step(model, tokens[0], positions[0], slots[0], cache); + } + if self.eager_steps >= 1 { + let g = GptOssDecodeGraph::capture(model, tokens[0], positions[0], slots[0], cache); + let logits = g.logits.clone(); + self.graph = Some(g); + return logits; + } + } + self.eager_steps += 1; + model.forward_decode_paged(tokens, positions, slots, cache) + } +} + +impl Default for GraphedGptOssDecoder { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/xserv-model/src/lib.rs b/crates/xserv-model/src/lib.rs index 67c3f46..e97a80e 100644 --- a/crates/xserv-model/src/lib.rs +++ b/crates/xserv-model/src/lib.rs @@ -2,6 +2,7 @@ pub mod config; pub mod decode_graph; pub mod gpt2; pub mod gpt_oss; +pub mod gpt_oss_graph; pub mod kv_cache; pub mod loader; pub mod paged_kv_cache; @@ -12,6 +13,7 @@ pub use config::ModelConfig; pub use decode_graph::{DecodeGraphState, LayerWeightPtrs}; pub use gpt2::{GPT2, KVCache}; pub use gpt_oss::GptOss; +pub use gpt_oss_graph::{GptOssDecodeGraph, GraphedGptOssDecoder}; pub use kv_cache::GpuKVCache; pub use paged_kv_cache::{BlockAllocator, Location, PagedKVCache, BLOCK_SIZE}; pub use qwen3::Qwen3; diff --git a/crates/xserv-server/src/tp_engine.rs b/crates/xserv-server/src/tp_engine.rs index be6453e..14deba0 100644 --- a/crates/xserv-server/src/tp_engine.rs +++ b/crates/xserv-server/src/tp_engine.rs @@ -19,7 +19,7 @@ use std::thread; use xserv_distributed::{TpContext, UniqueId}; use xserv_model::loader; -use xserv_model::{sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE}; +use xserv_model::{sample, sample_greedy_penalized, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE}; use xserv_tensor::{DType, Device, Tensor}; use xserv_tokenizer::Tokenizer; @@ -58,6 +58,16 @@ impl TpModel { struct RankCtx { model: TpModel, cache: PagedKVCache, + decoder: GraphedGptOssDecoder, +} + +/// Decode one step: gpt-oss batch=1 goes through the CUDA-graph decoder +/// (lazy capture, replay thereafter); everything else runs eager. +fn rank_decode(rc: &mut RankCtx, tokens: &[u32], positions: &[usize], slots: &[usize]) -> Tensor { + match &rc.model { + TpModel::GptOss(m) => rc.decoder.decode(m, tokens, positions, slots, &mut rc.cache), + TpModel::Qwen3(_) => rc.model.forward_decode_paged(tokens, positions, slots, &mut rc.cache), + } } fn build_rank( @@ -81,7 +91,7 @@ fn build_rank( let cache = PagedKVCache::new_tp( config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device, ); - RankCtx { model, cache } + RankCtx { model, cache, decoder: GraphedGptOssDecoder::new() } } fn worker_loop( @@ -106,7 +116,7 @@ fn worker_loop( let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache); } TpCommand::Decode { tokens, positions, slots } => { - let _ = rc.model.forward_decode_paged(&tokens, &positions, &slots, &mut rc.cache); + let _ = rank_decode(&mut rc, &tokens, &positions, &slots); } TpCommand::Shutdown => { let _ = ack_tx.send(()); @@ -207,7 +217,7 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece } let pos = rc.cache.seq_len(slot); broadcast(&cmd_txs, TpCommand::Decode { tokens: vec![next], positions: vec![pos], slots: vec![slot] }); - let logits = rc.model.forward_decode_paged(&[next], &[pos], &[slot], &mut rc.cache); + let logits = rank_decode(&mut rc, &[next], &[pos], &[slot]); wait_acks(&ack_rx); next = pick(&logits, &req.sampling, &gen_ids); gen_ids.push(next);