diff --git a/crates/xserv-server/Cargo.toml b/crates/xserv-server/Cargo.toml index 9c9882c..40ed14a 100644 --- a/crates/xserv-server/Cargo.toml +++ b/crates/xserv-server/Cargo.toml @@ -13,6 +13,7 @@ xserv-tensor = { path = "../xserv-tensor" } xserv-kernels = { path = "../xserv-kernels" } xserv-model = { path = "../xserv-model" } xserv-tokenizer = { path = "../xserv-tokenizer" } +xserv-distributed = { path = "../xserv-distributed" } half.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/xserv-server/src/main.rs b/crates/xserv-server/src/main.rs index 45a1ae3..6927161 100644 --- a/crates/xserv-server/src/main.rs +++ b/crates/xserv-server/src/main.rs @@ -1,5 +1,6 @@ mod api; mod engine; +mod tp_engine; use axum::{routing::{get, post}, Extension, Router}; use std::path::PathBuf; @@ -18,7 +19,7 @@ pub struct AppState { async fn main() { let args: Vec = std::env::args().collect(); if args.len() < 2 { - eprintln!("Usage: xserv-server [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N]"); + eprintln!("Usage: xserv-server [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N]"); std::process::exit(1); } @@ -45,6 +46,12 @@ async fn main() { .and_then(|i| args.get(i + 1)) .and_then(|s| s.parse().ok()) .unwrap_or(8); + let tp: usize = args.iter() + .position(|a| a == "--tp") + .and_then(|i| args.get(i + 1)) + .and_then(|s| s.parse().ok()) + .unwrap_or(1) + .max(1); let model_config = ModelConfig::from_file(&model_dir.join("config.json")); let model_max_seq_len = model_config.max_seq_len(); if model_max_seq_len == 0 { @@ -69,8 +76,13 @@ async fn main() { let model_dir_clone = model_dir.clone(); std::thread::spawn(move || { - let mut engine = engine::Engine::load_with_swap(&model_dir_clone, max_batch, max_seq_len, swap_space_gb); - engine.run(rx); + if tp <= 1 { + let mut engine = engine::Engine::load_with_swap(&model_dir_clone, max_batch, max_seq_len, swap_space_gb); + engine.run(rx); + } else { + // Tensor-parallel path: rank-0 coordinator + worker rank threads. + tp_engine::run_tp(&model_dir_clone, tp, max_seq_len, rx); + } }); let state = Arc::new(AppState { diff --git a/crates/xserv-server/src/tp_engine.rs b/crates/xserv-server/src/tp_engine.rs new file mode 100644 index 0000000..3e4a203 --- /dev/null +++ b/crates/xserv-server/src/tp_engine.rs @@ -0,0 +1,195 @@ +//! Tensor-parallel inference engine for the HTTP server. +//! +//! Serial coordinator model: one rank-0 coordinator thread (the caller) drives +//! generation and owns the scheduler; ranks 1..world are worker threads. For +//! each step the coordinator broadcasts a command (Register/Prefill/Decode/Free) +//! to the workers and runs the same op on its own shard; the per-layer NCCL +//! AllReduces keep all ranks in lockstep. Only the coordinator samples — the +//! chosen token is carried in the next Decode command, so this is correct for +//! both greedy and stochastic sampling. +//! +//! Requests are processed one at a time (sufficient for the quality benchmark, +//! which issues serial requests). Continuous batching across ranks is future +//! work; the single-GPU `Engine` still handles TP=1. + +use std::path::{Path, PathBuf}; +use std::sync::mpsc; +use std::sync::Arc; +use std::thread; + +use xserv_distributed::{TpContext, UniqueId}; +use xserv_model::loader; +use xserv_model::{sample, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE}; +use xserv_tensor::{DType, Device}; +use xserv_tokenizer::Tokenizer; + +use crate::engine::{GenerateEvent, GenerateRequest}; + +#[derive(Clone)] +enum TpCommand { + Register(usize), + Free(usize), + Prefill { tokens: Vec, slot: usize }, + Decode { tokens: Vec, positions: Vec, slots: Vec }, + Shutdown, +} + +struct RankCtx { + model: Qwen3, + cache: PagedKVCache, +} + +fn build_rank( + model_dir: &Path, + config: &ModelConfig, + rank: usize, + world: usize, + device: u32, + max_seq_len: usize, + tp: Option>, +) -> RankCtx { + let weights = loader::load_model_dir(model_dir, Device::Cpu); + let model = Qwen3::from_weights_tp(config.clone(), weights, rank, world, device, tp); + let local_kv = config.num_kv_heads() / world; + let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE); + let total_blocks = max_blocks_per_seq + 8; + let cache = PagedKVCache::new_tp( + config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device, + ); + RankCtx { model, cache } +} + +fn worker_loop( + rank: usize, + world: usize, + id: UniqueId, + model_dir: PathBuf, + config: ModelConfig, + max_seq_len: usize, + cmd_rx: mpsc::Receiver, + ack_tx: mpsc::Sender<()>, +) { + let tp = Arc::new(TpContext::init(rank, world, id, rank as u32)); + let mut rc = build_rank(&model_dir, &config, rank, world, rank as u32, max_seq_len, Some(tp)); + while let Ok(cmd) = cmd_rx.recv() { + match cmd { + TpCommand::Register(slot) => { + let _ = rc.cache.register_sequence(slot); + } + TpCommand::Free(slot) => rc.cache.free_sequence(slot), + TpCommand::Prefill { tokens, slot } => { + let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache); + } + TpCommand::Decode { tokens, positions, slots } => { + let _ = rc.model.forward_decode_paged(&tokens, &positions, &slots, &mut rc.cache); + } + TpCommand::Shutdown => { + let _ = ack_tx.send(()); + break; + } + } + let _ = ack_tx.send(()); + } +} + +/// Run the TP coordinator (rank 0) on the calling thread. Spawns worker ranks +/// internally and consumes generation requests from `rx`. +pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver) { + assert!(world >= 2, "run_tp requires world >= 2"); + let config = ModelConfig::from_file(&model_dir.join("config.json")); + assert!( + config.num_kv_heads() % world == 0, + "num_kv_heads {} not divisible by tp {world}", + config.num_kv_heads() + ); + let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json")); + let id = xserv_distributed::get_unique_id(); + + // Spawn worker ranks 1..world. + let (ack_tx, ack_rx) = mpsc::channel::<()>(); + let mut cmd_txs: Vec> = Vec::new(); + for rank in 1..world { + let (ctx_tx, ctx_rx) = mpsc::channel::(); + cmd_txs.push(ctx_tx); + let ack_tx = ack_tx.clone(); + let model_dir = model_dir.to_path_buf(); + let config = config.clone(); + thread::spawn(move || { + worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx); + }); + } + + // Rank 0 (this thread). + let tp = Arc::new(TpContext::init(0, world, id, 0)); + let mut rc = build_rank(model_dir, &config, 0, world, 0, max_seq_len, Some(tp)); + eprintln!("[tp-engine] ready (tp={world}, max_seq_len={max_seq_len})"); + + let eos = tokenizer.eos_token_id(); + let n_workers = world - 1; + let broadcast = |txs: &[mpsc::Sender], cmd: TpCommand| { + for t in txs { + let _ = t.send(cmd.clone()); + } + }; + let wait_acks = |rx: &mpsc::Receiver<()>| { + for _ in 0..n_workers { + let _ = rx.recv(); + } + }; + + let slot = 0usize; + while let Ok(req) = rx.recv() { + broadcast(&cmd_txs, TpCommand::Register(slot)); + rc.cache.register_sequence(slot).expect("register slot"); + wait_acks(&ack_rx); + + // Prefill. + broadcast(&cmd_txs, TpCommand::Prefill { tokens: req.prompt_tokens.clone(), slot }); + let logits = rc.model.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache); + wait_acks(&ack_rx); + let mut next = sample(&logits, &req.sampling); + + let mut decode_buf: Vec = Vec::new(); + let mut generated = 1usize; + emit_text(&tokenizer, &req, next, eos, &mut decode_buf); + + let finish = loop { + if eos == Some(next) { + break "stop"; + } + if generated >= req.max_tokens { + break "length"; + } + let pos = rc.cache.seq_len(slot); + broadcast(&cmd_txs, TpCommand::Decode { tokens: vec![next], positions: vec![pos], slots: vec![slot] }); + let logits = rc.model.forward_decode_paged(&[next], &[pos], &[slot], &mut rc.cache); + wait_acks(&ack_rx); + next = sample(&logits, &req.sampling); + generated += 1; + emit_text(&tokenizer, &req, next, eos, &mut decode_buf); + }; + + let tail = tokenizer.flush_decode_stream(&mut decode_buf); + if !tail.is_empty() { + let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail }); + } + let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() }); + + broadcast(&cmd_txs, TpCommand::Free(slot)); + rc.cache.free_sequence(slot); + wait_acks(&ack_rx); + } + + broadcast(&cmd_txs, TpCommand::Shutdown); +} + +/// Stream a token's decoded text to the client (EOS contributes no text). +fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, eos: Option, buf: &mut Vec) { + if eos == Some(token_id) { + return; + } + let text = tokenizer.decode_token_stream(token_id, buf); + if !text.is_empty() { + let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text }); + } +}