All three engines emitted tokens with blocking_send on the single decode/coordinator OS thread. A streaming client that drains slower than generation fills its 64-slot channel, and blocking_send then blocks the whole thread: under continuous batching one slow consumer stalls every other running sequence (and in the serial TP/PP path it blocks admission of the next request too). The whole point of continuous batching is defeated. Fix: switch to try_send. engine.rs sets a client_stalled flag on Full/Closed, reaped by is_finished() next iteration; tp_engine/pp_engine emit_text returns bool and the decode loop breaks with finish_reason "error". When the sequence/request is dropped its sender drops too, closing the channel so the client receive loop ends rather than hanging. A slow client now only loses its own sequence, never the batch. Verified on dash5: gpt-oss FP8 TP=1 streaming via tp_engine still streams correctly (SSE chunks, coherent content, no hang); bench-gpt-oss TP=2 5.9ms TPOT unchanged. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
362 lines
11 KiB
Rust
362 lines
11 KiB
Rust
//! Tensor-parallel inference engine for the HTTP server.
|
|
//!
|
|
//! Serial coordinator model: one rank-0 coordinator thread (the caller) drives
|
|
//! generation and owns the scheduler; ranks 1..world are worker threads. For
|
|
//! each step the coordinator broadcasts a command (Register/Prefill/Decode/Free)
|
|
//! to the workers and runs the same op on its own shard; the per-layer NCCL
|
|
//! AllReduces keep all ranks in lockstep. Only the coordinator samples — the
|
|
//! chosen token is carried in the next Decode command, so this is correct for
|
|
//! both greedy and stochastic sampling.
|
|
//!
|
|
//! Requests are processed one at a time (sufficient for the quality benchmark,
|
|
//! which issues serial requests). Continuous batching across ranks is future
|
|
//! work; the single-GPU `Engine` still handles TP=1.
|
|
|
|
use std::path::{Path, PathBuf};
|
|
use std::sync::Arc;
|
|
use std::sync::mpsc;
|
|
use std::thread;
|
|
|
|
use xserv_distributed::{TpContext, UniqueId};
|
|
use xserv_model::loader;
|
|
use xserv_model::{
|
|
BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, sample,
|
|
sample_greedy_penalized,
|
|
};
|
|
use xserv_tensor::{DType, Device, Tensor};
|
|
use xserv_tokenizer::Tokenizer;
|
|
|
|
use crate::engine::{GenerateEvent, GenerateRequest};
|
|
|
|
#[derive(Clone)]
|
|
enum TpCommand {
|
|
Register(usize),
|
|
Free(usize),
|
|
Prefill {
|
|
tokens: Vec<u32>,
|
|
slot: usize,
|
|
},
|
|
Decode {
|
|
tokens: Vec<u32>,
|
|
positions: Vec<usize>,
|
|
slots: Vec<usize>,
|
|
},
|
|
Shutdown,
|
|
}
|
|
|
|
enum TpModel {
|
|
Qwen3(Qwen3),
|
|
GptOss(GptOss),
|
|
}
|
|
|
|
impl TpModel {
|
|
fn forward_prefill_paged(
|
|
&self,
|
|
tokens: &[u32],
|
|
slot: usize,
|
|
cache: &mut PagedKVCache,
|
|
) -> Tensor {
|
|
match self {
|
|
TpModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache),
|
|
TpModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache),
|
|
}
|
|
}
|
|
|
|
fn forward_decode_paged(
|
|
&self,
|
|
tokens: &[u32],
|
|
positions: &[usize],
|
|
slots: &[usize],
|
|
cache: &mut PagedKVCache,
|
|
) -> Tensor {
|
|
match self {
|
|
TpModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache),
|
|
TpModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct RankCtx {
|
|
model: TpModel,
|
|
cache: PagedKVCache,
|
|
decoder: GraphedGptOssDecoder,
|
|
}
|
|
|
|
/// Decode one step: gpt-oss batch=1 goes through the CUDA-graph decoder
|
|
/// (lazy capture, replay thereafter); everything else runs eager.
|
|
fn rank_decode(rc: &mut RankCtx, tokens: &[u32], positions: &[usize], slots: &[usize]) -> Tensor {
|
|
match &rc.model {
|
|
TpModel::GptOss(m) => rc
|
|
.decoder
|
|
.decode(m, tokens, positions, slots, &mut rc.cache),
|
|
TpModel::Qwen3(_) => rc
|
|
.model
|
|
.forward_decode_paged(tokens, positions, slots, &mut rc.cache),
|
|
}
|
|
}
|
|
|
|
fn build_rank(
|
|
model_dir: &Path,
|
|
config: &ModelConfig,
|
|
rank: usize,
|
|
world: usize,
|
|
device: u32,
|
|
max_seq_len: usize,
|
|
tp: Option<Arc<TpContext>>,
|
|
) -> RankCtx {
|
|
let weights = loader::load_model_dir(model_dir, Device::Cpu);
|
|
let model = if config.is_moe() {
|
|
TpModel::GptOss(GptOss::from_weights_tp(
|
|
config.clone(),
|
|
weights,
|
|
rank,
|
|
world,
|
|
device,
|
|
tp,
|
|
))
|
|
} else {
|
|
TpModel::Qwen3(Qwen3::from_weights_tp(
|
|
config.clone(),
|
|
weights,
|
|
rank,
|
|
world,
|
|
device,
|
|
tp,
|
|
))
|
|
};
|
|
let local_kv = config.num_kv_heads() / world;
|
|
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
|
let total_blocks = max_blocks_per_seq + 8;
|
|
let cache = PagedKVCache::new_tp(
|
|
config,
|
|
local_kv,
|
|
total_blocks,
|
|
0,
|
|
4,
|
|
max_blocks_per_seq,
|
|
DType::BF16,
|
|
device,
|
|
);
|
|
RankCtx {
|
|
model,
|
|
cache,
|
|
decoder: GraphedGptOssDecoder::new(),
|
|
}
|
|
}
|
|
|
|
fn worker_loop(
|
|
rank: usize,
|
|
world: usize,
|
|
id: UniqueId,
|
|
model_dir: PathBuf,
|
|
config: ModelConfig,
|
|
max_seq_len: usize,
|
|
cmd_rx: mpsc::Receiver<TpCommand>,
|
|
ack_tx: mpsc::Sender<()>,
|
|
) {
|
|
let tp = Arc::new(TpContext::init(rank, world, id, rank as u32));
|
|
let mut rc = build_rank(
|
|
&model_dir,
|
|
&config,
|
|
rank,
|
|
world,
|
|
rank as u32,
|
|
max_seq_len,
|
|
Some(tp),
|
|
);
|
|
while let Ok(cmd) = cmd_rx.recv() {
|
|
match cmd {
|
|
TpCommand::Register(slot) => {
|
|
let _ = rc.cache.register_sequence(slot);
|
|
}
|
|
TpCommand::Free(slot) => rc.cache.free_sequence(slot),
|
|
TpCommand::Prefill { tokens, slot } => {
|
|
let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache);
|
|
}
|
|
TpCommand::Decode {
|
|
tokens,
|
|
positions,
|
|
slots,
|
|
} => {
|
|
let _ = rank_decode(&mut rc, &tokens, &positions, &slots);
|
|
}
|
|
TpCommand::Shutdown => {
|
|
let _ = ack_tx.send(());
|
|
break;
|
|
}
|
|
}
|
|
let _ = ack_tx.send(());
|
|
}
|
|
}
|
|
|
|
/// Run the TP coordinator (rank 0) on the calling thread. Spawns worker ranks
|
|
/// internally and consumes generation requests from `rx`.
|
|
pub fn run_tp(
|
|
model_dir: &Path,
|
|
world: usize,
|
|
max_seq_len: usize,
|
|
rx: mpsc::Receiver<GenerateRequest>,
|
|
) {
|
|
// world=1 is a valid single-rank configuration (gpt-oss has no
|
|
// single-GPU engine path; NCCL init and all_reduce no-op at world=1).
|
|
assert!(world >= 1, "run_tp requires world >= 1");
|
|
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
|
assert!(
|
|
config.num_kv_heads() % world == 0,
|
|
"num_kv_heads {} not divisible by tp {world}",
|
|
config.num_kv_heads()
|
|
);
|
|
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
|
let id = xserv_distributed::get_unique_id();
|
|
|
|
// Spawn worker ranks 1..world.
|
|
let (ack_tx, ack_rx) = mpsc::channel::<()>();
|
|
let mut cmd_txs: Vec<mpsc::Sender<TpCommand>> = Vec::new();
|
|
for rank in 1..world {
|
|
let (ctx_tx, ctx_rx) = mpsc::channel::<TpCommand>();
|
|
cmd_txs.push(ctx_tx);
|
|
let ack_tx = ack_tx.clone();
|
|
let model_dir = model_dir.to_path_buf();
|
|
let config = config.clone();
|
|
thread::spawn(move || {
|
|
worker_loop(
|
|
rank,
|
|
world,
|
|
id,
|
|
model_dir,
|
|
config,
|
|
max_seq_len,
|
|
ctx_rx,
|
|
ack_tx,
|
|
);
|
|
});
|
|
}
|
|
|
|
// Rank 0 (this thread).
|
|
let tp = Arc::new(TpContext::init(0, world, id, 0));
|
|
let mut rc = build_rank(model_dir, &config, 0, world, 0, max_seq_len, Some(tp));
|
|
eprintln!("[tp-engine] ready (tp={world}, max_seq_len={max_seq_len})");
|
|
|
|
// Optional repetition penalty to break greedy repetition loops (reasoning
|
|
// models loop under pure greedy when numerics diverge from the reference).
|
|
// Off by default; XSERV_REP_PENALTY>1 enables it over the last
|
|
// XSERV_REP_WINDOW generated tokens. Applied only on the greedy path.
|
|
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY")
|
|
.ok()
|
|
.and_then(|s| s.parse().ok())
|
|
.unwrap_or(1.0);
|
|
let rep_window: usize = std::env::var("XSERV_REP_WINDOW")
|
|
.ok()
|
|
.and_then(|s| s.parse().ok())
|
|
.unwrap_or(128);
|
|
let pick = |logits: &Tensor, sp: &xserv_model::SamplingParams, history: &[u32]| -> u32 {
|
|
if rep_penalty > 1.0 && sp.temperature == 0.0 {
|
|
let start = history.len().saturating_sub(rep_window);
|
|
sample_greedy_penalized(logits, &history[start..], rep_penalty)
|
|
} else {
|
|
sample(logits, sp)
|
|
}
|
|
};
|
|
|
|
let n_workers = world - 1;
|
|
let broadcast = |txs: &[mpsc::Sender<TpCommand>], cmd: TpCommand| {
|
|
for t in txs {
|
|
let _ = t.send(cmd.clone());
|
|
}
|
|
};
|
|
let wait_acks = |rx: &mpsc::Receiver<()>| {
|
|
for _ in 0..n_workers {
|
|
let _ = rx.recv();
|
|
}
|
|
};
|
|
|
|
let slot = 0usize;
|
|
while let Ok(req) = rx.recv() {
|
|
broadcast(&cmd_txs, TpCommand::Register(slot));
|
|
rc.cache.register_sequence(slot).expect("register slot");
|
|
wait_acks(&ack_rx);
|
|
|
|
// Prefill.
|
|
broadcast(
|
|
&cmd_txs,
|
|
TpCommand::Prefill {
|
|
tokens: req.prompt_tokens.clone(),
|
|
slot,
|
|
},
|
|
);
|
|
let logits = rc
|
|
.model
|
|
.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache);
|
|
wait_acks(&ack_rx);
|
|
let mut gen_ids: Vec<u32> = Vec::new();
|
|
let mut next = pick(&logits, &req.sampling, &gen_ids);
|
|
gen_ids.push(next);
|
|
|
|
let mut decode_buf: Vec<u8> = Vec::new();
|
|
let mut generated = 1usize;
|
|
let mut stalled = !emit_text(&tokenizer, &req, next, &mut decode_buf);
|
|
|
|
let finish = loop {
|
|
if stalled {
|
|
break "error";
|
|
}
|
|
if tokenizer.is_eos(next) {
|
|
break "stop";
|
|
}
|
|
if generated >= req.max_tokens {
|
|
break "length";
|
|
}
|
|
let pos = rc.cache.seq_len(slot);
|
|
broadcast(
|
|
&cmd_txs,
|
|
TpCommand::Decode {
|
|
tokens: vec![next],
|
|
positions: vec![pos],
|
|
slots: vec![slot],
|
|
},
|
|
);
|
|
let logits = rank_decode(&mut rc, &[next], &[pos], &[slot]);
|
|
wait_acks(&ack_rx);
|
|
next = pick(&logits, &req.sampling, &gen_ids);
|
|
gen_ids.push(next);
|
|
generated += 1;
|
|
stalled = !emit_text(&tokenizer, &req, next, &mut decode_buf);
|
|
};
|
|
|
|
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
|
|
if !tail.is_empty() {
|
|
let _ = req.sender.try_send(GenerateEvent::Token {
|
|
id: next,
|
|
text: tail,
|
|
});
|
|
}
|
|
let _ = req.sender.try_send(GenerateEvent::Done {
|
|
finish_reason: finish.to_string(),
|
|
});
|
|
|
|
broadcast(&cmd_txs, TpCommand::Free(slot));
|
|
rc.cache.free_sequence(slot);
|
|
wait_acks(&ack_rx);
|
|
}
|
|
|
|
broadcast(&cmd_txs, TpCommand::Shutdown);
|
|
}
|
|
|
|
/// Stream a token's decoded text to the client (EOS contributes no text).
|
|
/// Returns false if the send would block (client too slow) or the client is
|
|
/// gone — the caller stops generating so the serial coordinator thread is free
|
|
/// to admit the next request instead of blocking on one slow consumer.
|
|
fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: &mut Vec<u8>) -> bool {
|
|
if tokenizer.is_eos(token_id) {
|
|
return true;
|
|
}
|
|
let text = tokenizer.decode_token_stream(token_id, buf);
|
|
if !text.is_empty() {
|
|
return req
|
|
.sender
|
|
.try_send(GenerateEvent::Token { id: token_id, text })
|
|
.is_ok();
|
|
}
|
|
true
|
|
}
|