Files
xserv/crates/xserv-server/src/tp_engine.rs
Gahow Wang 0314b4f3ac server: non-blocking stream send — stop one slow client stalling the batch
All three engines emitted tokens with blocking_send on the single
decode/coordinator OS thread. A streaming client that drains slower than
generation fills its 64-slot channel, and blocking_send then blocks the whole
thread: under continuous batching one slow consumer stalls every other running
sequence (and in the serial TP/PP path it blocks admission of the next request
too). The whole point of continuous batching is defeated.

Fix: switch to try_send. engine.rs sets a client_stalled flag on Full/Closed,
reaped by is_finished() next iteration; tp_engine/pp_engine emit_text returns
bool and the decode loop breaks with finish_reason "error". When the
sequence/request is dropped its sender drops too, closing the channel so the
client receive loop ends rather than hanging. A slow client now only loses its
own sequence, never the batch.

Verified on dash5: gpt-oss FP8 TP=1 streaming via tp_engine still streams
correctly (SSE chunks, coherent content, no hang); bench-gpt-oss TP=2 5.9ms
TPOT unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-07-01 12:37:32 +08:00

362 lines
11 KiB
Rust

//! Tensor-parallel inference engine for the HTTP server.
//!
//! Serial coordinator model: one rank-0 coordinator thread (the caller) drives
//! generation and owns the scheduler; ranks 1..world are worker threads. For
//! each step the coordinator broadcasts a command (Register/Prefill/Decode/Free)
//! to the workers and runs the same op on its own shard; the per-layer NCCL
//! AllReduces keep all ranks in lockstep. Only the coordinator samples — the
//! chosen token is carried in the next Decode command, so this is correct for
//! both greedy and stochastic sampling.
//!
//! Requests are processed one at a time (sufficient for the quality benchmark,
//! which issues serial requests). Continuous batching across ranks is future
//! work; the single-GPU `Engine` still handles TP=1.
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::mpsc;
use std::thread;
use xserv_distributed::{TpContext, UniqueId};
use xserv_model::loader;
use xserv_model::{
BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, sample,
sample_greedy_penalized,
};
use xserv_tensor::{DType, Device, Tensor};
use xserv_tokenizer::Tokenizer;
use crate::engine::{GenerateEvent, GenerateRequest};
#[derive(Clone)]
enum TpCommand {
Register(usize),
Free(usize),
Prefill {
tokens: Vec<u32>,
slot: usize,
},
Decode {
tokens: Vec<u32>,
positions: Vec<usize>,
slots: Vec<usize>,
},
Shutdown,
}
enum TpModel {
Qwen3(Qwen3),
GptOss(GptOss),
}
impl TpModel {
fn forward_prefill_paged(
&self,
tokens: &[u32],
slot: usize,
cache: &mut PagedKVCache,
) -> Tensor {
match self {
TpModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache),
TpModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache),
}
}
fn forward_decode_paged(
&self,
tokens: &[u32],
positions: &[usize],
slots: &[usize],
cache: &mut PagedKVCache,
) -> Tensor {
match self {
TpModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache),
TpModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache),
}
}
}
struct RankCtx {
model: TpModel,
cache: PagedKVCache,
decoder: GraphedGptOssDecoder,
}
/// Decode one step: gpt-oss batch=1 goes through the CUDA-graph decoder
/// (lazy capture, replay thereafter); everything else runs eager.
fn rank_decode(rc: &mut RankCtx, tokens: &[u32], positions: &[usize], slots: &[usize]) -> Tensor {
match &rc.model {
TpModel::GptOss(m) => rc
.decoder
.decode(m, tokens, positions, slots, &mut rc.cache),
TpModel::Qwen3(_) => rc
.model
.forward_decode_paged(tokens, positions, slots, &mut rc.cache),
}
}
fn build_rank(
model_dir: &Path,
config: &ModelConfig,
rank: usize,
world: usize,
device: u32,
max_seq_len: usize,
tp: Option<Arc<TpContext>>,
) -> RankCtx {
let weights = loader::load_model_dir(model_dir, Device::Cpu);
let model = if config.is_moe() {
TpModel::GptOss(GptOss::from_weights_tp(
config.clone(),
weights,
rank,
world,
device,
tp,
))
} else {
TpModel::Qwen3(Qwen3::from_weights_tp(
config.clone(),
weights,
rank,
world,
device,
tp,
))
};
let local_kv = config.num_kv_heads() / world;
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = max_blocks_per_seq + 8;
let cache = PagedKVCache::new_tp(
config,
local_kv,
total_blocks,
0,
4,
max_blocks_per_seq,
DType::BF16,
device,
);
RankCtx {
model,
cache,
decoder: GraphedGptOssDecoder::new(),
}
}
fn worker_loop(
rank: usize,
world: usize,
id: UniqueId,
model_dir: PathBuf,
config: ModelConfig,
max_seq_len: usize,
cmd_rx: mpsc::Receiver<TpCommand>,
ack_tx: mpsc::Sender<()>,
) {
let tp = Arc::new(TpContext::init(rank, world, id, rank as u32));
let mut rc = build_rank(
&model_dir,
&config,
rank,
world,
rank as u32,
max_seq_len,
Some(tp),
);
while let Ok(cmd) = cmd_rx.recv() {
match cmd {
TpCommand::Register(slot) => {
let _ = rc.cache.register_sequence(slot);
}
TpCommand::Free(slot) => rc.cache.free_sequence(slot),
TpCommand::Prefill { tokens, slot } => {
let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache);
}
TpCommand::Decode {
tokens,
positions,
slots,
} => {
let _ = rank_decode(&mut rc, &tokens, &positions, &slots);
}
TpCommand::Shutdown => {
let _ = ack_tx.send(());
break;
}
}
let _ = ack_tx.send(());
}
}
/// Run the TP coordinator (rank 0) on the calling thread. Spawns worker ranks
/// internally and consumes generation requests from `rx`.
pub fn run_tp(
model_dir: &Path,
world: usize,
max_seq_len: usize,
rx: mpsc::Receiver<GenerateRequest>,
) {
// world=1 is a valid single-rank configuration (gpt-oss has no
// single-GPU engine path; NCCL init and all_reduce no-op at world=1).
assert!(world >= 1, "run_tp requires world >= 1");
let config = ModelConfig::from_file(&model_dir.join("config.json"));
assert!(
config.num_kv_heads() % world == 0,
"num_kv_heads {} not divisible by tp {world}",
config.num_kv_heads()
);
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
let id = xserv_distributed::get_unique_id();
// Spawn worker ranks 1..world.
let (ack_tx, ack_rx) = mpsc::channel::<()>();
let mut cmd_txs: Vec<mpsc::Sender<TpCommand>> = Vec::new();
for rank in 1..world {
let (ctx_tx, ctx_rx) = mpsc::channel::<TpCommand>();
cmd_txs.push(ctx_tx);
let ack_tx = ack_tx.clone();
let model_dir = model_dir.to_path_buf();
let config = config.clone();
thread::spawn(move || {
worker_loop(
rank,
world,
id,
model_dir,
config,
max_seq_len,
ctx_rx,
ack_tx,
);
});
}
// Rank 0 (this thread).
let tp = Arc::new(TpContext::init(0, world, id, 0));
let mut rc = build_rank(model_dir, &config, 0, world, 0, max_seq_len, Some(tp));
eprintln!("[tp-engine] ready (tp={world}, max_seq_len={max_seq_len})");
// Optional repetition penalty to break greedy repetition loops (reasoning
// models loop under pure greedy when numerics diverge from the reference).
// Off by default; XSERV_REP_PENALTY>1 enables it over the last
// XSERV_REP_WINDOW generated tokens. Applied only on the greedy path.
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1.0);
let rep_window: usize = std::env::var("XSERV_REP_WINDOW")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(128);
let pick = |logits: &Tensor, sp: &xserv_model::SamplingParams, history: &[u32]| -> u32 {
if rep_penalty > 1.0 && sp.temperature == 0.0 {
let start = history.len().saturating_sub(rep_window);
sample_greedy_penalized(logits, &history[start..], rep_penalty)
} else {
sample(logits, sp)
}
};
let n_workers = world - 1;
let broadcast = |txs: &[mpsc::Sender<TpCommand>], cmd: TpCommand| {
for t in txs {
let _ = t.send(cmd.clone());
}
};
let wait_acks = |rx: &mpsc::Receiver<()>| {
for _ in 0..n_workers {
let _ = rx.recv();
}
};
let slot = 0usize;
while let Ok(req) = rx.recv() {
broadcast(&cmd_txs, TpCommand::Register(slot));
rc.cache.register_sequence(slot).expect("register slot");
wait_acks(&ack_rx);
// Prefill.
broadcast(
&cmd_txs,
TpCommand::Prefill {
tokens: req.prompt_tokens.clone(),
slot,
},
);
let logits = rc
.model
.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache);
wait_acks(&ack_rx);
let mut gen_ids: Vec<u32> = Vec::new();
let mut next = pick(&logits, &req.sampling, &gen_ids);
gen_ids.push(next);
let mut decode_buf: Vec<u8> = Vec::new();
let mut generated = 1usize;
let mut stalled = !emit_text(&tokenizer, &req, next, &mut decode_buf);
let finish = loop {
if stalled {
break "error";
}
if tokenizer.is_eos(next) {
break "stop";
}
if generated >= req.max_tokens {
break "length";
}
let pos = rc.cache.seq_len(slot);
broadcast(
&cmd_txs,
TpCommand::Decode {
tokens: vec![next],
positions: vec![pos],
slots: vec![slot],
},
);
let logits = rank_decode(&mut rc, &[next], &[pos], &[slot]);
wait_acks(&ack_rx);
next = pick(&logits, &req.sampling, &gen_ids);
gen_ids.push(next);
generated += 1;
stalled = !emit_text(&tokenizer, &req, next, &mut decode_buf);
};
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
if !tail.is_empty() {
let _ = req.sender.try_send(GenerateEvent::Token {
id: next,
text: tail,
});
}
let _ = req.sender.try_send(GenerateEvent::Done {
finish_reason: finish.to_string(),
});
broadcast(&cmd_txs, TpCommand::Free(slot));
rc.cache.free_sequence(slot);
wait_acks(&ack_rx);
}
broadcast(&cmd_txs, TpCommand::Shutdown);
}
/// Stream a token's decoded text to the client (EOS contributes no text).
/// Returns false if the send would block (client too slow) or the client is
/// gone — the caller stops generating so the serial coordinator thread is free
/// to admit the next request instead of blocking on one slow consumer.
fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: &mut Vec<u8>) -> bool {
if tokenizer.is_eos(token_id) {
return true;
}
let text = tokenizer.decode_token_stream(token_id, buf);
if !text.is_empty() {
return req
.sender
.try_send(GenerateEvent::Token { id: token_id, text })
.is_ok();
}
true
}