server: VRAM-sized KV pool + vLLM-style swap scheduler
Fixes the paged-KV OOM at large --max-seq-len and adds elastic memory: - Size the GPU block pool to available VRAM (cudaMemGetInfo) instead of the worst-case blocks_per_seq * max_batch * 2 reservation, which OOM'd at 8192. - Scheduler tracks waiting/running/swapped sets: block-aware admission, swap-in of resumable sequences when blocks free, and preemption of the newest running sequence to host when the pool can't cover a decode step. - --swap-space-gb (default 8) sizes the pinned host swap pool; XSERV_MAX_KV_BLOCKS forces a small pool to exercise swapping. - api: poison-tolerant lock + clean 503 when the engine thread is gone, instead of cascading mutex-poison panics. Verified on RTX 5090: serves at --max-seq-len 8192 (previously OOM), and a forced 40-block pool drives 48 lossless swap-out/swap-in cycles under concurrency with coherent output. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -85,18 +85,20 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
||||
let model_name = state.model_name.clone();
|
||||
let created = unix_timestamp();
|
||||
|
||||
if let Some(response) = validate_request(&req, &model_name) {
|
||||
return response;
|
||||
}
|
||||
|
||||
let prompt = build_prompt(&req.messages);
|
||||
let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt);
|
||||
let prompt_token_count = prompt_tokens.len();
|
||||
|
||||
let max_seq_len = state.max_seq_len;
|
||||
if prompt_token_count >= max_seq_len {
|
||||
return (StatusCode::BAD_REQUEST, Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("prompt is {} tokens, exceeds max_seq_len {}", prompt_token_count, max_seq_len),
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
}))).into_response();
|
||||
return bad_request(format!(
|
||||
"prompt is {} tokens, exceeds max_seq_len {}",
|
||||
prompt_token_count, max_seq_len
|
||||
));
|
||||
}
|
||||
let max_tokens = req.max_tokens.min(max_seq_len - prompt_token_count);
|
||||
|
||||
@@ -107,12 +109,9 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
||||
sampling: sampling_params(&req),
|
||||
sender: tx,
|
||||
};
|
||||
state
|
||||
.engine_sender
|
||||
.lock()
|
||||
.unwrap()
|
||||
.send(gen_req)
|
||||
.expect("engine channel closed");
|
||||
if let Err(resp) = submit_to_engine(&state, gen_req) {
|
||||
return resp;
|
||||
}
|
||||
|
||||
let mut content = String::new();
|
||||
let mut completion_token_count: usize = 0;
|
||||
@@ -156,17 +155,19 @@ fn chat_stream(
|
||||
let model_name = state.model_name.clone();
|
||||
let created = unix_timestamp();
|
||||
|
||||
if let Some(response) = validate_request(&req, &model_name) {
|
||||
return response;
|
||||
}
|
||||
|
||||
let prompt = build_prompt(&req.messages);
|
||||
let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt);
|
||||
|
||||
let max_seq_len = state.max_seq_len;
|
||||
if prompt_tokens.len() >= max_seq_len {
|
||||
return (StatusCode::BAD_REQUEST, Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("prompt is {} tokens, exceeds max_seq_len {}", prompt_tokens.len(), max_seq_len),
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
}))).into_response();
|
||||
return bad_request(format!(
|
||||
"prompt is {} tokens, exceeds max_seq_len {}",
|
||||
prompt_tokens.len(), max_seq_len
|
||||
));
|
||||
}
|
||||
let max_tokens = req.max_tokens.min(max_seq_len - prompt_tokens.len());
|
||||
|
||||
@@ -177,12 +178,9 @@ fn chat_stream(
|
||||
sampling: sampling_params(&req),
|
||||
sender: engine_tx,
|
||||
};
|
||||
state
|
||||
.engine_sender
|
||||
.lock()
|
||||
.unwrap()
|
||||
.send(gen_req)
|
||||
.expect("engine channel closed");
|
||||
if let Err(resp) = submit_to_engine(&state, gen_req) {
|
||||
return resp;
|
||||
}
|
||||
|
||||
// SSE event channel: engine events -> SSE events
|
||||
let (sse_tx, sse_rx) = tokio::sync::mpsc::channel::<Result<Event, Infallible>>(64);
|
||||
@@ -228,6 +226,53 @@ fn chat_stream(
|
||||
Sse::new(ReceiverStream::new(sse_rx)).keep_alive(KeepAlive::default()).into_response()
|
||||
}
|
||||
|
||||
fn validate_request(req: &ChatRequest, model_name: &str) -> Option<Response> {
|
||||
if let Some(model) = &req.model {
|
||||
if model != model_name {
|
||||
return Some(bad_request(format!(
|
||||
"model '{model}' is not loaded; available model is '{model_name}'"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if req.max_tokens == 0 {
|
||||
return Some(bad_request("max_tokens must be greater than 0"));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Hand a request to the engine thread. Poison-tolerant (recovers the lock if a
|
||||
/// prior handler panicked) and returns a clean 503 instead of panicking when the
|
||||
/// engine thread is gone, so one dead engine doesn't cascade into every request.
|
||||
fn submit_to_engine(state: &AppState, req: GenerateRequest) -> Result<(), Response> {
|
||||
let sender = state.engine_sender.lock().unwrap_or_else(|e| e.into_inner());
|
||||
sender.send(req).map_err(|_| service_unavailable("inference engine is not available"))
|
||||
}
|
||||
|
||||
fn service_unavailable(message: impl Into<String>) -> Response {
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": message.into(), "type": "server_error" }
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn bad_request(message: impl Into<String>) -> Response {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": message.into(),
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn make_chunk(
|
||||
id: &str,
|
||||
model: &str,
|
||||
@@ -295,5 +340,6 @@ fn build_prompt(messages: &[Message]) -> String {
|
||||
}
|
||||
}
|
||||
prompt.push_str("<|im_start|>assistant\n");
|
||||
prompt.push_str("<think>\n\n</think>\n\n");
|
||||
prompt
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::path::Path;
|
||||
use std::sync::mpsc;
|
||||
use std::sync::Once;
|
||||
use std::time::Instant;
|
||||
use xserv_model::{GpuKVCache, ModelConfig, Qwen3, SamplingParams, sample};
|
||||
use xserv_model::{ModelConfig, PagedKVCache, Qwen3, SamplingParams, sample, BLOCK_SIZE};
|
||||
use xserv_model::loader;
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
@@ -14,6 +14,7 @@ pub struct Engine {
|
||||
tokenizer: Tokenizer,
|
||||
max_batch_size: usize,
|
||||
max_seq_len: usize,
|
||||
paged_cache: PagedKVCache,
|
||||
}
|
||||
|
||||
pub struct GenerateRequest {
|
||||
@@ -34,15 +35,25 @@ struct Sequence {
|
||||
generated_tokens: Vec<u32>,
|
||||
max_tokens: usize,
|
||||
sampling: SamplingParams,
|
||||
kv_cache: Option<GpuKVCache>,
|
||||
seq_slot: Option<usize>,
|
||||
sender: tokio::sync::mpsc::Sender<GenerateEvent>,
|
||||
prefilled: bool,
|
||||
eos_token_id: Option<u32>,
|
||||
decode_buffer: Vec<u8>,
|
||||
created_at: Instant,
|
||||
}
|
||||
|
||||
impl Engine {
|
||||
pub fn load(model_dir: &Path, max_batch_size: usize, max_seq_len: usize) -> Self {
|
||||
Self::load_with_swap(model_dir, max_batch_size, max_seq_len, 8)
|
||||
}
|
||||
|
||||
pub fn load_with_swap(
|
||||
model_dir: &Path,
|
||||
max_batch_size: usize,
|
||||
max_seq_len: usize,
|
||||
swap_space_gb: usize,
|
||||
) -> Self {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
eprintln!("[engine] Loading weights...");
|
||||
@@ -50,8 +61,55 @@ impl Engine {
|
||||
eprintln!("[engine] Loaded {} tensors", weights.len());
|
||||
let model = Qwen3::from_weights(config.clone(), weights);
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
eprintln!("[engine] Ready (max_batch_size={max_batch_size}, max_seq_len={max_seq_len})");
|
||||
Self { model, config, tokenizer, max_batch_size, max_seq_len }
|
||||
|
||||
// Tier-1 sizing: size the GPU block pool to *available VRAM* after the
|
||||
// weights are resident, not to worst-case max_batch * max_ctx. This is
|
||||
// what makes paged attention elastic — sequences share the pool on
|
||||
// demand, and overflow is swapped to host (Tier-2) rather than reserved.
|
||||
let bytes_per_block = PagedKVCache::bytes_per_block(&config, DType::BF16);
|
||||
let info = xserv_cuda::device::device_info(0).expect("device info");
|
||||
// Reserve headroom for activations, cuBLAS workspace and the [B, vocab]
|
||||
// logits buffer; the transpose peak during load is already behind us.
|
||||
const ACTIVATION_RESERVE: usize = 3 * 1024 * 1024 * 1024; // 3 GiB
|
||||
let util_num = 90; // use 90% of remaining free memory for KV
|
||||
let usable = info.free_memory.saturating_sub(ACTIVATION_RESERVE);
|
||||
let mut total_blocks = (usable * util_num / 100) / bytes_per_block;
|
||||
// Cap at a sane upper bound and ensure a floor.
|
||||
total_blocks = total_blocks.max(256);
|
||||
// Test hook: force a small GPU pool to exercise the swap path. Must stay
|
||||
// >= max_blocks_per_seq so a single max-length sequence still fits.
|
||||
if let Ok(v) = std::env::var("XSERV_MAX_KV_BLOCKS") {
|
||||
if let Ok(n) = v.parse::<usize>() {
|
||||
total_blocks = total_blocks.min(n);
|
||||
eprintln!("[engine] XSERV_MAX_KV_BLOCKS override: gpu_blocks={total_blocks}");
|
||||
}
|
||||
}
|
||||
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
// Slots must cover running + swapped sequences, so be generous (cheap:
|
||||
// each slot is just a block-table row of i32s).
|
||||
let max_seqs_slots = (max_batch_size * 8).max(32);
|
||||
// CPU swap pool: swap_space_gb of pinned host memory.
|
||||
let cpu_total_blocks = (swap_space_gb * 1024 * 1024 * 1024) / bytes_per_block;
|
||||
|
||||
let paged_cache = PagedKVCache::new(
|
||||
&config,
|
||||
total_blocks,
|
||||
cpu_total_blocks,
|
||||
max_seqs_slots,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
0,
|
||||
);
|
||||
|
||||
eprintln!(
|
||||
"[engine] Ready (max_batch={max_batch_size}, max_seq_len={max_seq_len}, \
|
||||
gpu_blocks={total_blocks} ({:.1} GiB), swap_blocks={cpu_total_blocks} ({swap_space_gb} GiB), \
|
||||
free_vram={:.1} GiB)",
|
||||
(total_blocks * bytes_per_block) as f64 / 1e9,
|
||||
info.free_memory as f64 / 1e9,
|
||||
);
|
||||
Self { model, config, tokenizer, max_batch_size, max_seq_len, paged_cache }
|
||||
}
|
||||
|
||||
pub fn tokenizer(&self) -> &Tokenizer { &self.tokenizer }
|
||||
@@ -59,54 +117,124 @@ impl Engine {
|
||||
pub fn max_seq_len(&self) -> usize { self.max_seq_len }
|
||||
|
||||
/// Main scheduler loop. Receives requests from channel, manages concurrent sequences.
|
||||
pub fn run(&self, rx: mpsc::Receiver<GenerateRequest>) {
|
||||
///
|
||||
/// Sequences move between three sets:
|
||||
/// waiting — admitted to the queue, no GPU slot yet
|
||||
/// running — KV resident on GPU, actively prefilling/decoding
|
||||
/// swapped — KV evicted to pinned host memory (preempted), paused
|
||||
/// When running sequences grow past the GPU block pool, the newest are
|
||||
/// swapped out to host (vLLM-style) and swapped back in when blocks free up.
|
||||
pub fn run(&mut self, rx: mpsc::Receiver<GenerateRequest>) {
|
||||
let mut waiting: VecDeque<Sequence> = VecDeque::new();
|
||||
let mut running: Vec<Sequence> = Vec::new();
|
||||
let mut swapped: Vec<Sequence> = Vec::new();
|
||||
let mut next_id: u64 = 0;
|
||||
|
||||
eprintln!("[scheduler] Listening for requests...");
|
||||
|
||||
loop {
|
||||
// Step 1: Remove finished sequences
|
||||
// Step 1: Remove finished sequences and return their slots.
|
||||
let finished_slots: Vec<usize> = running.iter()
|
||||
.filter(|s| is_finished(s))
|
||||
.filter_map(|s| s.seq_slot)
|
||||
.collect();
|
||||
for slot in finished_slots {
|
||||
self.paged_cache.free_sequence(slot);
|
||||
}
|
||||
running.retain(|seq| !is_finished(seq));
|
||||
|
||||
// Step 2: Admit new sequences from waiting queue
|
||||
while running.len() < self.max_batch_size {
|
||||
if let Some(seq) = waiting.pop_front() {
|
||||
// Step 2: Swap previously-evicted sequences back in when there is
|
||||
// room (oldest first). They resume decoding from where they paused.
|
||||
while running.len() < self.max_batch_size && !swapped.is_empty() {
|
||||
let slot = swapped[0].seq_slot.expect("swapped slot");
|
||||
if !self.paged_cache.can_swap_in(slot) { break; }
|
||||
self.paged_cache.swap_in(slot).expect("swap_in");
|
||||
let seq = swapped.remove(0);
|
||||
eprintln!("[scheduler] swapped in seq {} ({} blocks)", seq.id, self.paged_cache.block_count(slot));
|
||||
running.push(seq);
|
||||
}
|
||||
|
||||
// Step 3: Admit new sequences (block-aware). Only admit if the GPU
|
||||
// pool can hold the prompt AND leave one block of decode headroom
|
||||
// per already-running sequence, so admission never starves decode.
|
||||
{
|
||||
let mut avail = self.paged_cache.free_blocks();
|
||||
let decode_reserve = running.len();
|
||||
while running.len() < self.max_batch_size {
|
||||
let Some(front) = waiting.front() else { break; };
|
||||
let prompt_blocks = front.prompt_tokens.len().div_ceil(BLOCK_SIZE).max(1);
|
||||
if avail < prompt_blocks + decode_reserve { break; }
|
||||
let free_slot = (0..self.paged_cache.max_seqs())
|
||||
.find(|&s| self.paged_cache.is_slot_free(s));
|
||||
let Some(slot) = free_slot else { break; };
|
||||
let mut seq = waiting.pop_front().unwrap();
|
||||
self.paged_cache.register_sequence(slot).expect("register paged slot");
|
||||
seq.seq_slot = Some(slot);
|
||||
running.push(seq);
|
||||
} else {
|
||||
break;
|
||||
avail -= prompt_blocks; // projected free after this seq prefills
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: If nothing to do, blocking wait for new request
|
||||
if running.is_empty() {
|
||||
// Step 4: If nothing to do, blocking wait for new request.
|
||||
if running.is_empty() && waiting.is_empty() && swapped.is_empty() {
|
||||
match rx.recv() {
|
||||
Ok(req) => {
|
||||
let seq = self.make_sequence(req, &mut next_id);
|
||||
running.push(seq);
|
||||
waiting.push_back(seq);
|
||||
continue;
|
||||
}
|
||||
Err(_) => break, // channel closed
|
||||
}
|
||||
}
|
||||
// Nothing runnable this iteration (e.g. all swapped, waiting on
|
||||
// blocks to free): loop to retry swap-in/admission next iteration.
|
||||
if running.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Step 4a: Process prefills (one at a time — different prompt lengths)
|
||||
// Prefill sequences must be processed individually because they have
|
||||
// different prompt lengths and each needs a full forward pass.
|
||||
// Step 5a: Process prefills (one at a time — different prompt lengths).
|
||||
// Admission guaranteed block headroom, so ensure_capacity won't starve.
|
||||
let mut newly_prefilled = Vec::new();
|
||||
for seq in running.iter_mut() {
|
||||
if !seq.prefilled {
|
||||
let logits = self.model.forward_gpu_cache(&seq.prompt_tokens, seq.kv_cache.as_mut().unwrap());
|
||||
let slot = seq.seq_slot.expect("slot");
|
||||
let logits = self.model.forward_prefill_paged(
|
||||
&seq.prompt_tokens, slot, &mut self.paged_cache,
|
||||
);
|
||||
let next = sample(&logits, &seq.sampling);
|
||||
seq.generated_tokens.push(next);
|
||||
seq.prefilled = true;
|
||||
self.emit_token(seq, next);
|
||||
emit_token(&self.tokenizer, seq, next);
|
||||
newly_prefilled.push(seq.id);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4b: Batched decode — batch all decode-ready sequences into one forward pass.
|
||||
// Projections and FFN run as [B, hidden] matmuls; attention remains per-seq.
|
||||
// Step 5b: Ensure block headroom for this decode step; preempt the
|
||||
// newest running sequences to host if the pool can't cover it.
|
||||
let mut needed = decode_block_need(&self.paged_cache, &running, &newly_prefilled);
|
||||
while self.paged_cache.free_blocks() < needed {
|
||||
// Victim: newest prefilled, decoding (not just-prefilled) sequence.
|
||||
let victim = (0..running.len()).rev().find(|&p| {
|
||||
running[p].prefilled
|
||||
&& !newly_prefilled.contains(&running[p].id)
|
||||
&& running[p].seq_slot.is_some()
|
||||
});
|
||||
let Some(pos) = victim else { break; };
|
||||
let seq = running.remove(pos);
|
||||
let slot = seq.seq_slot.unwrap();
|
||||
if self.paged_cache.can_swap_out(slot) {
|
||||
let nblocks = self.paged_cache.block_count(slot);
|
||||
self.paged_cache.swap_out(slot).expect("swap_out");
|
||||
eprintln!("[scheduler] preempt: swapped out seq {} ({nblocks} blocks) to host", seq.id);
|
||||
swapped.push(seq);
|
||||
needed = decode_block_need(&self.paged_cache, &running, &newly_prefilled);
|
||||
} else {
|
||||
running.insert(pos, seq); // CPU pool full — can't evict further
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5c: Batched paged decode for the surviving prefilled sequences.
|
||||
let decode_indices: Vec<usize> = running.iter().enumerate()
|
||||
.filter(|(_, s)| s.prefilled && !newly_prefilled.contains(&s.id))
|
||||
.map(|(i, _)| i)
|
||||
@@ -115,65 +243,44 @@ impl Engine {
|
||||
if !decode_indices.is_empty() {
|
||||
static LOG_ONCE: Once = Once::new();
|
||||
LOG_ONCE.call_once(|| {
|
||||
eprintln!("[scheduler] batched decode active");
|
||||
eprintln!("[scheduler] paged decode active");
|
||||
});
|
||||
eprintln!("[scheduler] decode batch_size={}", decode_indices.len());
|
||||
|
||||
if decode_indices.len() == 1 {
|
||||
// Single sequence: use per-seq path (no batching overhead)
|
||||
let i = decode_indices[0];
|
||||
let last = *running[i].generated_tokens.last().unwrap();
|
||||
let logits = self.model.forward_gpu_cache(&[last], running[i].kv_cache.as_mut().unwrap());
|
||||
let next = sample(&logits, &running[i].sampling);
|
||||
let tokens: Vec<u32> = decode_indices.iter()
|
||||
.map(|&i| *running[i].generated_tokens.last().unwrap())
|
||||
.collect();
|
||||
let positions: Vec<usize> = decode_indices.iter()
|
||||
.map(|&i| self.paged_cache.seq_len(running[i].seq_slot.unwrap()))
|
||||
.collect();
|
||||
let slots: Vec<usize> = decode_indices.iter()
|
||||
.map(|&i| running[i].seq_slot.unwrap())
|
||||
.collect();
|
||||
|
||||
let logits = self.model.forward_decode_paged(
|
||||
&tokens, &positions, &slots, &mut self.paged_cache,
|
||||
);
|
||||
|
||||
// Sample per-sequence from batched logits [B, vocab_size]
|
||||
let vocab_size = logits.shape()[1];
|
||||
let logits_cpu = logits.to_device(xserv_tensor::Device::Cpu);
|
||||
let data = logits_cpu.as_slice::<half::bf16>();
|
||||
for (j, &i) in decode_indices.iter().enumerate() {
|
||||
let row_start = j * vocab_size;
|
||||
let row_logits = &data[row_start..row_start + vocab_size];
|
||||
let next = if running[i].sampling.temperature == 0.0 {
|
||||
row_logits.iter().enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(idx, _)| idx as u32).unwrap()
|
||||
} else {
|
||||
let row_tensor = xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
|
||||
sample(&row_tensor, &running[i].sampling)
|
||||
};
|
||||
running[i].generated_tokens.push(next);
|
||||
self.emit_token(&running[i], next);
|
||||
} else {
|
||||
// Batched decode: extract tokens and positions
|
||||
let tokens: Vec<u32> = decode_indices.iter()
|
||||
.map(|&i| *running[i].generated_tokens.last().unwrap())
|
||||
.collect();
|
||||
let positions: Vec<usize> = decode_indices.iter()
|
||||
.map(|&i| running[i].kv_cache.as_ref().unwrap().seq_len())
|
||||
.collect();
|
||||
|
||||
// Take caches out of sequences via Option::take (no dummy allocation).
|
||||
let mut caches: Vec<GpuKVCache> = decode_indices.iter()
|
||||
.map(|&i| running[i].kv_cache.take().unwrap())
|
||||
.collect();
|
||||
let mut cache_refs: Vec<&mut GpuKVCache> = caches.iter_mut().collect();
|
||||
|
||||
let logits = self.model.forward_decode_batch(&tokens, &positions, &mut cache_refs);
|
||||
|
||||
// Put caches back: pop from end while iterating in reverse
|
||||
drop(cache_refs);
|
||||
for &i in decode_indices.iter().rev() {
|
||||
running[i].kv_cache = Some(caches.pop().unwrap());
|
||||
}
|
||||
|
||||
// Sample per-sequence from batched logits [B, vocab_size]
|
||||
let vocab_size = logits.shape()[1];
|
||||
let logits_cpu = logits.to_device(xserv_tensor::Device::Cpu);
|
||||
let data = logits_cpu.as_slice::<half::bf16>();
|
||||
for (j, &i) in decode_indices.iter().enumerate() {
|
||||
let row_start = j * vocab_size;
|
||||
let row_logits = &data[row_start..row_start + vocab_size];
|
||||
let next = if running[i].sampling.temperature == 0.0 {
|
||||
// Greedy: argmax
|
||||
row_logits.iter().enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(idx, _)| idx as u32).unwrap()
|
||||
} else {
|
||||
// Use the row as a single-row tensor for full sampling
|
||||
let row_tensor = xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
|
||||
sample(&row_tensor, &running[i].sampling)
|
||||
};
|
||||
running[i].generated_tokens.push(next);
|
||||
self.emit_token(&running[i], next);
|
||||
}
|
||||
emit_token(&self.tokenizer, &mut running[i], next);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Check for newly arrived requests (non-blocking)
|
||||
// Step 6: Check for newly arrived requests (non-blocking)
|
||||
loop {
|
||||
match rx.try_recv() {
|
||||
Ok(req) => {
|
||||
@@ -187,39 +294,62 @@ impl Engine {
|
||||
}
|
||||
}
|
||||
|
||||
fn make_sequence(&self, req: GenerateRequest, next_id: &mut u64) -> Sequence {
|
||||
fn make_sequence(&mut self, req: GenerateRequest, next_id: &mut u64) -> Sequence {
|
||||
let id = *next_id;
|
||||
*next_id += 1;
|
||||
let kv_cache = GpuKVCache::new(&self.config, self.max_seq_len, DType::BF16, 0);
|
||||
Sequence {
|
||||
id,
|
||||
prompt_tokens: req.prompt_tokens,
|
||||
generated_tokens: Vec::new(),
|
||||
max_tokens: req.max_tokens,
|
||||
sampling: req.sampling,
|
||||
kv_cache: Some(kv_cache),
|
||||
seq_slot: None,
|
||||
sender: req.sender,
|
||||
prefilled: false,
|
||||
eos_token_id: self.tokenizer.eos_token_id(),
|
||||
decode_buffer: Vec::new(),
|
||||
created_at: Instant::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_token(&self, seq: &Sequence, token_id: u32) {
|
||||
let text = self.tokenizer.decode(&[token_id]);
|
||||
/// Total additional GPU blocks the next decode step needs across all
|
||||
/// currently-decoding (prefilled, not just-prefilled) sequences.
|
||||
fn decode_block_need(paged: &PagedKVCache, running: &[Sequence], newly_prefilled: &[u64]) -> usize {
|
||||
running.iter()
|
||||
.filter(|s| s.prefilled && !newly_prefilled.contains(&s.id))
|
||||
.filter_map(|s| s.seq_slot)
|
||||
.map(|slot| paged.additional_blocks_needed(slot, 1))
|
||||
.sum()
|
||||
}
|
||||
|
||||
if self.tokenizer.eos_token_id() == Some(token_id) {
|
||||
let _ = seq.sender.blocking_send(GenerateEvent::Done {
|
||||
finish_reason: "stop".to_string(),
|
||||
});
|
||||
} else if seq.generated_tokens.len() >= seq.max_tokens {
|
||||
let _ = seq.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
|
||||
let _ = seq.sender.blocking_send(GenerateEvent::Done {
|
||||
finish_reason: "length".to_string(),
|
||||
});
|
||||
} else {
|
||||
let _ = seq.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
|
||||
}
|
||||
fn emit_token(tokenizer: &Tokenizer, seq: &mut Sequence, token_id: u32) {
|
||||
if tokenizer.eos_token_id() == Some(token_id) {
|
||||
let tail = tokenizer.flush_decode_stream(&mut seq.decode_buffer);
|
||||
send_token_if_nonempty(seq, tail);
|
||||
let _ = seq.sender.blocking_send(GenerateEvent::Done {
|
||||
finish_reason: "stop".to_string(),
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
let text = tokenizer.decode_token_stream(token_id, &mut seq.decode_buffer);
|
||||
if seq.generated_tokens.len() >= seq.max_tokens {
|
||||
let tail = tokenizer.flush_decode_stream(&mut seq.decode_buffer);
|
||||
send_token_if_nonempty(seq, text);
|
||||
send_token_if_nonempty(seq, tail);
|
||||
let _ = seq.sender.blocking_send(GenerateEvent::Done {
|
||||
finish_reason: "length".to_string(),
|
||||
});
|
||||
} else {
|
||||
send_token_if_nonempty(seq, text);
|
||||
}
|
||||
}
|
||||
|
||||
fn send_token_if_nonempty(seq: &Sequence, text: String) {
|
||||
if !text.is_empty() {
|
||||
let id = *seq.generated_tokens.last().unwrap_or(&0);
|
||||
let _ = seq.sender.blocking_send(GenerateEvent::Token { id, text });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -227,7 +357,5 @@ fn is_finished(seq: &Sequence) -> bool {
|
||||
if seq.generated_tokens.is_empty() { return false; }
|
||||
let last = *seq.generated_tokens.last().unwrap();
|
||||
if seq.generated_tokens.len() >= seq.max_tokens { return true; }
|
||||
// Check EOS — need tokenizer info. Use a simple heuristic:
|
||||
// If sender is closed (receiver dropped), also consider finished.
|
||||
seq.sender.is_closed() || seq.eos_token_id == Some(last)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ use axum::{routing::{get, post}, Extension, Router};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{mpsc, Arc, Mutex};
|
||||
use engine::GenerateRequest;
|
||||
use xserv_model::ModelConfig;
|
||||
|
||||
pub struct AppState {
|
||||
pub model_name: String,
|
||||
@@ -17,7 +18,7 @@ pub struct AppState {
|
||||
async fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N]");
|
||||
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N]");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
@@ -31,12 +32,31 @@ async fn main() {
|
||||
.position(|a| a == "--max-batch")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(4);
|
||||
let max_seq_len: usize = args.iter()
|
||||
.unwrap_or(4)
|
||||
.max(1);
|
||||
let requested_max_seq_len: usize = args.iter()
|
||||
.position(|a| a == "--max-seq-len")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(2048);
|
||||
.unwrap_or(2048)
|
||||
.max(1);
|
||||
let swap_space_gb: usize = args.iter()
|
||||
.position(|a| a == "--swap-space-gb")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(8);
|
||||
let model_config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
let model_max_seq_len = model_config.max_seq_len();
|
||||
if model_max_seq_len == 0 {
|
||||
eprintln!("model config has invalid max_seq_len=0");
|
||||
std::process::exit(1);
|
||||
}
|
||||
let max_seq_len = requested_max_seq_len.min(model_max_seq_len);
|
||||
if max_seq_len != requested_max_seq_len {
|
||||
eprintln!(
|
||||
"[server] --max-seq-len {requested_max_seq_len} exceeds model limit {model_max_seq_len}; using {max_seq_len}"
|
||||
);
|
||||
}
|
||||
|
||||
let model_name = model_dir.file_name()
|
||||
.map(|n| n.to_string_lossy().to_string())
|
||||
@@ -49,7 +69,7 @@ async fn main() {
|
||||
|
||||
let model_dir_clone = model_dir.clone();
|
||||
std::thread::spawn(move || {
|
||||
let engine = engine::Engine::load(&model_dir_clone, max_batch, max_seq_len);
|
||||
let mut engine = engine::Engine::load_with_swap(&model_dir_clone, max_batch, max_seq_len, swap_space_gb);
|
||||
engine.run(rx);
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user