server: pipeline-parallel HTTP engine (--pp N)
pp_engine::run_pp: stage-0 coordinator (scheduler/tokenizer/sampling + stop logic) on the calling thread, worker stage threads for 1..P. Each step the coordinator embeds + runs its layers, then the hidden state is handed stage->stage over NCCL P2P; the last stage samples and returns the token to stage 0 over an in-process channel. v1 is serial (one request, one token/step) — correctness first; throughput via microbatch overlap is future work. main: wire --pp N (mutually exclusive with --tp). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
mod api;
|
||||
mod engine;
|
||||
mod pp_engine;
|
||||
mod tp_engine;
|
||||
|
||||
use axum::{routing::{get, post}, Extension, Router};
|
||||
@@ -19,7 +20,7 @@ pub struct AppState {
|
||||
async fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N]");
|
||||
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N] [--pp N]");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
@@ -52,6 +53,16 @@ async fn main() {
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
let pp: usize = args.iter()
|
||||
.position(|a| a == "--pp")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
if tp > 1 && pp > 1 {
|
||||
eprintln!("--tp and --pp cannot be combined yet (2D TP×PP is future work)");
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
let model_max_seq_len = model_config.max_seq_len();
|
||||
if model_max_seq_len == 0 {
|
||||
@@ -76,7 +87,10 @@ async fn main() {
|
||||
|
||||
let model_dir_clone = model_dir.clone();
|
||||
std::thread::spawn(move || {
|
||||
if tp <= 1 {
|
||||
if pp > 1 {
|
||||
// Pipeline-parallel path: stage-0 coordinator + worker stage threads.
|
||||
pp_engine::run_pp(&model_dir_clone, pp, max_seq_len, rx);
|
||||
} else if tp <= 1 {
|
||||
let mut engine = engine::Engine::load_with_swap(&model_dir_clone, max_batch, max_seq_len, swap_space_gb);
|
||||
engine.run(rx);
|
||||
} else {
|
||||
|
||||
264
crates/xserv-server/src/pp_engine.rs
Normal file
264
crates/xserv-server/src/pp_engine.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
//! Pipeline-parallel inference engine for the HTTP server (Phase 18).
|
||||
//!
|
||||
//! Layer-wise split: stage `s` holds layers `[s*L, (s+1)*L)`. Stage 0 owns the
|
||||
//! token embedding and acts as the coordinator (scheduler + tokenizer + response
|
||||
//! sender + stop logic); the last stage owns `norm`/`lm_head` and does sampling.
|
||||
//! Hidden states are handed off stage->stage via NCCL P2P (`PpContext`); the
|
||||
//! sampled token id (a single u32) is returned last-stage -> stage0 over an
|
||||
//! in-process channel (same process, so no NCCL needed for that).
|
||||
//!
|
||||
//! v1 is serial: one request at a time, one token per step, the pipeline is
|
||||
//! filled and drained each step (stage0's decode step t+1 depends on the token
|
||||
//! the last stage sampled at step t). This gives correctness + per-GPU memory
|
||||
//! savings; throughput via microbatch/1F1B overlap is future work
|
||||
//! (see docs/18-pipeline-parallelism.md).
|
||||
|
||||
use std::ffi::c_void;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::mpsc;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
use half::bf16;
|
||||
use xserv_distributed::{PpContext, UniqueId};
|
||||
use xserv_model::loader;
|
||||
use xserv_model::sampling::SamplingParams;
|
||||
use xserv_model::{sample, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
use crate::engine::{GenerateEvent, GenerateRequest};
|
||||
|
||||
/// Control messages from the coordinator (stage 0) to a worker stage. The heavy
|
||||
/// hidden-state tensors do NOT travel here — they go GPU->GPU over NCCL. Only
|
||||
/// tiny control info (slot ids, token count, sampling params) is sent.
|
||||
#[derive(Clone)]
|
||||
enum PpCommand {
|
||||
Register(usize),
|
||||
Free(usize),
|
||||
/// Receive `[n_tokens, hidden]` from the previous stage, run this stage's
|
||||
/// layers; if last stage, sample with `sampling` and return the token.
|
||||
Prefill { n_tokens: usize, slot: usize, sampling: SamplingParams },
|
||||
/// Receive `[1, hidden]`, run this stage's layers; last stage samples.
|
||||
Decode { slot: usize, sampling: SamplingParams },
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
struct StageCtx {
|
||||
model: Qwen3,
|
||||
cache: PagedKVCache,
|
||||
pp: Arc<PpContext>,
|
||||
hidden: usize,
|
||||
device: u32,
|
||||
}
|
||||
|
||||
/// Build this stage: NCCL init, load + slice weights, size a per-stage KV pool
|
||||
/// for THIS stage's layers only (so per-GPU KV is ~1/P).
|
||||
fn build_stage(
|
||||
model_dir: &Path,
|
||||
config: &ModelConfig,
|
||||
stage: usize,
|
||||
world: usize,
|
||||
device: u32,
|
||||
max_seq_len: usize,
|
||||
id: UniqueId,
|
||||
) -> StageCtx {
|
||||
let pp = Arc::new(PpContext::init(stage, world, id, device));
|
||||
let weights = loader::load_model_dir(model_dir, Device::Cpu);
|
||||
let model = Qwen3::from_weights_pp(config.clone(), weights, stage, world, device);
|
||||
|
||||
// The KV cache only needs this stage's layers; build it from a config clone
|
||||
// whose layer count is the per-stage count (heads are NOT split under PP).
|
||||
let per_stage = config.num_layers() / world;
|
||||
let mut stage_config = config.clone();
|
||||
stage_config.num_hidden_layers = Some(per_stage);
|
||||
|
||||
let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE);
|
||||
let total_blocks = max_blocks_per_seq + 8; // v1 serial: one active sequence
|
||||
let cache = PagedKVCache::new(
|
||||
&stage_config, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device,
|
||||
);
|
||||
StageCtx { model, cache, pp, hidden: config.hidden(), device }
|
||||
}
|
||||
|
||||
/// Allocate a zeroed `[n, hidden]` device tensor and receive into it from `peer`.
|
||||
fn recv_hidden(sc: &StageCtx, n: usize, peer: usize) -> Tensor {
|
||||
let zeros = vec![bf16::ZERO; n * sc.hidden];
|
||||
let x = Tensor::from_slice(&zeros, &[n, sc.hidden]).to_device(Device::Cuda(sc.device));
|
||||
let ptr = x.storage().gpu_buffer().as_ptr() as *mut c_void;
|
||||
sc.pp.recv_bf16_ptr(ptr, n * sc.hidden, peer);
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
x
|
||||
}
|
||||
|
||||
/// Send the `[*, hidden]` hidden state to `peer`, then synchronize so NCCL has
|
||||
/// finished reading `x` before it is dropped/reused.
|
||||
fn send_hidden(sc: &StageCtx, x: &Tensor, peer: usize) {
|
||||
let ptr = x.storage().gpu_buffer().as_ptr() as *const c_void;
|
||||
sc.pp.send_bf16_ptr(ptr, x.numel(), peer);
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
|
||||
fn worker_loop(
|
||||
stage: usize,
|
||||
world: usize,
|
||||
id: UniqueId,
|
||||
model_dir: PathBuf,
|
||||
config: ModelConfig,
|
||||
max_seq_len: usize,
|
||||
cmd_rx: mpsc::Receiver<PpCommand>,
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
token_tx: mpsc::Sender<u32>,
|
||||
) {
|
||||
let mut sc = build_stage(&model_dir, &config, stage, world, stage as u32, max_seq_len, id);
|
||||
let is_last = stage == world - 1;
|
||||
let prev = stage - 1;
|
||||
let next = stage + 1;
|
||||
|
||||
while let Ok(cmd) = cmd_rx.recv() {
|
||||
match cmd {
|
||||
PpCommand::Register(slot) => {
|
||||
let _ = sc.cache.register_sequence(slot);
|
||||
let _ = ack_tx.send(());
|
||||
}
|
||||
PpCommand::Free(slot) => {
|
||||
sc.cache.free_sequence(slot);
|
||||
let _ = ack_tx.send(());
|
||||
}
|
||||
PpCommand::Prefill { n_tokens, slot, sampling } => {
|
||||
let x = recv_hidden(&sc, n_tokens, prev);
|
||||
let x = sc.model.forward_layers_prefill(x, slot, &mut sc.cache);
|
||||
if is_last {
|
||||
let logits = sc.model.head(&x);
|
||||
let _ = token_tx.send(sample(&logits, &sampling));
|
||||
} else {
|
||||
send_hidden(&sc, &x, next);
|
||||
}
|
||||
}
|
||||
PpCommand::Decode { slot, sampling } => {
|
||||
let x = recv_hidden(&sc, 1, prev);
|
||||
let x = sc.model.forward_layers_decode(x, &[slot], &mut sc.cache);
|
||||
if is_last {
|
||||
let logits = sc.model.head(&x);
|
||||
let _ = token_tx.send(sample(&logits, &sampling));
|
||||
} else {
|
||||
send_hidden(&sc, &x, next);
|
||||
}
|
||||
}
|
||||
PpCommand::Shutdown => {
|
||||
let _ = ack_tx.send(());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the PP coordinator (stage 0) on the calling thread. Spawns worker stages
|
||||
/// 1..world and consumes generation requests from `rx`.
|
||||
pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver<GenerateRequest>) {
|
||||
assert!(world >= 2, "run_pp requires world >= 2");
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
assert!(
|
||||
config.num_layers() % world == 0,
|
||||
"num_layers {} not divisible by pp {world}",
|
||||
config.num_layers()
|
||||
);
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
let id = xserv_distributed::get_unique_id();
|
||||
|
||||
// Worker stages 1..world. Each gets a control channel; all share one ack
|
||||
// channel and one token channel (only the last stage actually sends tokens).
|
||||
let (ack_tx, ack_rx) = mpsc::channel::<()>();
|
||||
let (token_tx, token_rx) = mpsc::channel::<u32>();
|
||||
let mut cmd_txs: Vec<mpsc::Sender<PpCommand>> = Vec::new();
|
||||
for stage in 1..world {
|
||||
let (ctx_tx, ctx_rx) = mpsc::channel::<PpCommand>();
|
||||
cmd_txs.push(ctx_tx);
|
||||
let ack_tx = ack_tx.clone();
|
||||
let token_tx = token_tx.clone();
|
||||
let model_dir = model_dir.to_path_buf();
|
||||
let config = config.clone();
|
||||
thread::spawn(move || {
|
||||
worker_loop(stage, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx, token_tx);
|
||||
});
|
||||
}
|
||||
|
||||
// Stage 0 (this thread): coordinator + embedding + first layers.
|
||||
let mut sc = build_stage(model_dir, &config, 0, world, 0, max_seq_len, id);
|
||||
eprintln!("[pp-engine] ready (pp={world}, max_seq_len={max_seq_len})");
|
||||
|
||||
let eos = tokenizer.eos_token_id();
|
||||
let n_workers = world - 1;
|
||||
let next_peer = 1usize;
|
||||
let broadcast = |txs: &[mpsc::Sender<PpCommand>], cmd: PpCommand| {
|
||||
for t in txs {
|
||||
let _ = t.send(cmd.clone());
|
||||
}
|
||||
};
|
||||
let wait_acks = |rx: &mpsc::Receiver<()>| {
|
||||
for _ in 0..n_workers {
|
||||
let _ = rx.recv();
|
||||
}
|
||||
};
|
||||
|
||||
let slot = 0usize;
|
||||
while let Ok(req) = rx.recv() {
|
||||
broadcast(&cmd_txs, PpCommand::Register(slot));
|
||||
sc.cache.register_sequence(slot).expect("register slot");
|
||||
wait_acks(&ack_rx);
|
||||
|
||||
// Prefill: embed prompt, run stage-0 layers, push hidden into the pipe.
|
||||
broadcast(&cmd_txs, PpCommand::Prefill {
|
||||
n_tokens: req.prompt_tokens.len(),
|
||||
slot,
|
||||
sampling: req.sampling.clone(),
|
||||
});
|
||||
let x = sc.model.embed(&req.prompt_tokens);
|
||||
let x = sc.model.forward_layers_prefill(x, slot, &mut sc.cache);
|
||||
send_hidden(&sc, &x, next_peer);
|
||||
let mut next = token_rx.recv().expect("prefill token");
|
||||
|
||||
let mut decode_buf: Vec<u8> = Vec::new();
|
||||
let mut generated = 1usize;
|
||||
emit_text(&tokenizer, &req, next, eos, &mut decode_buf);
|
||||
|
||||
let finish = loop {
|
||||
if eos == Some(next) {
|
||||
break "stop";
|
||||
}
|
||||
if generated >= req.max_tokens {
|
||||
break "length";
|
||||
}
|
||||
broadcast(&cmd_txs, PpCommand::Decode { slot, sampling: req.sampling.clone() });
|
||||
let x = sc.model.embed(&[next]);
|
||||
let x = sc.model.forward_layers_decode(x, &[slot], &mut sc.cache);
|
||||
send_hidden(&sc, &x, next_peer);
|
||||
next = token_rx.recv().expect("decode token");
|
||||
generated += 1;
|
||||
emit_text(&tokenizer, &req, next, eos, &mut decode_buf);
|
||||
};
|
||||
|
||||
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
|
||||
if !tail.is_empty() {
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail });
|
||||
}
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() });
|
||||
|
||||
broadcast(&cmd_txs, PpCommand::Free(slot));
|
||||
sc.cache.free_sequence(slot);
|
||||
wait_acks(&ack_rx);
|
||||
}
|
||||
|
||||
broadcast(&cmd_txs, PpCommand::Shutdown);
|
||||
}
|
||||
|
||||
/// Stream a token's decoded text to the client (EOS contributes no text).
|
||||
fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, eos: Option<u32>, buf: &mut Vec<u8>) {
|
||||
if eos == Some(token_id) {
|
||||
return;
|
||||
}
|
||||
let text = tokenizer.decode_token_stream(token_id, buf);
|
||||
if !text.is_empty() {
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user