server: tensor-parallel HTTP engine (--tp N)
tp_engine: rank-0 coordinator owns the scheduler and broadcasts per-token commands (Register/Prefill/Decode/Free) to worker rank threads; the sampled token always comes from rank 0, so it's correct for greedy and stochastic sampling. Serial single-request path (sufficient for the quality benchmark). --tp N selects it; TP=1 keeps the existing single-GPU Engine unchanged. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -13,6 +13,7 @@ xserv-tensor = { path = "../xserv-tensor" }
|
|||||||
xserv-kernels = { path = "../xserv-kernels" }
|
xserv-kernels = { path = "../xserv-kernels" }
|
||||||
xserv-model = { path = "../xserv-model" }
|
xserv-model = { path = "../xserv-model" }
|
||||||
xserv-tokenizer = { path = "../xserv-tokenizer" }
|
xserv-tokenizer = { path = "../xserv-tokenizer" }
|
||||||
|
xserv-distributed = { path = "../xserv-distributed" }
|
||||||
half.workspace = true
|
half.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
mod api;
|
mod api;
|
||||||
mod engine;
|
mod engine;
|
||||||
|
mod tp_engine;
|
||||||
|
|
||||||
use axum::{routing::{get, post}, Extension, Router};
|
use axum::{routing::{get, post}, Extension, Router};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
@@ -18,7 +19,7 @@ pub struct AppState {
|
|||||||
async fn main() {
|
async fn main() {
|
||||||
let args: Vec<String> = std::env::args().collect();
|
let args: Vec<String> = std::env::args().collect();
|
||||||
if args.len() < 2 {
|
if args.len() < 2 {
|
||||||
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N]");
|
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N]");
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -45,6 +46,12 @@ async fn main() {
|
|||||||
.and_then(|i| args.get(i + 1))
|
.and_then(|i| args.get(i + 1))
|
||||||
.and_then(|s| s.parse().ok())
|
.and_then(|s| s.parse().ok())
|
||||||
.unwrap_or(8);
|
.unwrap_or(8);
|
||||||
|
let tp: usize = args.iter()
|
||||||
|
.position(|a| a == "--tp")
|
||||||
|
.and_then(|i| args.get(i + 1))
|
||||||
|
.and_then(|s| s.parse().ok())
|
||||||
|
.unwrap_or(1)
|
||||||
|
.max(1);
|
||||||
let model_config = ModelConfig::from_file(&model_dir.join("config.json"));
|
let model_config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||||
let model_max_seq_len = model_config.max_seq_len();
|
let model_max_seq_len = model_config.max_seq_len();
|
||||||
if model_max_seq_len == 0 {
|
if model_max_seq_len == 0 {
|
||||||
@@ -69,8 +76,13 @@ async fn main() {
|
|||||||
|
|
||||||
let model_dir_clone = model_dir.clone();
|
let model_dir_clone = model_dir.clone();
|
||||||
std::thread::spawn(move || {
|
std::thread::spawn(move || {
|
||||||
let mut engine = engine::Engine::load_with_swap(&model_dir_clone, max_batch, max_seq_len, swap_space_gb);
|
if tp <= 1 {
|
||||||
engine.run(rx);
|
let mut engine = engine::Engine::load_with_swap(&model_dir_clone, max_batch, max_seq_len, swap_space_gb);
|
||||||
|
engine.run(rx);
|
||||||
|
} else {
|
||||||
|
// Tensor-parallel path: rank-0 coordinator + worker rank threads.
|
||||||
|
tp_engine::run_tp(&model_dir_clone, tp, max_seq_len, rx);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let state = Arc::new(AppState {
|
let state = Arc::new(AppState {
|
||||||
|
|||||||
195
crates/xserv-server/src/tp_engine.rs
Normal file
195
crates/xserv-server/src/tp_engine.rs
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
//! Tensor-parallel inference engine for the HTTP server.
|
||||||
|
//!
|
||||||
|
//! Serial coordinator model: one rank-0 coordinator thread (the caller) drives
|
||||||
|
//! generation and owns the scheduler; ranks 1..world are worker threads. For
|
||||||
|
//! each step the coordinator broadcasts a command (Register/Prefill/Decode/Free)
|
||||||
|
//! to the workers and runs the same op on its own shard; the per-layer NCCL
|
||||||
|
//! AllReduces keep all ranks in lockstep. Only the coordinator samples — the
|
||||||
|
//! chosen token is carried in the next Decode command, so this is correct for
|
||||||
|
//! both greedy and stochastic sampling.
|
||||||
|
//!
|
||||||
|
//! Requests are processed one at a time (sufficient for the quality benchmark,
|
||||||
|
//! which issues serial requests). Continuous batching across ranks is future
|
||||||
|
//! work; the single-GPU `Engine` still handles TP=1.
|
||||||
|
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::mpsc;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
|
use xserv_distributed::{TpContext, UniqueId};
|
||||||
|
use xserv_model::loader;
|
||||||
|
use xserv_model::{sample, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
|
||||||
|
use xserv_tensor::{DType, Device};
|
||||||
|
use xserv_tokenizer::Tokenizer;
|
||||||
|
|
||||||
|
use crate::engine::{GenerateEvent, GenerateRequest};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
enum TpCommand {
|
||||||
|
Register(usize),
|
||||||
|
Free(usize),
|
||||||
|
Prefill { tokens: Vec<u32>, slot: usize },
|
||||||
|
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
|
||||||
|
Shutdown,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RankCtx {
|
||||||
|
model: Qwen3,
|
||||||
|
cache: PagedKVCache,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_rank(
|
||||||
|
model_dir: &Path,
|
||||||
|
config: &ModelConfig,
|
||||||
|
rank: usize,
|
||||||
|
world: usize,
|
||||||
|
device: u32,
|
||||||
|
max_seq_len: usize,
|
||||||
|
tp: Option<Arc<TpContext>>,
|
||||||
|
) -> RankCtx {
|
||||||
|
let weights = loader::load_model_dir(model_dir, Device::Cpu);
|
||||||
|
let model = Qwen3::from_weights_tp(config.clone(), weights, rank, world, device, tp);
|
||||||
|
let local_kv = config.num_kv_heads() / world;
|
||||||
|
let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE);
|
||||||
|
let total_blocks = max_blocks_per_seq + 8;
|
||||||
|
let cache = PagedKVCache::new_tp(
|
||||||
|
config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device,
|
||||||
|
);
|
||||||
|
RankCtx { model, cache }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn worker_loop(
|
||||||
|
rank: usize,
|
||||||
|
world: usize,
|
||||||
|
id: UniqueId,
|
||||||
|
model_dir: PathBuf,
|
||||||
|
config: ModelConfig,
|
||||||
|
max_seq_len: usize,
|
||||||
|
cmd_rx: mpsc::Receiver<TpCommand>,
|
||||||
|
ack_tx: mpsc::Sender<()>,
|
||||||
|
) {
|
||||||
|
let tp = Arc::new(TpContext::init(rank, world, id, rank as u32));
|
||||||
|
let mut rc = build_rank(&model_dir, &config, rank, world, rank as u32, max_seq_len, Some(tp));
|
||||||
|
while let Ok(cmd) = cmd_rx.recv() {
|
||||||
|
match cmd {
|
||||||
|
TpCommand::Register(slot) => {
|
||||||
|
let _ = rc.cache.register_sequence(slot);
|
||||||
|
}
|
||||||
|
TpCommand::Free(slot) => rc.cache.free_sequence(slot),
|
||||||
|
TpCommand::Prefill { tokens, slot } => {
|
||||||
|
let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache);
|
||||||
|
}
|
||||||
|
TpCommand::Decode { tokens, positions, slots } => {
|
||||||
|
let _ = rc.model.forward_decode_paged(&tokens, &positions, &slots, &mut rc.cache);
|
||||||
|
}
|
||||||
|
TpCommand::Shutdown => {
|
||||||
|
let _ = ack_tx.send(());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _ = ack_tx.send(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the TP coordinator (rank 0) on the calling thread. Spawns worker ranks
|
||||||
|
/// internally and consumes generation requests from `rx`.
|
||||||
|
pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver<GenerateRequest>) {
|
||||||
|
assert!(world >= 2, "run_tp requires world >= 2");
|
||||||
|
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||||
|
assert!(
|
||||||
|
config.num_kv_heads() % world == 0,
|
||||||
|
"num_kv_heads {} not divisible by tp {world}",
|
||||||
|
config.num_kv_heads()
|
||||||
|
);
|
||||||
|
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||||
|
let id = xserv_distributed::get_unique_id();
|
||||||
|
|
||||||
|
// Spawn worker ranks 1..world.
|
||||||
|
let (ack_tx, ack_rx) = mpsc::channel::<()>();
|
||||||
|
let mut cmd_txs: Vec<mpsc::Sender<TpCommand>> = Vec::new();
|
||||||
|
for rank in 1..world {
|
||||||
|
let (ctx_tx, ctx_rx) = mpsc::channel::<TpCommand>();
|
||||||
|
cmd_txs.push(ctx_tx);
|
||||||
|
let ack_tx = ack_tx.clone();
|
||||||
|
let model_dir = model_dir.to_path_buf();
|
||||||
|
let config = config.clone();
|
||||||
|
thread::spawn(move || {
|
||||||
|
worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rank 0 (this thread).
|
||||||
|
let tp = Arc::new(TpContext::init(0, world, id, 0));
|
||||||
|
let mut rc = build_rank(model_dir, &config, 0, world, 0, max_seq_len, Some(tp));
|
||||||
|
eprintln!("[tp-engine] ready (tp={world}, max_seq_len={max_seq_len})");
|
||||||
|
|
||||||
|
let eos = tokenizer.eos_token_id();
|
||||||
|
let n_workers = world - 1;
|
||||||
|
let broadcast = |txs: &[mpsc::Sender<TpCommand>], cmd: TpCommand| {
|
||||||
|
for t in txs {
|
||||||
|
let _ = t.send(cmd.clone());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let wait_acks = |rx: &mpsc::Receiver<()>| {
|
||||||
|
for _ in 0..n_workers {
|
||||||
|
let _ = rx.recv();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let slot = 0usize;
|
||||||
|
while let Ok(req) = rx.recv() {
|
||||||
|
broadcast(&cmd_txs, TpCommand::Register(slot));
|
||||||
|
rc.cache.register_sequence(slot).expect("register slot");
|
||||||
|
wait_acks(&ack_rx);
|
||||||
|
|
||||||
|
// Prefill.
|
||||||
|
broadcast(&cmd_txs, TpCommand::Prefill { tokens: req.prompt_tokens.clone(), slot });
|
||||||
|
let logits = rc.model.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache);
|
||||||
|
wait_acks(&ack_rx);
|
||||||
|
let mut next = sample(&logits, &req.sampling);
|
||||||
|
|
||||||
|
let mut decode_buf: Vec<u8> = Vec::new();
|
||||||
|
let mut generated = 1usize;
|
||||||
|
emit_text(&tokenizer, &req, next, eos, &mut decode_buf);
|
||||||
|
|
||||||
|
let finish = loop {
|
||||||
|
if eos == Some(next) {
|
||||||
|
break "stop";
|
||||||
|
}
|
||||||
|
if generated >= req.max_tokens {
|
||||||
|
break "length";
|
||||||
|
}
|
||||||
|
let pos = rc.cache.seq_len(slot);
|
||||||
|
broadcast(&cmd_txs, TpCommand::Decode { tokens: vec![next], positions: vec![pos], slots: vec![slot] });
|
||||||
|
let logits = rc.model.forward_decode_paged(&[next], &[pos], &[slot], &mut rc.cache);
|
||||||
|
wait_acks(&ack_rx);
|
||||||
|
next = sample(&logits, &req.sampling);
|
||||||
|
generated += 1;
|
||||||
|
emit_text(&tokenizer, &req, next, eos, &mut decode_buf);
|
||||||
|
};
|
||||||
|
|
||||||
|
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
|
||||||
|
if !tail.is_empty() {
|
||||||
|
let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail });
|
||||||
|
}
|
||||||
|
let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() });
|
||||||
|
|
||||||
|
broadcast(&cmd_txs, TpCommand::Free(slot));
|
||||||
|
rc.cache.free_sequence(slot);
|
||||||
|
wait_acks(&ack_rx);
|
||||||
|
}
|
||||||
|
|
||||||
|
broadcast(&cmd_txs, TpCommand::Shutdown);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stream a token's decoded text to the client (EOS contributes no text).
|
||||||
|
fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, eos: Option<u32>, buf: &mut Vec<u8>) {
|
||||||
|
if eos == Some(token_id) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let text = tokenizer.decode_token_stream(token_id, buf);
|
||||||
|
if !text.is_empty() {
|
||||||
|
let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user