gpt-oss: replay the whole batch=1 decode step as one CUDA graph

Split forward_decode_paged into host bookkeeping (decode_prepare +
ids/pos upload + advance_seq_len) and a pure-GPU decode_core. The
paged-KV and sparse-MoE designs already read every per-step variable
(block table, context lens, expert ids) from stable-address device
buffers, so decode_core captures as-is.

GptOssDecodeGraph captures lazily on the second decode step (the
first eager step warms cuBLAS) after a "retained warmup": the step
runs once with the allocator quarantine on, stocking the pool with a
dedicated block for every allocation so the capture itself never
pool-misses (a cudaMalloc while capturing is illegal — and the
capture's own quarantine is what would otherwise starve the pool).
NCCL all-reduces capture cleanly; TP=2 replays in lockstep.

Wired into tp_engine, bench-gpt-oss, and xserv-chat via
GraphedGptOssDecoder (batch>1 falls back to eager;
XSERV_DECODE_GRAPH=0 disables). Greedy tokens identical to eager.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 20:12:37 +08:00
parent 4088f49b7d
commit 34224c7c93
6 changed files with 277 additions and 25 deletions

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use std::time::Instant;
use xserv_distributed::{TpContext, UniqueId, get_unique_id};
use xserv_model::{loader, GptOss, ModelConfig, PagedKVCache, BLOCK_SIZE};
use xserv_model::{loader, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, BLOCK_SIZE};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -172,6 +172,7 @@ fn main() {
print!("{prompt}");
// Decode
let mut decoder = GraphedGptOssDecoder::new();
let decode_start = Instant::now();
for _ in 1..max_tokens {
let text = tokenizer.decode(&[next]);
@@ -183,7 +184,7 @@ fn main() {
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode {
tokens: vec![next], positions: vec![pos], slots: vec![slot],
});
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut cache);
let logits = decoder.decode(&model, &[next], &[pos], &[slot], &mut cache);
wait_workers(&worker_handles);
next = sample_greedy_last(&logits);
@@ -250,6 +251,7 @@ fn worker_loop(
eprintln!("[rank {rank}] Ready.");
ack_tx.send(()).unwrap();
let mut decoder = GraphedGptOssDecoder::new();
while let Ok(cmd) = rx.recv() {
match cmd {
WorkerCmd::Register(slot) => {
@@ -259,7 +261,7 @@ fn worker_loop(
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
}
WorkerCmd::Decode { tokens, positions, slots } => {
let _ = model.forward_decode_paged(&tokens, &positions, &slots, &mut cache);
let _ = decoder.decode(&model, &tokens, &positions, &slots, &mut cache);
}
WorkerCmd::Shutdown => break,
}
@@ -286,6 +288,11 @@ fn wait_workers(handles: &[(std::thread::JoinHandle<()>, std::sync::mpsc::Receiv
fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 {
use half::bf16;
assert_eq!(logits.ndim(), 2);
// GPU argmax fast path (4-byte D2H instead of the full logits row).
if logits.dtype() == xserv_tensor::DType::BF16 && logits.is_contiguous() {
let ids = xserv_kernels::argmax_bf16_to_host(logits);
return *ids.last().unwrap();
}
let logits_cpu = logits.to_device(Device::Cpu);
let vocab_size = logits.shape()[1];
let seq_len = logits.shape()[0];

View File

@@ -4,7 +4,7 @@ use std::path::PathBuf;
use std::sync::{mpsc, Arc};
use std::thread;
use xserv_model::{loader, sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
use xserv_model::{GraphedGptOssDecoder, loader, sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -77,6 +77,7 @@ fn tp_worker_loop(
let mut cache = PagedKVCache::new_tp(
&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, rank as u32,
);
let mut decoder = GraphedGptOssDecoder::new();
while let Ok(cmd) = cmd_rx.recv() {
match cmd {
TpCommand::Register(slot) => { let _ = cache.register_sequence(slot); }
@@ -85,7 +86,7 @@ fn tp_worker_loop(
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
}
TpCommand::Decode { tokens, positions, slots } => {
let _ = model.forward_decode_paged(&tokens, &positions, &slots, &mut cache);
let _ = chat_decode(&model, &mut decoder, &tokens, &positions, &slots, &mut cache);
}
}
let _ = ack_tx.send(());
@@ -321,6 +322,7 @@ fn main() {
};
let tokenizer = Tokenizer::from_file(&opts.model_dir.join("tokenizer.json"));
let mut decoder = GraphedGptOssDecoder::new();
if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); }
cache.register_sequence(SLOT).expect("register chat slot");
let use_color = opts.color && io::stdout().is_terminal();
@@ -384,7 +386,7 @@ fn main() {
print!("assistant> ");
io::stdout().flush().unwrap();
let (_finish, answer) = generate_with_paged_cache(
&model, &mut cache, &tokenizer, &prompt_tokens, &opts.sampling,
&model, &mut decoder, &mut cache, &tokenizer, &prompt_tokens, &opts.sampling,
max_new_tokens, use_color, &tp_handle, is_moe, opts.enable_thinking,
);
moe_history.push((input.to_string(), answer));
@@ -421,6 +423,7 @@ fn main() {
io::stdout().flush().unwrap();
let (finish, _answer) = generate_with_paged_cache(
&model,
&mut decoder,
&mut cache,
&tokenizer,
&prompt_tokens,
@@ -679,8 +682,23 @@ fn today_ymd() -> String {
format!("{y:04}-{m:02}-{d:02}")
}
fn chat_decode(
model: &ChatModel,
decoder: &mut GraphedGptOssDecoder,
tokens: &[u32],
positions: &[usize],
slots: &[usize],
cache: &mut PagedKVCache,
) -> xserv_tensor::Tensor {
match model {
ChatModel::GptOss(m) => decoder.decode(m, tokens, positions, slots, cache),
ChatModel::Qwen3(_) => model.forward_decode_paged(tokens, positions, slots, cache),
}
}
fn generate_with_paged_cache(
model: &ChatModel,
decoder: &mut GraphedGptOssDecoder,
cache: &mut PagedKVCache,
tokenizer: &Tokenizer,
prompt_tokens: &[u32],
@@ -745,7 +763,7 @@ fn generate_with_paged_cache(
for _ in 0..max_tokens {
let position = cache.seq_len(SLOT);
if let Some(h) = tp { h.send(TpCommand::Decode { tokens: vec![next], positions: vec![position], slots: vec![SLOT] }); }
let logits = model.forward_decode_paged(&[next], &[position], &[SLOT], cache);
let logits = chat_decode(model, decoder, &[next], &[position], &[SLOT], cache);
if let Some(h) = tp { h.wait(); }
if tokenizer.is_eos(next) {
print_stream_text(

View File

@@ -373,24 +373,62 @@ impl GptOss {
assert_eq!(seq_slots.len(), batch);
assert!(batch > 0);
let num_heads = self.local_num_heads;
let num_kv_heads = self.local_num_kv_heads;
let head_dim = self.config.head_dim();
let eps = self.norm_eps();
self.decode_prepare(positions, seq_slots, paged_cache);
// Upload token ids + positions, then run the pure-GPU core.
let ids_gpu = upload_u32(tokens);
let positions_u32: Vec<u32> = positions.iter().map(|&p| p as u32).collect();
let pos_gpu = upload_u32(&positions_u32);
let logits = self.decode_core(
ids_gpu.as_ptr() as *const c_void,
pos_gpu.as_ptr() as *const c_void,
batch,
paged_cache,
);
for &slot in seq_slots {
paged_cache.advance_seq_len(slot, 1);
}
logits
}
/// Host-side per-step cache bookkeeping: block allocation + uploading
/// block tables / context lens to their (stable-address) GPU buffers.
/// Runs OUTSIDE the CUDA-graph captured region.
pub fn decode_prepare(
&self,
positions: &[usize],
seq_slots: &[usize],
paged_cache: &mut PagedKVCache,
) {
let kv_lens: Vec<i32> = positions.iter().map(|&p| (p + 1) as i32).collect();
for (b, &slot) in seq_slots.iter().enumerate() {
paged_cache.ensure_capacity(slot, positions[b] + 1);
}
paged_cache.sync_active_batch_with_lens(seq_slots, &kv_lens);
}
/// The pure-GPU decode step: embedding → 24 layers → final norm → logits.
/// Token ids and positions are read from device buffers; every other input
/// (weights, KV pools, block table, context lens) has a stable address —
/// which is exactly what makes this region CUDA-graph capturable.
pub fn decode_core(
&self,
ids_gpu: *const c_void,
pos_gpu: *const c_void,
batch: usize,
paged_cache: &mut PagedKVCache,
) -> Tensor {
let num_heads = self.local_num_heads;
let num_kv_heads = self.local_num_kv_heads;
let head_dim = self.config.head_dim();
let eps = self.norm_eps();
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
let max_blocks = paged_cache.max_blocks_per_seq();
let positions_u32: Vec<u32> = positions.iter().map(|&p| p as u32).collect();
let mut x = embedding(&self.embed_tokens, tokens);
let mut x = embedding_device_ids(&self.embed_tokens, ids_gpu, batch);
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
@@ -407,8 +445,8 @@ impl GptOss {
let k_3d = k_all.reshape(&[batch, num_kv_heads, head_dim]);
// RoPE (no QK-norm for gpt-oss)
rope_inplace(&q_3d, &self.rope_cache, &positions_u32);
rope_inplace(&k_3d, &self.rope_cache, &positions_u32);
rope_inplace_device_pos(&q_3d, &self.rope_cache, pos_gpu);
rope_inplace_device_pos(&k_3d, &self.rope_cache, pos_gpu);
let v_3d = v_all.reshape(&[batch, num_kv_heads, head_dim]);
@@ -445,11 +483,6 @@ impl GptOss {
x = xserv_kernels::add(&residual, &moe_out);
}
// Advance KV cache
for &slot in seq_slots {
paged_cache.advance_seq_len(slot, 1);
}
let x = Self::norm(&x, &self.norm, &self.norm_bias, eps);
matmul_2d(&x, &self.lm_head_t)
}
@@ -673,6 +706,16 @@ impl GptOss {
// --- Helpers ---
/// Upload a u32 slice to a pooled GPU buffer (synchronous H2D).
fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer {
let bytes = unsafe {
std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4)
};
let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc u32 upload");
buf.copy_from_host(bytes).unwrap();
buf
}
/// XSERV_DENSE_MOE=1 forces the dense all-expert path (A/B benchmarking).
fn dense_moe_forced() -> bool {
static FORCED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();

View File

@@ -0,0 +1,172 @@
//! CUDA-graph replay for gpt-oss batch=1 decode (Phase 21).
//!
//! A decode step launches ~200 kernels; with sparse MoE the GPU work is only
//! a few ms, so launch overhead dominates TPOT. The whole step (embedding →
//! 24 layers → logits) is captured ONCE into a CUDA graph and replayed per
//! token with a single `cudaGraphLaunch`.
//!
//! Why the existing forward is capturable as-is:
//! - Every per-step variable input lives in a stable-address device buffer
//! whose CONTENTS are updated outside the captured region: token id and
//! position (persistent buffers owned here), block table and context lens
//! (PagedKVCache GPU buffers, refreshed by `decode_prepare`). The KV scatter
//! and paged attention kernels read their write/read positions from those
//! buffers, and the sparse-MoE GEMVs read expert ids from `topk_ids` written
//! earlier in the same graph — all data-dependent, no host branching.
//! - Kernel launches go through the thread-local launch stream
//! (`xserv_cuda::stream::push_stream`), so the capture stream sees them.
//! - Intermediate tensors come from the caching allocator. Blocks freed while
//! capturing are quarantined (`allocator::begin_retain`) for the graph's
//! lifetime so no later allocation can take ownership of memory the graph
//! still references on every replay.
//!
//! Capture preconditions: at least one EAGER decode step must have run first,
//! so the allocator pool already holds every bucket size the step needs
//! (a pool-miss inside capture would call cudaMalloc — illegal while
//! capturing) and cuBLAS has finished its one-time per-shape setup.
use std::ffi::c_void;
use xserv_cuda::allocator::{self, RetainedBlocks};
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
use xserv_tensor::Tensor;
use crate::gpt_oss::GptOss;
use crate::paged_kv_cache::PagedKVCache;
pub struct GptOssDecodeGraph {
stream: CudaStream,
graph: CudaGraph,
ids_buf: GpuBuffer, // [1] u32, persistent graph input
pos_buf: GpuBuffer, // [1] u32, persistent graph input
logits: Tensor, // graph output; rewritten in place by every replay
_arena: RetainedBlocks,
}
impl GptOssDecodeGraph {
/// Capture one batch=1 decode step and replay it once (capture records
/// without executing, so the replay performs this token's computation).
pub fn capture(
model: &GptOss,
token: u32,
position: usize,
slot: usize,
cache: &mut PagedKVCache,
) -> Self {
let stream = CudaStream::new().expect("create capture stream");
let mut ids_buf = allocator::cached_alloc(4).expect("alloc ids buf");
let mut pos_buf = allocator::cached_alloc(4).expect("alloc pos buf");
model.decode_prepare(&[position], &[slot], cache);
ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
pos_buf.copy_from_host(&(position as u32).to_le_bytes()).unwrap();
// Retained warmup: run the exact step once eagerly with the quarantine
// ON. Freed intermediates are held back instead of recycled, so the
// pool ends up stocked with a dedicated block for EVERY allocation the
// step performs. The capture below repeats the same allocation
// sequence and therefore never misses the pool — a pool miss would
// call cudaMalloc, which is illegal while a stream is capturing (this
// is also why one block per bucket is not enough: the capture's own
// quarantine keeps freed blocks out of reuse). Re-running the step is
// idempotent: the KV scatter rewrites the same cache position.
allocator::begin_retain();
{
let _guard = xserv_cuda::push_stream(&stream);
let _ = model.decode_core(
ids_buf.as_ptr() as *const c_void,
pos_buf.as_ptr() as *const c_void,
1,
cache,
);
}
drop(allocator::end_retain()); // release the warmup blocks to the pool
stream.synchronize().expect("warmup sync");
allocator::begin_retain();
let mut graph = CudaGraph::new();
let logits;
{
let _guard = xserv_cuda::stream::push_stream(&stream);
graph.begin_capture(&stream).expect("begin decode-graph capture");
logits = model.decode_core(
ids_buf.as_ptr() as *const c_void,
pos_buf.as_ptr() as *const c_void,
1,
cache,
);
graph.end_capture(&stream).expect("end decode-graph capture");
}
let arena = allocator::end_retain();
graph.launch(&stream).expect("first decode-graph replay");
cache.advance_seq_len(slot, 1);
Self { stream, graph, ids_buf, pos_buf, logits, _arena: arena }
}
/// Run one decode step by replaying the captured graph.
pub fn step(
&mut self,
model: &GptOss,
token: u32,
position: usize,
slot: usize,
cache: &mut PagedKVCache,
) -> Tensor {
model.decode_prepare(&[position], &[slot], cache);
self.ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
self.pos_buf.copy_from_host(&(position as u32).to_le_bytes()).unwrap();
self.graph.launch(&self.stream).expect("decode-graph replay");
cache.advance_seq_len(slot, 1);
// Shallow clone: the caller reads these logits before the next replay
// rewrites the underlying buffer.
self.logits.clone()
}
}
/// Lazy capture policy: first decode step of the process runs eager (warms the
/// allocator pool + cuBLAS so capture performs no "unsafe" CUDA calls), the
/// second is captured, the rest replay. Batch>1 always falls back to eager.
/// Disable with XSERV_DECODE_GRAPH=0.
pub struct GraphedGptOssDecoder {
graph: Option<GptOssDecodeGraph>,
eager_steps: u32,
enabled: bool,
}
impl GraphedGptOssDecoder {
pub fn new() -> Self {
let enabled = std::env::var("XSERV_DECODE_GRAPH").map(|v| v != "0").unwrap_or(true);
Self { graph: None, eager_steps: 0, enabled }
}
pub fn decode(
&mut self,
model: &GptOss,
tokens: &[u32],
positions: &[usize],
slots: &[usize],
cache: &mut PagedKVCache,
) -> Tensor {
if self.enabled && tokens.len() == 1 {
if let Some(g) = self.graph.as_mut() {
return g.step(model, tokens[0], positions[0], slots[0], cache);
}
if self.eager_steps >= 1 {
let g = GptOssDecodeGraph::capture(model, tokens[0], positions[0], slots[0], cache);
let logits = g.logits.clone();
self.graph = Some(g);
return logits;
}
}
self.eager_steps += 1;
model.forward_decode_paged(tokens, positions, slots, cache)
}
}
impl Default for GraphedGptOssDecoder {
fn default() -> Self {
Self::new()
}
}

View File

@@ -2,6 +2,7 @@ pub mod config;
pub mod decode_graph;
pub mod gpt2;
pub mod gpt_oss;
pub mod gpt_oss_graph;
pub mod kv_cache;
pub mod loader;
pub mod paged_kv_cache;
@@ -12,6 +13,7 @@ pub use config::ModelConfig;
pub use decode_graph::{DecodeGraphState, LayerWeightPtrs};
pub use gpt2::{GPT2, KVCache};
pub use gpt_oss::GptOss;
pub use gpt_oss_graph::{GptOssDecodeGraph, GraphedGptOssDecoder};
pub use kv_cache::GpuKVCache;
pub use paged_kv_cache::{BlockAllocator, Location, PagedKVCache, BLOCK_SIZE};
pub use qwen3::Qwen3;

View File

@@ -19,7 +19,7 @@ use std::thread;
use xserv_distributed::{TpContext, UniqueId};
use xserv_model::loader;
use xserv_model::{sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
use xserv_model::{sample, sample_greedy_penalized, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
use xserv_tensor::{DType, Device, Tensor};
use xserv_tokenizer::Tokenizer;
@@ -58,6 +58,16 @@ impl TpModel {
struct RankCtx {
model: TpModel,
cache: PagedKVCache,
decoder: GraphedGptOssDecoder,
}
/// Decode one step: gpt-oss batch=1 goes through the CUDA-graph decoder
/// (lazy capture, replay thereafter); everything else runs eager.
fn rank_decode(rc: &mut RankCtx, tokens: &[u32], positions: &[usize], slots: &[usize]) -> Tensor {
match &rc.model {
TpModel::GptOss(m) => rc.decoder.decode(m, tokens, positions, slots, &mut rc.cache),
TpModel::Qwen3(_) => rc.model.forward_decode_paged(tokens, positions, slots, &mut rc.cache),
}
}
fn build_rank(
@@ -81,7 +91,7 @@ fn build_rank(
let cache = PagedKVCache::new_tp(
config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device,
);
RankCtx { model, cache }
RankCtx { model, cache, decoder: GraphedGptOssDecoder::new() }
}
fn worker_loop(
@@ -106,7 +116,7 @@ fn worker_loop(
let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache);
}
TpCommand::Decode { tokens, positions, slots } => {
let _ = rc.model.forward_decode_paged(&tokens, &positions, &slots, &mut rc.cache);
let _ = rank_decode(&mut rc, &tokens, &positions, &slots);
}
TpCommand::Shutdown => {
let _ = ack_tx.send(());
@@ -207,7 +217,7 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
}
let pos = rc.cache.seq_len(slot);
broadcast(&cmd_txs, TpCommand::Decode { tokens: vec![next], positions: vec![pos], slots: vec![slot] });
let logits = rc.model.forward_decode_paged(&[next], &[pos], &[slot], &mut rc.cache);
let logits = rank_decode(&mut rc, &[next], &[pos], &[slot]);
wait_acks(&ack_rx);
next = pick(&logits, &req.sampling, &gen_ids);
gen_ids.push(next);