server: tensor-parallel HTTP engine (--tp N)

tp_engine: rank-0 coordinator owns the scheduler and broadcasts per-token
commands (Register/Prefill/Decode/Free) to worker rank threads; the sampled
token always comes from rank 0, so it's correct for greedy and stochastic
sampling. Serial single-request path (sufficient for the quality benchmark).
--tp N selects it; TP=1 keeps the existing single-GPU Engine unchanged.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-29 11:10:33 +08:00
parent f17011129e
commit 95eb61d639
3 changed files with 211 additions and 3 deletions

View File

@@ -13,6 +13,7 @@ xserv-tensor = { path = "../xserv-tensor" }
xserv-kernels = { path = "../xserv-kernels" }
xserv-model = { path = "../xserv-model" }
xserv-tokenizer = { path = "../xserv-tokenizer" }
xserv-distributed = { path = "../xserv-distributed" }
half.workspace = true
serde.workspace = true
serde_json.workspace = true

View File

@@ -1,5 +1,6 @@
mod api;
mod engine;
mod tp_engine;
use axum::{routing::{get, post}, Extension, Router};
use std::path::PathBuf;
@@ -18,7 +19,7 @@ pub struct AppState {
async fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N]");
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N]");
std::process::exit(1);
}
@@ -45,6 +46,12 @@ async fn main() {
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(8);
let tp: usize = args.iter()
.position(|a| a == "--tp")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(1)
.max(1);
let model_config = ModelConfig::from_file(&model_dir.join("config.json"));
let model_max_seq_len = model_config.max_seq_len();
if model_max_seq_len == 0 {
@@ -69,8 +76,13 @@ async fn main() {
let model_dir_clone = model_dir.clone();
std::thread::spawn(move || {
if tp <= 1 {
let mut engine = engine::Engine::load_with_swap(&model_dir_clone, max_batch, max_seq_len, swap_space_gb);
engine.run(rx);
} else {
// Tensor-parallel path: rank-0 coordinator + worker rank threads.
tp_engine::run_tp(&model_dir_clone, tp, max_seq_len, rx);
}
});
let state = Arc::new(AppState {

View File

@@ -0,0 +1,195 @@
//! Tensor-parallel inference engine for the HTTP server.
//!
//! Serial coordinator model: one rank-0 coordinator thread (the caller) drives
//! generation and owns the scheduler; ranks 1..world are worker threads. For
//! each step the coordinator broadcasts a command (Register/Prefill/Decode/Free)
//! to the workers and runs the same op on its own shard; the per-layer NCCL
//! AllReduces keep all ranks in lockstep. Only the coordinator samples — the
//! chosen token is carried in the next Decode command, so this is correct for
//! both greedy and stochastic sampling.
//!
//! Requests are processed one at a time (sufficient for the quality benchmark,
//! which issues serial requests). Continuous batching across ranks is future
//! work; the single-GPU `Engine` still handles TP=1.
use std::path::{Path, PathBuf};
use std::sync::mpsc;
use std::sync::Arc;
use std::thread;
use xserv_distributed::{TpContext, UniqueId};
use xserv_model::loader;
use xserv_model::{sample, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
use crate::engine::{GenerateEvent, GenerateRequest};
#[derive(Clone)]
enum TpCommand {
Register(usize),
Free(usize),
Prefill { tokens: Vec<u32>, slot: usize },
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
Shutdown,
}
struct RankCtx {
model: Qwen3,
cache: PagedKVCache,
}
fn build_rank(
model_dir: &Path,
config: &ModelConfig,
rank: usize,
world: usize,
device: u32,
max_seq_len: usize,
tp: Option<Arc<TpContext>>,
) -> RankCtx {
let weights = loader::load_model_dir(model_dir, Device::Cpu);
let model = Qwen3::from_weights_tp(config.clone(), weights, rank, world, device, tp);
let local_kv = config.num_kv_heads() / world;
let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE);
let total_blocks = max_blocks_per_seq + 8;
let cache = PagedKVCache::new_tp(
config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device,
);
RankCtx { model, cache }
}
fn worker_loop(
rank: usize,
world: usize,
id: UniqueId,
model_dir: PathBuf,
config: ModelConfig,
max_seq_len: usize,
cmd_rx: mpsc::Receiver<TpCommand>,
ack_tx: mpsc::Sender<()>,
) {
let tp = Arc::new(TpContext::init(rank, world, id, rank as u32));
let mut rc = build_rank(&model_dir, &config, rank, world, rank as u32, max_seq_len, Some(tp));
while let Ok(cmd) = cmd_rx.recv() {
match cmd {
TpCommand::Register(slot) => {
let _ = rc.cache.register_sequence(slot);
}
TpCommand::Free(slot) => rc.cache.free_sequence(slot),
TpCommand::Prefill { tokens, slot } => {
let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache);
}
TpCommand::Decode { tokens, positions, slots } => {
let _ = rc.model.forward_decode_paged(&tokens, &positions, &slots, &mut rc.cache);
}
TpCommand::Shutdown => {
let _ = ack_tx.send(());
break;
}
}
let _ = ack_tx.send(());
}
}
/// Run the TP coordinator (rank 0) on the calling thread. Spawns worker ranks
/// internally and consumes generation requests from `rx`.
pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver<GenerateRequest>) {
assert!(world >= 2, "run_tp requires world >= 2");
let config = ModelConfig::from_file(&model_dir.join("config.json"));
assert!(
config.num_kv_heads() % world == 0,
"num_kv_heads {} not divisible by tp {world}",
config.num_kv_heads()
);
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
let id = xserv_distributed::get_unique_id();
// Spawn worker ranks 1..world.
let (ack_tx, ack_rx) = mpsc::channel::<()>();
let mut cmd_txs: Vec<mpsc::Sender<TpCommand>> = Vec::new();
for rank in 1..world {
let (ctx_tx, ctx_rx) = mpsc::channel::<TpCommand>();
cmd_txs.push(ctx_tx);
let ack_tx = ack_tx.clone();
let model_dir = model_dir.to_path_buf();
let config = config.clone();
thread::spawn(move || {
worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx);
});
}
// Rank 0 (this thread).
let tp = Arc::new(TpContext::init(0, world, id, 0));
let mut rc = build_rank(model_dir, &config, 0, world, 0, max_seq_len, Some(tp));
eprintln!("[tp-engine] ready (tp={world}, max_seq_len={max_seq_len})");
let eos = tokenizer.eos_token_id();
let n_workers = world - 1;
let broadcast = |txs: &[mpsc::Sender<TpCommand>], cmd: TpCommand| {
for t in txs {
let _ = t.send(cmd.clone());
}
};
let wait_acks = |rx: &mpsc::Receiver<()>| {
for _ in 0..n_workers {
let _ = rx.recv();
}
};
let slot = 0usize;
while let Ok(req) = rx.recv() {
broadcast(&cmd_txs, TpCommand::Register(slot));
rc.cache.register_sequence(slot).expect("register slot");
wait_acks(&ack_rx);
// Prefill.
broadcast(&cmd_txs, TpCommand::Prefill { tokens: req.prompt_tokens.clone(), slot });
let logits = rc.model.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache);
wait_acks(&ack_rx);
let mut next = sample(&logits, &req.sampling);
let mut decode_buf: Vec<u8> = Vec::new();
let mut generated = 1usize;
emit_text(&tokenizer, &req, next, eos, &mut decode_buf);
let finish = loop {
if eos == Some(next) {
break "stop";
}
if generated >= req.max_tokens {
break "length";
}
let pos = rc.cache.seq_len(slot);
broadcast(&cmd_txs, TpCommand::Decode { tokens: vec![next], positions: vec![pos], slots: vec![slot] });
let logits = rc.model.forward_decode_paged(&[next], &[pos], &[slot], &mut rc.cache);
wait_acks(&ack_rx);
next = sample(&logits, &req.sampling);
generated += 1;
emit_text(&tokenizer, &req, next, eos, &mut decode_buf);
};
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
if !tail.is_empty() {
let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail });
}
let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() });
broadcast(&cmd_txs, TpCommand::Free(slot));
rc.cache.free_sequence(slot);
wait_acks(&ack_rx);
}
broadcast(&cmd_txs, TpCommand::Shutdown);
}
/// Stream a token's decoded text to the client (EOS contributes no text).
fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, eos: Option<u32>, buf: &mut Vec<u8>) {
if eos == Some(token_id) {
return;
}
let text = tokenizer.decode_token_stream(token_id, buf);
if !text.is_empty() {
let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
}
}