server: tensor-parallel HTTP engine (--tp N)
tp_engine: rank-0 coordinator owns the scheduler and broadcasts per-token commands (Register/Prefill/Decode/Free) to worker rank threads; the sampled token always comes from rank 0, so it's correct for greedy and stochastic sampling. Serial single-request path (sufficient for the quality benchmark). --tp N selects it; TP=1 keeps the existing single-GPU Engine unchanged. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -13,6 +13,7 @@ xserv-tensor = { path = "../xserv-tensor" }
|
||||
xserv-kernels = { path = "../xserv-kernels" }
|
||||
xserv-model = { path = "../xserv-model" }
|
||||
xserv-tokenizer = { path = "../xserv-tokenizer" }
|
||||
xserv-distributed = { path = "../xserv-distributed" }
|
||||
half.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod api;
|
||||
mod engine;
|
||||
mod tp_engine;
|
||||
|
||||
use axum::{routing::{get, post}, Extension, Router};
|
||||
use std::path::PathBuf;
|
||||
@@ -18,7 +19,7 @@ pub struct AppState {
|
||||
async fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N]");
|
||||
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N]");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
@@ -45,6 +46,12 @@ async fn main() {
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(8);
|
||||
let tp: usize = args.iter()
|
||||
.position(|a| a == "--tp")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
let model_config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
let model_max_seq_len = model_config.max_seq_len();
|
||||
if model_max_seq_len == 0 {
|
||||
@@ -69,8 +76,13 @@ async fn main() {
|
||||
|
||||
let model_dir_clone = model_dir.clone();
|
||||
std::thread::spawn(move || {
|
||||
if tp <= 1 {
|
||||
let mut engine = engine::Engine::load_with_swap(&model_dir_clone, max_batch, max_seq_len, swap_space_gb);
|
||||
engine.run(rx);
|
||||
} else {
|
||||
// Tensor-parallel path: rank-0 coordinator + worker rank threads.
|
||||
tp_engine::run_tp(&model_dir_clone, tp, max_seq_len, rx);
|
||||
}
|
||||
});
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
|
||||
195
crates/xserv-server/src/tp_engine.rs
Normal file
195
crates/xserv-server/src/tp_engine.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
//! Tensor-parallel inference engine for the HTTP server.
|
||||
//!
|
||||
//! Serial coordinator model: one rank-0 coordinator thread (the caller) drives
|
||||
//! generation and owns the scheduler; ranks 1..world are worker threads. For
|
||||
//! each step the coordinator broadcasts a command (Register/Prefill/Decode/Free)
|
||||
//! to the workers and runs the same op on its own shard; the per-layer NCCL
|
||||
//! AllReduces keep all ranks in lockstep. Only the coordinator samples — the
|
||||
//! chosen token is carried in the next Decode command, so this is correct for
|
||||
//! both greedy and stochastic sampling.
|
||||
//!
|
||||
//! Requests are processed one at a time (sufficient for the quality benchmark,
|
||||
//! which issues serial requests). Continuous batching across ranks is future
|
||||
//! work; the single-GPU `Engine` still handles TP=1.
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::mpsc;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
use xserv_distributed::{TpContext, UniqueId};
|
||||
use xserv_model::loader;
|
||||
use xserv_model::{sample, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
use crate::engine::{GenerateEvent, GenerateRequest};
|
||||
|
||||
#[derive(Clone)]
|
||||
enum TpCommand {
|
||||
Register(usize),
|
||||
Free(usize),
|
||||
Prefill { tokens: Vec<u32>, slot: usize },
|
||||
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
struct RankCtx {
|
||||
model: Qwen3,
|
||||
cache: PagedKVCache,
|
||||
}
|
||||
|
||||
fn build_rank(
|
||||
model_dir: &Path,
|
||||
config: &ModelConfig,
|
||||
rank: usize,
|
||||
world: usize,
|
||||
device: u32,
|
||||
max_seq_len: usize,
|
||||
tp: Option<Arc<TpContext>>,
|
||||
) -> RankCtx {
|
||||
let weights = loader::load_model_dir(model_dir, Device::Cpu);
|
||||
let model = Qwen3::from_weights_tp(config.clone(), weights, rank, world, device, tp);
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE);
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
let cache = PagedKVCache::new_tp(
|
||||
config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device,
|
||||
);
|
||||
RankCtx { model, cache }
|
||||
}
|
||||
|
||||
fn worker_loop(
|
||||
rank: usize,
|
||||
world: usize,
|
||||
id: UniqueId,
|
||||
model_dir: PathBuf,
|
||||
config: ModelConfig,
|
||||
max_seq_len: usize,
|
||||
cmd_rx: mpsc::Receiver<TpCommand>,
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
) {
|
||||
let tp = Arc::new(TpContext::init(rank, world, id, rank as u32));
|
||||
let mut rc = build_rank(&model_dir, &config, rank, world, rank as u32, max_seq_len, Some(tp));
|
||||
while let Ok(cmd) = cmd_rx.recv() {
|
||||
match cmd {
|
||||
TpCommand::Register(slot) => {
|
||||
let _ = rc.cache.register_sequence(slot);
|
||||
}
|
||||
TpCommand::Free(slot) => rc.cache.free_sequence(slot),
|
||||
TpCommand::Prefill { tokens, slot } => {
|
||||
let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache);
|
||||
}
|
||||
TpCommand::Decode { tokens, positions, slots } => {
|
||||
let _ = rc.model.forward_decode_paged(&tokens, &positions, &slots, &mut rc.cache);
|
||||
}
|
||||
TpCommand::Shutdown => {
|
||||
let _ = ack_tx.send(());
|
||||
break;
|
||||
}
|
||||
}
|
||||
let _ = ack_tx.send(());
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the TP coordinator (rank 0) on the calling thread. Spawns worker ranks
|
||||
/// internally and consumes generation requests from `rx`.
|
||||
pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver<GenerateRequest>) {
|
||||
assert!(world >= 2, "run_tp requires world >= 2");
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
assert!(
|
||||
config.num_kv_heads() % world == 0,
|
||||
"num_kv_heads {} not divisible by tp {world}",
|
||||
config.num_kv_heads()
|
||||
);
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
let id = xserv_distributed::get_unique_id();
|
||||
|
||||
// Spawn worker ranks 1..world.
|
||||
let (ack_tx, ack_rx) = mpsc::channel::<()>();
|
||||
let mut cmd_txs: Vec<mpsc::Sender<TpCommand>> = Vec::new();
|
||||
for rank in 1..world {
|
||||
let (ctx_tx, ctx_rx) = mpsc::channel::<TpCommand>();
|
||||
cmd_txs.push(ctx_tx);
|
||||
let ack_tx = ack_tx.clone();
|
||||
let model_dir = model_dir.to_path_buf();
|
||||
let config = config.clone();
|
||||
thread::spawn(move || {
|
||||
worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx);
|
||||
});
|
||||
}
|
||||
|
||||
// Rank 0 (this thread).
|
||||
let tp = Arc::new(TpContext::init(0, world, id, 0));
|
||||
let mut rc = build_rank(model_dir, &config, 0, world, 0, max_seq_len, Some(tp));
|
||||
eprintln!("[tp-engine] ready (tp={world}, max_seq_len={max_seq_len})");
|
||||
|
||||
let eos = tokenizer.eos_token_id();
|
||||
let n_workers = world - 1;
|
||||
let broadcast = |txs: &[mpsc::Sender<TpCommand>], cmd: TpCommand| {
|
||||
for t in txs {
|
||||
let _ = t.send(cmd.clone());
|
||||
}
|
||||
};
|
||||
let wait_acks = |rx: &mpsc::Receiver<()>| {
|
||||
for _ in 0..n_workers {
|
||||
let _ = rx.recv();
|
||||
}
|
||||
};
|
||||
|
||||
let slot = 0usize;
|
||||
while let Ok(req) = rx.recv() {
|
||||
broadcast(&cmd_txs, TpCommand::Register(slot));
|
||||
rc.cache.register_sequence(slot).expect("register slot");
|
||||
wait_acks(&ack_rx);
|
||||
|
||||
// Prefill.
|
||||
broadcast(&cmd_txs, TpCommand::Prefill { tokens: req.prompt_tokens.clone(), slot });
|
||||
let logits = rc.model.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache);
|
||||
wait_acks(&ack_rx);
|
||||
let mut next = sample(&logits, &req.sampling);
|
||||
|
||||
let mut decode_buf: Vec<u8> = Vec::new();
|
||||
let mut generated = 1usize;
|
||||
emit_text(&tokenizer, &req, next, eos, &mut decode_buf);
|
||||
|
||||
let finish = loop {
|
||||
if eos == Some(next) {
|
||||
break "stop";
|
||||
}
|
||||
if generated >= req.max_tokens {
|
||||
break "length";
|
||||
}
|
||||
let pos = rc.cache.seq_len(slot);
|
||||
broadcast(&cmd_txs, TpCommand::Decode { tokens: vec![next], positions: vec![pos], slots: vec![slot] });
|
||||
let logits = rc.model.forward_decode_paged(&[next], &[pos], &[slot], &mut rc.cache);
|
||||
wait_acks(&ack_rx);
|
||||
next = sample(&logits, &req.sampling);
|
||||
generated += 1;
|
||||
emit_text(&tokenizer, &req, next, eos, &mut decode_buf);
|
||||
};
|
||||
|
||||
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
|
||||
if !tail.is_empty() {
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail });
|
||||
}
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() });
|
||||
|
||||
broadcast(&cmd_txs, TpCommand::Free(slot));
|
||||
rc.cache.free_sequence(slot);
|
||||
wait_acks(&ack_rx);
|
||||
}
|
||||
|
||||
broadcast(&cmd_txs, TpCommand::Shutdown);
|
||||
}
|
||||
|
||||
/// Stream a token's decoded text to the client (EOS contributes no text).
|
||||
fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, eos: Option<u32>, buf: &mut Vec<u8>) {
|
||||
if eos == Some(token_id) {
|
||||
return;
|
||||
}
|
||||
let text = tokenizer.decode_token_stream(token_id, buf);
|
||||
if !text.is_empty() {
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user