gpt-oss: replay the whole batch=1 decode step as one CUDA graph
Split forward_decode_paged into host bookkeeping (decode_prepare + ids/pos upload + advance_seq_len) and a pure-GPU decode_core. The paged-KV and sparse-MoE designs already read every per-step variable (block table, context lens, expert ids) from stable-address device buffers, so decode_core captures as-is. GptOssDecodeGraph captures lazily on the second decode step (the first eager step warms cuBLAS) after a "retained warmup": the step runs once with the allocator quarantine on, stocking the pool with a dedicated block for every allocation so the capture itself never pool-misses (a cudaMalloc while capturing is illegal — and the capture's own quarantine is what would otherwise starve the pool). NCCL all-reduces capture cleanly; TP=2 replays in lockstep. Wired into tp_engine, bench-gpt-oss, and xserv-chat via GraphedGptOssDecoder (batch>1 falls back to eager; XSERV_DECODE_GRAPH=0 disables). Greedy tokens identical to eager. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use xserv_distributed::{TpContext, UniqueId, get_unique_id};
|
||||
use xserv_model::{loader, GptOss, ModelConfig, PagedKVCache, BLOCK_SIZE};
|
||||
use xserv_model::{loader, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, BLOCK_SIZE};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -172,6 +172,7 @@ fn main() {
|
||||
print!("{prompt}");
|
||||
|
||||
// Decode
|
||||
let mut decoder = GraphedGptOssDecoder::new();
|
||||
let decode_start = Instant::now();
|
||||
for _ in 1..max_tokens {
|
||||
let text = tokenizer.decode(&[next]);
|
||||
@@ -183,7 +184,7 @@ fn main() {
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode {
|
||||
tokens: vec![next], positions: vec![pos], slots: vec![slot],
|
||||
});
|
||||
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut cache);
|
||||
let logits = decoder.decode(&model, &[next], &[pos], &[slot], &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
|
||||
next = sample_greedy_last(&logits);
|
||||
@@ -250,6 +251,7 @@ fn worker_loop(
|
||||
eprintln!("[rank {rank}] Ready.");
|
||||
ack_tx.send(()).unwrap();
|
||||
|
||||
let mut decoder = GraphedGptOssDecoder::new();
|
||||
while let Ok(cmd) = rx.recv() {
|
||||
match cmd {
|
||||
WorkerCmd::Register(slot) => {
|
||||
@@ -259,7 +261,7 @@ fn worker_loop(
|
||||
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
|
||||
}
|
||||
WorkerCmd::Decode { tokens, positions, slots } => {
|
||||
let _ = model.forward_decode_paged(&tokens, &positions, &slots, &mut cache);
|
||||
let _ = decoder.decode(&model, &tokens, &positions, &slots, &mut cache);
|
||||
}
|
||||
WorkerCmd::Shutdown => break,
|
||||
}
|
||||
@@ -286,6 +288,11 @@ fn wait_workers(handles: &[(std::thread::JoinHandle<()>, std::sync::mpsc::Receiv
|
||||
fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 {
|
||||
use half::bf16;
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
// GPU argmax fast path (4-byte D2H instead of the full logits row).
|
||||
if logits.dtype() == xserv_tensor::DType::BF16 && logits.is_contiguous() {
|
||||
let ids = xserv_kernels::argmax_bf16_to_host(logits);
|
||||
return *ids.last().unwrap();
|
||||
}
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::path::PathBuf;
|
||||
use std::sync::{mpsc, Arc};
|
||||
use std::thread;
|
||||
|
||||
use xserv_model::{loader, sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
|
||||
use xserv_model::{GraphedGptOssDecoder, loader, sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -77,6 +77,7 @@ fn tp_worker_loop(
|
||||
let mut cache = PagedKVCache::new_tp(
|
||||
&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, rank as u32,
|
||||
);
|
||||
let mut decoder = GraphedGptOssDecoder::new();
|
||||
while let Ok(cmd) = cmd_rx.recv() {
|
||||
match cmd {
|
||||
TpCommand::Register(slot) => { let _ = cache.register_sequence(slot); }
|
||||
@@ -85,7 +86,7 @@ fn tp_worker_loop(
|
||||
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
|
||||
}
|
||||
TpCommand::Decode { tokens, positions, slots } => {
|
||||
let _ = model.forward_decode_paged(&tokens, &positions, &slots, &mut cache);
|
||||
let _ = chat_decode(&model, &mut decoder, &tokens, &positions, &slots, &mut cache);
|
||||
}
|
||||
}
|
||||
let _ = ack_tx.send(());
|
||||
@@ -321,6 +322,7 @@ fn main() {
|
||||
};
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&opts.model_dir.join("tokenizer.json"));
|
||||
let mut decoder = GraphedGptOssDecoder::new();
|
||||
if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); }
|
||||
cache.register_sequence(SLOT).expect("register chat slot");
|
||||
let use_color = opts.color && io::stdout().is_terminal();
|
||||
@@ -384,7 +386,7 @@ fn main() {
|
||||
print!("assistant> ");
|
||||
io::stdout().flush().unwrap();
|
||||
let (_finish, answer) = generate_with_paged_cache(
|
||||
&model, &mut cache, &tokenizer, &prompt_tokens, &opts.sampling,
|
||||
&model, &mut decoder, &mut cache, &tokenizer, &prompt_tokens, &opts.sampling,
|
||||
max_new_tokens, use_color, &tp_handle, is_moe, opts.enable_thinking,
|
||||
);
|
||||
moe_history.push((input.to_string(), answer));
|
||||
@@ -421,6 +423,7 @@ fn main() {
|
||||
io::stdout().flush().unwrap();
|
||||
let (finish, _answer) = generate_with_paged_cache(
|
||||
&model,
|
||||
&mut decoder,
|
||||
&mut cache,
|
||||
&tokenizer,
|
||||
&prompt_tokens,
|
||||
@@ -679,8 +682,23 @@ fn today_ymd() -> String {
|
||||
format!("{y:04}-{m:02}-{d:02}")
|
||||
}
|
||||
|
||||
fn chat_decode(
|
||||
model: &ChatModel,
|
||||
decoder: &mut GraphedGptOssDecoder,
|
||||
tokens: &[u32],
|
||||
positions: &[usize],
|
||||
slots: &[usize],
|
||||
cache: &mut PagedKVCache,
|
||||
) -> xserv_tensor::Tensor {
|
||||
match model {
|
||||
ChatModel::GptOss(m) => decoder.decode(m, tokens, positions, slots, cache),
|
||||
ChatModel::Qwen3(_) => model.forward_decode_paged(tokens, positions, slots, cache),
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_with_paged_cache(
|
||||
model: &ChatModel,
|
||||
decoder: &mut GraphedGptOssDecoder,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_tokens: &[u32],
|
||||
@@ -745,7 +763,7 @@ fn generate_with_paged_cache(
|
||||
for _ in 0..max_tokens {
|
||||
let position = cache.seq_len(SLOT);
|
||||
if let Some(h) = tp { h.send(TpCommand::Decode { tokens: vec![next], positions: vec![position], slots: vec![SLOT] }); }
|
||||
let logits = model.forward_decode_paged(&[next], &[position], &[SLOT], cache);
|
||||
let logits = chat_decode(model, decoder, &[next], &[position], &[SLOT], cache);
|
||||
if let Some(h) = tp { h.wait(); }
|
||||
if tokenizer.is_eos(next) {
|
||||
print_stream_text(
|
||||
|
||||
@@ -373,24 +373,62 @@ impl GptOss {
|
||||
assert_eq!(seq_slots.len(), batch);
|
||||
assert!(batch > 0);
|
||||
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.norm_eps();
|
||||
self.decode_prepare(positions, seq_slots, paged_cache);
|
||||
|
||||
// Upload token ids + positions, then run the pure-GPU core.
|
||||
let ids_gpu = upload_u32(tokens);
|
||||
let positions_u32: Vec<u32> = positions.iter().map(|&p| p as u32).collect();
|
||||
let pos_gpu = upload_u32(&positions_u32);
|
||||
let logits = self.decode_core(
|
||||
ids_gpu.as_ptr() as *const c_void,
|
||||
pos_gpu.as_ptr() as *const c_void,
|
||||
batch,
|
||||
paged_cache,
|
||||
);
|
||||
|
||||
for &slot in seq_slots {
|
||||
paged_cache.advance_seq_len(slot, 1);
|
||||
}
|
||||
logits
|
||||
}
|
||||
|
||||
/// Host-side per-step cache bookkeeping: block allocation + uploading
|
||||
/// block tables / context lens to their (stable-address) GPU buffers.
|
||||
/// Runs OUTSIDE the CUDA-graph captured region.
|
||||
pub fn decode_prepare(
|
||||
&self,
|
||||
positions: &[usize],
|
||||
seq_slots: &[usize],
|
||||
paged_cache: &mut PagedKVCache,
|
||||
) {
|
||||
let kv_lens: Vec<i32> = positions.iter().map(|&p| (p + 1) as i32).collect();
|
||||
for (b, &slot) in seq_slots.iter().enumerate() {
|
||||
paged_cache.ensure_capacity(slot, positions[b] + 1);
|
||||
}
|
||||
paged_cache.sync_active_batch_with_lens(seq_slots, &kv_lens);
|
||||
}
|
||||
|
||||
/// The pure-GPU decode step: embedding → 24 layers → final norm → logits.
|
||||
/// Token ids and positions are read from device buffers; every other input
|
||||
/// (weights, KV pools, block table, context lens) has a stable address —
|
||||
/// which is exactly what makes this region CUDA-graph capturable.
|
||||
pub fn decode_core(
|
||||
&self,
|
||||
ids_gpu: *const c_void,
|
||||
pos_gpu: *const c_void,
|
||||
batch: usize,
|
||||
paged_cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.norm_eps();
|
||||
|
||||
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
|
||||
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
|
||||
let max_blocks = paged_cache.max_blocks_per_seq();
|
||||
|
||||
let positions_u32: Vec<u32> = positions.iter().map(|&p| p as u32).collect();
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, tokens);
|
||||
let mut x = embedding_device_ids(&self.embed_tokens, ids_gpu, batch);
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
@@ -407,8 +445,8 @@ impl GptOss {
|
||||
let k_3d = k_all.reshape(&[batch, num_kv_heads, head_dim]);
|
||||
|
||||
// RoPE (no QK-norm for gpt-oss)
|
||||
rope_inplace(&q_3d, &self.rope_cache, &positions_u32);
|
||||
rope_inplace(&k_3d, &self.rope_cache, &positions_u32);
|
||||
rope_inplace_device_pos(&q_3d, &self.rope_cache, pos_gpu);
|
||||
rope_inplace_device_pos(&k_3d, &self.rope_cache, pos_gpu);
|
||||
|
||||
let v_3d = v_all.reshape(&[batch, num_kv_heads, head_dim]);
|
||||
|
||||
@@ -445,11 +483,6 @@ impl GptOss {
|
||||
x = xserv_kernels::add(&residual, &moe_out);
|
||||
}
|
||||
|
||||
// Advance KV cache
|
||||
for &slot in seq_slots {
|
||||
paged_cache.advance_seq_len(slot, 1);
|
||||
}
|
||||
|
||||
let x = Self::norm(&x, &self.norm, &self.norm_bias, eps);
|
||||
matmul_2d(&x, &self.lm_head_t)
|
||||
}
|
||||
@@ -673,6 +706,16 @@ impl GptOss {
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
/// Upload a u32 slice to a pooled GPU buffer (synchronous H2D).
|
||||
fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4)
|
||||
};
|
||||
let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc u32 upload");
|
||||
buf.copy_from_host(bytes).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
/// XSERV_DENSE_MOE=1 forces the dense all-expert path (A/B benchmarking).
|
||||
fn dense_moe_forced() -> bool {
|
||||
static FORCED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
|
||||
|
||||
172
crates/xserv-model/src/gpt_oss_graph.rs
Normal file
172
crates/xserv-model/src/gpt_oss_graph.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
//! CUDA-graph replay for gpt-oss batch=1 decode (Phase 21).
|
||||
//!
|
||||
//! A decode step launches ~200 kernels; with sparse MoE the GPU work is only
|
||||
//! a few ms, so launch overhead dominates TPOT. The whole step (embedding →
|
||||
//! 24 layers → logits) is captured ONCE into a CUDA graph and replayed per
|
||||
//! token with a single `cudaGraphLaunch`.
|
||||
//!
|
||||
//! Why the existing forward is capturable as-is:
|
||||
//! - Every per-step variable input lives in a stable-address device buffer
|
||||
//! whose CONTENTS are updated outside the captured region: token id and
|
||||
//! position (persistent buffers owned here), block table and context lens
|
||||
//! (PagedKVCache GPU buffers, refreshed by `decode_prepare`). The KV scatter
|
||||
//! and paged attention kernels read their write/read positions from those
|
||||
//! buffers, and the sparse-MoE GEMVs read expert ids from `topk_ids` written
|
||||
//! earlier in the same graph — all data-dependent, no host branching.
|
||||
//! - Kernel launches go through the thread-local launch stream
|
||||
//! (`xserv_cuda::stream::push_stream`), so the capture stream sees them.
|
||||
//! - Intermediate tensors come from the caching allocator. Blocks freed while
|
||||
//! capturing are quarantined (`allocator::begin_retain`) for the graph's
|
||||
//! lifetime so no later allocation can take ownership of memory the graph
|
||||
//! still references on every replay.
|
||||
//!
|
||||
//! Capture preconditions: at least one EAGER decode step must have run first,
|
||||
//! so the allocator pool already holds every bucket size the step needs
|
||||
//! (a pool-miss inside capture would call cudaMalloc — illegal while
|
||||
//! capturing) and cuBLAS has finished its one-time per-shape setup.
|
||||
|
||||
use std::ffi::c_void;
|
||||
|
||||
use xserv_cuda::allocator::{self, RetainedBlocks};
|
||||
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
|
||||
use xserv_tensor::Tensor;
|
||||
|
||||
use crate::gpt_oss::GptOss;
|
||||
use crate::paged_kv_cache::PagedKVCache;
|
||||
|
||||
pub struct GptOssDecodeGraph {
|
||||
stream: CudaStream,
|
||||
graph: CudaGraph,
|
||||
ids_buf: GpuBuffer, // [1] u32, persistent graph input
|
||||
pos_buf: GpuBuffer, // [1] u32, persistent graph input
|
||||
logits: Tensor, // graph output; rewritten in place by every replay
|
||||
_arena: RetainedBlocks,
|
||||
}
|
||||
|
||||
impl GptOssDecodeGraph {
|
||||
/// Capture one batch=1 decode step and replay it once (capture records
|
||||
/// without executing, so the replay performs this token's computation).
|
||||
pub fn capture(
|
||||
model: &GptOss,
|
||||
token: u32,
|
||||
position: usize,
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Self {
|
||||
let stream = CudaStream::new().expect("create capture stream");
|
||||
let mut ids_buf = allocator::cached_alloc(4).expect("alloc ids buf");
|
||||
let mut pos_buf = allocator::cached_alloc(4).expect("alloc pos buf");
|
||||
|
||||
model.decode_prepare(&[position], &[slot], cache);
|
||||
ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
|
||||
pos_buf.copy_from_host(&(position as u32).to_le_bytes()).unwrap();
|
||||
|
||||
// Retained warmup: run the exact step once eagerly with the quarantine
|
||||
// ON. Freed intermediates are held back instead of recycled, so the
|
||||
// pool ends up stocked with a dedicated block for EVERY allocation the
|
||||
// step performs. The capture below repeats the same allocation
|
||||
// sequence and therefore never misses the pool — a pool miss would
|
||||
// call cudaMalloc, which is illegal while a stream is capturing (this
|
||||
// is also why one block per bucket is not enough: the capture's own
|
||||
// quarantine keeps freed blocks out of reuse). Re-running the step is
|
||||
// idempotent: the KV scatter rewrites the same cache position.
|
||||
allocator::begin_retain();
|
||||
{
|
||||
let _guard = xserv_cuda::push_stream(&stream);
|
||||
let _ = model.decode_core(
|
||||
ids_buf.as_ptr() as *const c_void,
|
||||
pos_buf.as_ptr() as *const c_void,
|
||||
1,
|
||||
cache,
|
||||
);
|
||||
}
|
||||
drop(allocator::end_retain()); // release the warmup blocks to the pool
|
||||
stream.synchronize().expect("warmup sync");
|
||||
|
||||
allocator::begin_retain();
|
||||
let mut graph = CudaGraph::new();
|
||||
let logits;
|
||||
{
|
||||
let _guard = xserv_cuda::stream::push_stream(&stream);
|
||||
graph.begin_capture(&stream).expect("begin decode-graph capture");
|
||||
logits = model.decode_core(
|
||||
ids_buf.as_ptr() as *const c_void,
|
||||
pos_buf.as_ptr() as *const c_void,
|
||||
1,
|
||||
cache,
|
||||
);
|
||||
graph.end_capture(&stream).expect("end decode-graph capture");
|
||||
}
|
||||
let arena = allocator::end_retain();
|
||||
|
||||
graph.launch(&stream).expect("first decode-graph replay");
|
||||
cache.advance_seq_len(slot, 1);
|
||||
|
||||
Self { stream, graph, ids_buf, pos_buf, logits, _arena: arena }
|
||||
}
|
||||
|
||||
/// Run one decode step by replaying the captured graph.
|
||||
pub fn step(
|
||||
&mut self,
|
||||
model: &GptOss,
|
||||
token: u32,
|
||||
position: usize,
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
model.decode_prepare(&[position], &[slot], cache);
|
||||
self.ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
|
||||
self.pos_buf.copy_from_host(&(position as u32).to_le_bytes()).unwrap();
|
||||
self.graph.launch(&self.stream).expect("decode-graph replay");
|
||||
cache.advance_seq_len(slot, 1);
|
||||
// Shallow clone: the caller reads these logits before the next replay
|
||||
// rewrites the underlying buffer.
|
||||
self.logits.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Lazy capture policy: first decode step of the process runs eager (warms the
|
||||
/// allocator pool + cuBLAS so capture performs no "unsafe" CUDA calls), the
|
||||
/// second is captured, the rest replay. Batch>1 always falls back to eager.
|
||||
/// Disable with XSERV_DECODE_GRAPH=0.
|
||||
pub struct GraphedGptOssDecoder {
|
||||
graph: Option<GptOssDecodeGraph>,
|
||||
eager_steps: u32,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl GraphedGptOssDecoder {
|
||||
pub fn new() -> Self {
|
||||
let enabled = std::env::var("XSERV_DECODE_GRAPH").map(|v| v != "0").unwrap_or(true);
|
||||
Self { graph: None, eager_steps: 0, enabled }
|
||||
}
|
||||
|
||||
pub fn decode(
|
||||
&mut self,
|
||||
model: &GptOss,
|
||||
tokens: &[u32],
|
||||
positions: &[usize],
|
||||
slots: &[usize],
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
if self.enabled && tokens.len() == 1 {
|
||||
if let Some(g) = self.graph.as_mut() {
|
||||
return g.step(model, tokens[0], positions[0], slots[0], cache);
|
||||
}
|
||||
if self.eager_steps >= 1 {
|
||||
let g = GptOssDecodeGraph::capture(model, tokens[0], positions[0], slots[0], cache);
|
||||
let logits = g.logits.clone();
|
||||
self.graph = Some(g);
|
||||
return logits;
|
||||
}
|
||||
}
|
||||
self.eager_steps += 1;
|
||||
model.forward_decode_paged(tokens, positions, slots, cache)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GraphedGptOssDecoder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ pub mod config;
|
||||
pub mod decode_graph;
|
||||
pub mod gpt2;
|
||||
pub mod gpt_oss;
|
||||
pub mod gpt_oss_graph;
|
||||
pub mod kv_cache;
|
||||
pub mod loader;
|
||||
pub mod paged_kv_cache;
|
||||
@@ -12,6 +13,7 @@ pub use config::ModelConfig;
|
||||
pub use decode_graph::{DecodeGraphState, LayerWeightPtrs};
|
||||
pub use gpt2::{GPT2, KVCache};
|
||||
pub use gpt_oss::GptOss;
|
||||
pub use gpt_oss_graph::{GptOssDecodeGraph, GraphedGptOssDecoder};
|
||||
pub use kv_cache::GpuKVCache;
|
||||
pub use paged_kv_cache::{BlockAllocator, Location, PagedKVCache, BLOCK_SIZE};
|
||||
pub use qwen3::Qwen3;
|
||||
|
||||
@@ -19,7 +19,7 @@ use std::thread;
|
||||
|
||||
use xserv_distributed::{TpContext, UniqueId};
|
||||
use xserv_model::loader;
|
||||
use xserv_model::{sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
|
||||
use xserv_model::{sample, sample_greedy_penalized, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -58,6 +58,16 @@ impl TpModel {
|
||||
struct RankCtx {
|
||||
model: TpModel,
|
||||
cache: PagedKVCache,
|
||||
decoder: GraphedGptOssDecoder,
|
||||
}
|
||||
|
||||
/// Decode one step: gpt-oss batch=1 goes through the CUDA-graph decoder
|
||||
/// (lazy capture, replay thereafter); everything else runs eager.
|
||||
fn rank_decode(rc: &mut RankCtx, tokens: &[u32], positions: &[usize], slots: &[usize]) -> Tensor {
|
||||
match &rc.model {
|
||||
TpModel::GptOss(m) => rc.decoder.decode(m, tokens, positions, slots, &mut rc.cache),
|
||||
TpModel::Qwen3(_) => rc.model.forward_decode_paged(tokens, positions, slots, &mut rc.cache),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_rank(
|
||||
@@ -81,7 +91,7 @@ fn build_rank(
|
||||
let cache = PagedKVCache::new_tp(
|
||||
config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device,
|
||||
);
|
||||
RankCtx { model, cache }
|
||||
RankCtx { model, cache, decoder: GraphedGptOssDecoder::new() }
|
||||
}
|
||||
|
||||
fn worker_loop(
|
||||
@@ -106,7 +116,7 @@ fn worker_loop(
|
||||
let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache);
|
||||
}
|
||||
TpCommand::Decode { tokens, positions, slots } => {
|
||||
let _ = rc.model.forward_decode_paged(&tokens, &positions, &slots, &mut rc.cache);
|
||||
let _ = rank_decode(&mut rc, &tokens, &positions, &slots);
|
||||
}
|
||||
TpCommand::Shutdown => {
|
||||
let _ = ack_tx.send(());
|
||||
@@ -207,7 +217,7 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
}
|
||||
let pos = rc.cache.seq_len(slot);
|
||||
broadcast(&cmd_txs, TpCommand::Decode { tokens: vec![next], positions: vec![pos], slots: vec![slot] });
|
||||
let logits = rc.model.forward_decode_paged(&[next], &[pos], &[slot], &mut rc.cache);
|
||||
let logits = rank_decode(&mut rc, &[next], &[pos], &[slot]);
|
||||
wait_acks(&ack_rx);
|
||||
next = pick(&logits, &req.sampling, &gen_ids);
|
||||
gen_ids.push(next);
|
||||
|
||||
Reference in New Issue
Block a user