//! Tensor-parallel inference engine for the HTTP server. //! //! Serial coordinator model: one rank-0 coordinator thread (the caller) drives //! generation and owns the scheduler; ranks 1..world are worker threads. For //! each step the coordinator broadcasts a command (Register/Prefill/Decode/Free) //! to the workers and runs the same op on its own shard; the per-layer NCCL //! AllReduces keep all ranks in lockstep. Only the coordinator samples — the //! chosen token is carried in the next Decode command, so this is correct for //! both greedy and stochastic sampling. //! //! Requests are processed one at a time (sufficient for the quality benchmark, //! which issues serial requests). Continuous batching across ranks is future //! work; the single-GPU `Engine` still handles TP=1. use std::path::{Path, PathBuf}; use std::sync::Arc; use std::sync::mpsc; use std::thread; use xserv_distributed::{TpContext, UniqueId}; use xserv_model::loader; use xserv_model::{ BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, sample, sample_greedy_penalized, }; use xserv_tensor::{DType, Device, Tensor}; use xserv_tokenizer::Tokenizer; use crate::engine::{GenerateEvent, GenerateRequest}; #[derive(Clone)] enum TpCommand { Register(usize), Free(usize), Prefill { tokens: Vec, slot: usize, }, Decode { tokens: Vec, positions: Vec, slots: Vec, }, Shutdown, } enum TpModel { Qwen3(Qwen3), GptOss(GptOss), } impl TpModel { fn forward_prefill_paged( &self, tokens: &[u32], slot: usize, cache: &mut PagedKVCache, ) -> Tensor { match self { TpModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache), TpModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache), } } fn forward_decode_paged( &self, tokens: &[u32], positions: &[usize], slots: &[usize], cache: &mut PagedKVCache, ) -> Tensor { match self { TpModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache), TpModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache), } } } struct RankCtx { model: TpModel, cache: PagedKVCache, decoder: GraphedGptOssDecoder, } /// Decode one step: gpt-oss batch=1 goes through the CUDA-graph decoder /// (lazy capture, replay thereafter); everything else runs eager. fn rank_decode(rc: &mut RankCtx, tokens: &[u32], positions: &[usize], slots: &[usize]) -> Tensor { match &rc.model { TpModel::GptOss(m) => rc .decoder .decode(m, tokens, positions, slots, &mut rc.cache), TpModel::Qwen3(_) => rc .model .forward_decode_paged(tokens, positions, slots, &mut rc.cache), } } fn build_rank( model_dir: &Path, config: &ModelConfig, rank: usize, world: usize, device: u32, max_seq_len: usize, tp: Option>, ) -> RankCtx { let weights = loader::load_model_dir(model_dir, Device::Cpu); let model = if config.is_moe() { TpModel::GptOss(GptOss::from_weights_tp( config.clone(), weights, rank, world, device, tp, )) } else { TpModel::Qwen3(Qwen3::from_weights_tp( config.clone(), weights, rank, world, device, tp, )) }; let local_kv = config.num_kv_heads() / world; let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; let total_blocks = max_blocks_per_seq + 8; let cache = PagedKVCache::new_tp( config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device, ); RankCtx { model, cache, decoder: GraphedGptOssDecoder::new(), } } fn worker_loop( rank: usize, world: usize, id: UniqueId, model_dir: PathBuf, config: ModelConfig, max_seq_len: usize, cmd_rx: mpsc::Receiver, ack_tx: mpsc::Sender<()>, ) { let tp = Arc::new(TpContext::init(rank, world, id, rank as u32)); let mut rc = build_rank( &model_dir, &config, rank, world, rank as u32, max_seq_len, Some(tp), ); while let Ok(cmd) = cmd_rx.recv() { match cmd { TpCommand::Register(slot) => { let _ = rc.cache.register_sequence(slot); } TpCommand::Free(slot) => rc.cache.free_sequence(slot), TpCommand::Prefill { tokens, slot } => { let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache); } TpCommand::Decode { tokens, positions, slots, } => { let _ = rank_decode(&mut rc, &tokens, &positions, &slots); } TpCommand::Shutdown => { let _ = ack_tx.send(()); break; } } let _ = ack_tx.send(()); } } /// Run the TP coordinator (rank 0) on the calling thread. Spawns worker ranks /// internally and consumes generation requests from `rx`. pub fn run_tp( model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver, ) { // world=1 is a valid single-rank configuration (gpt-oss has no // single-GPU engine path; NCCL init and all_reduce no-op at world=1). assert!(world >= 1, "run_tp requires world >= 1"); let config = ModelConfig::from_file(&model_dir.join("config.json")); assert!( config.num_kv_heads() % world == 0, "num_kv_heads {} not divisible by tp {world}", config.num_kv_heads() ); let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json")); let id = xserv_distributed::get_unique_id(); // Spawn worker ranks 1..world. let (ack_tx, ack_rx) = mpsc::channel::<()>(); let mut cmd_txs: Vec> = Vec::new(); for rank in 1..world { let (ctx_tx, ctx_rx) = mpsc::channel::(); cmd_txs.push(ctx_tx); let ack_tx = ack_tx.clone(); let model_dir = model_dir.to_path_buf(); let config = config.clone(); thread::spawn(move || { worker_loop( rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx, ); }); } // Rank 0 (this thread). let tp = Arc::new(TpContext::init(0, world, id, 0)); let mut rc = build_rank(model_dir, &config, 0, world, 0, max_seq_len, Some(tp)); eprintln!("[tp-engine] ready (tp={world}, max_seq_len={max_seq_len})"); // Optional repetition penalty to break greedy repetition loops (reasoning // models loop under pure greedy when numerics diverge from the reference). // Off by default; XSERV_REP_PENALTY>1 enables it over the last // XSERV_REP_WINDOW generated tokens. Applied only on the greedy path. let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY") .ok() .and_then(|s| s.parse().ok()) .unwrap_or(1.0); let rep_window: usize = std::env::var("XSERV_REP_WINDOW") .ok() .and_then(|s| s.parse().ok()) .unwrap_or(128); let pick = |logits: &Tensor, sp: &xserv_model::SamplingParams, history: &[u32]| -> u32 { if rep_penalty > 1.0 && sp.temperature == 0.0 { let start = history.len().saturating_sub(rep_window); sample_greedy_penalized(logits, &history[start..], rep_penalty) } else { sample(logits, sp) } }; let n_workers = world - 1; let broadcast = |txs: &[mpsc::Sender], cmd: TpCommand| { for t in txs { let _ = t.send(cmd.clone()); } }; let wait_acks = |rx: &mpsc::Receiver<()>| { for _ in 0..n_workers { let _ = rx.recv(); } }; let slot = 0usize; while let Ok(req) = rx.recv() { broadcast(&cmd_txs, TpCommand::Register(slot)); rc.cache.register_sequence(slot).expect("register slot"); wait_acks(&ack_rx); // Prefill. broadcast( &cmd_txs, TpCommand::Prefill { tokens: req.prompt_tokens.clone(), slot, }, ); let logits = rc .model .forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache); wait_acks(&ack_rx); let mut gen_ids: Vec = Vec::new(); let mut next = pick(&logits, &req.sampling, &gen_ids); gen_ids.push(next); let mut decode_buf: Vec = Vec::new(); let mut generated = 1usize; let mut stalled = !emit_text(&tokenizer, &req, next, &mut decode_buf); let finish = loop { if stalled { break "error"; } if tokenizer.is_eos(next) { break "stop"; } if generated >= req.max_tokens { break "length"; } let pos = rc.cache.seq_len(slot); broadcast( &cmd_txs, TpCommand::Decode { tokens: vec![next], positions: vec![pos], slots: vec![slot], }, ); let logits = rank_decode(&mut rc, &[next], &[pos], &[slot]); wait_acks(&ack_rx); next = pick(&logits, &req.sampling, &gen_ids); gen_ids.push(next); generated += 1; stalled = !emit_text(&tokenizer, &req, next, &mut decode_buf); }; let tail = tokenizer.flush_decode_stream(&mut decode_buf); if !tail.is_empty() { let _ = req.sender.try_send(GenerateEvent::Token { id: next, text: tail, }); } let _ = req.sender.try_send(GenerateEvent::Done { finish_reason: finish.to_string(), }); broadcast(&cmd_txs, TpCommand::Free(slot)); rc.cache.free_sequence(slot); wait_acks(&ack_rx); } broadcast(&cmd_txs, TpCommand::Shutdown); } /// Stream a token's decoded text to the client (EOS contributes no text). /// Returns false if the send would block (client too slow) or the client is /// gone — the caller stops generating so the serial coordinator thread is free /// to admit the next request instead of blocking on one slow consumer. fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: &mut Vec) -> bool { if tokenizer.is_eos(token_id) { return true; } let text = tokenizer.decode_token_stream(token_id, buf); if !text.is_empty() { return req .sender .try_send(GenerateEvent::Token { id: token_id, text }) .is_ok(); } true }