server: VRAM-sized KV pool + vLLM-style swap scheduler

Fixes the paged-KV OOM at large --max-seq-len and adds elastic memory:

- Size the GPU block pool to available VRAM (cudaMemGetInfo) instead of the
  worst-case blocks_per_seq * max_batch * 2 reservation, which OOM'd at 8192.
- Scheduler tracks waiting/running/swapped sets: block-aware admission,
  swap-in of resumable sequences when blocks free, and preemption of the
  newest running sequence to host when the pool can't cover a decode step.
- --swap-space-gb (default 8) sizes the pinned host swap pool;
  XSERV_MAX_KV_BLOCKS forces a small pool to exercise swapping.
- api: poison-tolerant lock + clean 503 when the engine thread is gone,
  instead of cascading mutex-poison panics.

Verified on RTX 5090: serves at --max-seq-len 8192 (previously OOM), and a
forced 40-block pool drives 48 lossless swap-out/swap-in cycles under
concurrency with coherent output.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-28 19:59:06 +08:00
parent d52baa0006
commit fc1900a745
3 changed files with 316 additions and 122 deletions

View File

@@ -85,18 +85,20 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
let model_name = state.model_name.clone();
let created = unix_timestamp();
if let Some(response) = validate_request(&req, &model_name) {
return response;
}
let prompt = build_prompt(&req.messages);
let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt);
let prompt_token_count = prompt_tokens.len();
let max_seq_len = state.max_seq_len;
if prompt_token_count >= max_seq_len {
return (StatusCode::BAD_REQUEST, Json(serde_json::json!({
"error": {
"message": format!("prompt is {} tokens, exceeds max_seq_len {}", prompt_token_count, max_seq_len),
"type": "invalid_request_error"
}
}))).into_response();
return bad_request(format!(
"prompt is {} tokens, exceeds max_seq_len {}",
prompt_token_count, max_seq_len
));
}
let max_tokens = req.max_tokens.min(max_seq_len - prompt_token_count);
@@ -107,12 +109,9 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
sampling: sampling_params(&req),
sender: tx,
};
state
.engine_sender
.lock()
.unwrap()
.send(gen_req)
.expect("engine channel closed");
if let Err(resp) = submit_to_engine(&state, gen_req) {
return resp;
}
let mut content = String::new();
let mut completion_token_count: usize = 0;
@@ -156,17 +155,19 @@ fn chat_stream(
let model_name = state.model_name.clone();
let created = unix_timestamp();
if let Some(response) = validate_request(&req, &model_name) {
return response;
}
let prompt = build_prompt(&req.messages);
let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt);
let max_seq_len = state.max_seq_len;
if prompt_tokens.len() >= max_seq_len {
return (StatusCode::BAD_REQUEST, Json(serde_json::json!({
"error": {
"message": format!("prompt is {} tokens, exceeds max_seq_len {}", prompt_tokens.len(), max_seq_len),
"type": "invalid_request_error"
}
}))).into_response();
return bad_request(format!(
"prompt is {} tokens, exceeds max_seq_len {}",
prompt_tokens.len(), max_seq_len
));
}
let max_tokens = req.max_tokens.min(max_seq_len - prompt_tokens.len());
@@ -177,12 +178,9 @@ fn chat_stream(
sampling: sampling_params(&req),
sender: engine_tx,
};
state
.engine_sender
.lock()
.unwrap()
.send(gen_req)
.expect("engine channel closed");
if let Err(resp) = submit_to_engine(&state, gen_req) {
return resp;
}
// SSE event channel: engine events -> SSE events
let (sse_tx, sse_rx) = tokio::sync::mpsc::channel::<Result<Event, Infallible>>(64);
@@ -228,6 +226,53 @@ fn chat_stream(
Sse::new(ReceiverStream::new(sse_rx)).keep_alive(KeepAlive::default()).into_response()
}
fn validate_request(req: &ChatRequest, model_name: &str) -> Option<Response> {
if let Some(model) = &req.model {
if model != model_name {
return Some(bad_request(format!(
"model '{model}' is not loaded; available model is '{model_name}'"
)));
}
}
if req.max_tokens == 0 {
return Some(bad_request("max_tokens must be greater than 0"));
}
None
}
/// Hand a request to the engine thread. Poison-tolerant (recovers the lock if a
/// prior handler panicked) and returns a clean 503 instead of panicking when the
/// engine thread is gone, so one dead engine doesn't cascade into every request.
fn submit_to_engine(state: &AppState, req: GenerateRequest) -> Result<(), Response> {
let sender = state.engine_sender.lock().unwrap_or_else(|e| e.into_inner());
sender.send(req).map_err(|_| service_unavailable("inference engine is not available"))
}
fn service_unavailable(message: impl Into<String>) -> Response {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": { "message": message.into(), "type": "server_error" }
})),
)
.into_response()
}
fn bad_request(message: impl Into<String>) -> Response {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"message": message.into(),
"type": "invalid_request_error"
}
})),
)
.into_response()
}
fn make_chunk(
id: &str,
model: &str,
@@ -295,5 +340,6 @@ fn build_prompt(messages: &[Message]) -> String {
}
}
prompt.push_str("<|im_start|>assistant\n");
prompt.push_str("<think>\n\n</think>\n\n");
prompt
}

View File

@@ -3,7 +3,7 @@ use std::path::Path;
use std::sync::mpsc;
use std::sync::Once;
use std::time::Instant;
use xserv_model::{GpuKVCache, ModelConfig, Qwen3, SamplingParams, sample};
use xserv_model::{ModelConfig, PagedKVCache, Qwen3, SamplingParams, sample, BLOCK_SIZE};
use xserv_model::loader;
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -14,6 +14,7 @@ pub struct Engine {
tokenizer: Tokenizer,
max_batch_size: usize,
max_seq_len: usize,
paged_cache: PagedKVCache,
}
pub struct GenerateRequest {
@@ -34,15 +35,25 @@ struct Sequence {
generated_tokens: Vec<u32>,
max_tokens: usize,
sampling: SamplingParams,
kv_cache: Option<GpuKVCache>,
seq_slot: Option<usize>,
sender: tokio::sync::mpsc::Sender<GenerateEvent>,
prefilled: bool,
eos_token_id: Option<u32>,
decode_buffer: Vec<u8>,
created_at: Instant,
}
impl Engine {
pub fn load(model_dir: &Path, max_batch_size: usize, max_seq_len: usize) -> Self {
Self::load_with_swap(model_dir, max_batch_size, max_seq_len, 8)
}
pub fn load_with_swap(
model_dir: &Path,
max_batch_size: usize,
max_seq_len: usize,
swap_space_gb: usize,
) -> Self {
xserv_cuda::device::set_device(0).unwrap();
let config = ModelConfig::from_file(&model_dir.join("config.json"));
eprintln!("[engine] Loading weights...");
@@ -50,8 +61,55 @@ impl Engine {
eprintln!("[engine] Loaded {} tensors", weights.len());
let model = Qwen3::from_weights(config.clone(), weights);
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
eprintln!("[engine] Ready (max_batch_size={max_batch_size}, max_seq_len={max_seq_len})");
Self { model, config, tokenizer, max_batch_size, max_seq_len }
// Tier-1 sizing: size the GPU block pool to *available VRAM* after the
// weights are resident, not to worst-case max_batch * max_ctx. This is
// what makes paged attention elastic — sequences share the pool on
// demand, and overflow is swapped to host (Tier-2) rather than reserved.
let bytes_per_block = PagedKVCache::bytes_per_block(&config, DType::BF16);
let info = xserv_cuda::device::device_info(0).expect("device info");
// Reserve headroom for activations, cuBLAS workspace and the [B, vocab]
// logits buffer; the transpose peak during load is already behind us.
const ACTIVATION_RESERVE: usize = 3 * 1024 * 1024 * 1024; // 3 GiB
let util_num = 90; // use 90% of remaining free memory for KV
let usable = info.free_memory.saturating_sub(ACTIVATION_RESERVE);
let mut total_blocks = (usable * util_num / 100) / bytes_per_block;
// Cap at a sane upper bound and ensure a floor.
total_blocks = total_blocks.max(256);
// Test hook: force a small GPU pool to exercise the swap path. Must stay
// >= max_blocks_per_seq so a single max-length sequence still fits.
if let Ok(v) = std::env::var("XSERV_MAX_KV_BLOCKS") {
if let Ok(n) = v.parse::<usize>() {
total_blocks = total_blocks.min(n);
eprintln!("[engine] XSERV_MAX_KV_BLOCKS override: gpu_blocks={total_blocks}");
}
}
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
// Slots must cover running + swapped sequences, so be generous (cheap:
// each slot is just a block-table row of i32s).
let max_seqs_slots = (max_batch_size * 8).max(32);
// CPU swap pool: swap_space_gb of pinned host memory.
let cpu_total_blocks = (swap_space_gb * 1024 * 1024 * 1024) / bytes_per_block;
let paged_cache = PagedKVCache::new(
&config,
total_blocks,
cpu_total_blocks,
max_seqs_slots,
max_blocks_per_seq,
DType::BF16,
0,
);
eprintln!(
"[engine] Ready (max_batch={max_batch_size}, max_seq_len={max_seq_len}, \
gpu_blocks={total_blocks} ({:.1} GiB), swap_blocks={cpu_total_blocks} ({swap_space_gb} GiB), \
free_vram={:.1} GiB)",
(total_blocks * bytes_per_block) as f64 / 1e9,
info.free_memory as f64 / 1e9,
);
Self { model, config, tokenizer, max_batch_size, max_seq_len, paged_cache }
}
pub fn tokenizer(&self) -> &Tokenizer { &self.tokenizer }
@@ -59,54 +117,124 @@ impl Engine {
pub fn max_seq_len(&self) -> usize { self.max_seq_len }
/// Main scheduler loop. Receives requests from channel, manages concurrent sequences.
pub fn run(&self, rx: mpsc::Receiver<GenerateRequest>) {
///
/// Sequences move between three sets:
/// waiting — admitted to the queue, no GPU slot yet
/// running — KV resident on GPU, actively prefilling/decoding
/// swapped — KV evicted to pinned host memory (preempted), paused
/// When running sequences grow past the GPU block pool, the newest are
/// swapped out to host (vLLM-style) and swapped back in when blocks free up.
pub fn run(&mut self, rx: mpsc::Receiver<GenerateRequest>) {
let mut waiting: VecDeque<Sequence> = VecDeque::new();
let mut running: Vec<Sequence> = Vec::new();
let mut swapped: Vec<Sequence> = Vec::new();
let mut next_id: u64 = 0;
eprintln!("[scheduler] Listening for requests...");
loop {
// Step 1: Remove finished sequences
// Step 1: Remove finished sequences and return their slots.
let finished_slots: Vec<usize> = running.iter()
.filter(|s| is_finished(s))
.filter_map(|s| s.seq_slot)
.collect();
for slot in finished_slots {
self.paged_cache.free_sequence(slot);
}
running.retain(|seq| !is_finished(seq));
// Step 2: Admit new sequences from waiting queue
while running.len() < self.max_batch_size {
if let Some(seq) = waiting.pop_front() {
// Step 2: Swap previously-evicted sequences back in when there is
// room (oldest first). They resume decoding from where they paused.
while running.len() < self.max_batch_size && !swapped.is_empty() {
let slot = swapped[0].seq_slot.expect("swapped slot");
if !self.paged_cache.can_swap_in(slot) { break; }
self.paged_cache.swap_in(slot).expect("swap_in");
let seq = swapped.remove(0);
eprintln!("[scheduler] swapped in seq {} ({} blocks)", seq.id, self.paged_cache.block_count(slot));
running.push(seq);
}
// Step 3: Admit new sequences (block-aware). Only admit if the GPU
// pool can hold the prompt AND leave one block of decode headroom
// per already-running sequence, so admission never starves decode.
{
let mut avail = self.paged_cache.free_blocks();
let decode_reserve = running.len();
while running.len() < self.max_batch_size {
let Some(front) = waiting.front() else { break; };
let prompt_blocks = front.prompt_tokens.len().div_ceil(BLOCK_SIZE).max(1);
if avail < prompt_blocks + decode_reserve { break; }
let free_slot = (0..self.paged_cache.max_seqs())
.find(|&s| self.paged_cache.is_slot_free(s));
let Some(slot) = free_slot else { break; };
let mut seq = waiting.pop_front().unwrap();
self.paged_cache.register_sequence(slot).expect("register paged slot");
seq.seq_slot = Some(slot);
running.push(seq);
} else {
break;
avail -= prompt_blocks; // projected free after this seq prefills
}
}
// Step 3: If nothing to do, blocking wait for new request
if running.is_empty() {
// Step 4: If nothing to do, blocking wait for new request.
if running.is_empty() && waiting.is_empty() && swapped.is_empty() {
match rx.recv() {
Ok(req) => {
let seq = self.make_sequence(req, &mut next_id);
running.push(seq);
waiting.push_back(seq);
continue;
}
Err(_) => break, // channel closed
}
}
// Nothing runnable this iteration (e.g. all swapped, waiting on
// blocks to free): loop to retry swap-in/admission next iteration.
if running.is_empty() {
continue;
}
// Step 4a: Process prefills (one at a time — different prompt lengths)
// Prefill sequences must be processed individually because they have
// different prompt lengths and each needs a full forward pass.
// Step 5a: Process prefills (one at a time — different prompt lengths).
// Admission guaranteed block headroom, so ensure_capacity won't starve.
let mut newly_prefilled = Vec::new();
for seq in running.iter_mut() {
if !seq.prefilled {
let logits = self.model.forward_gpu_cache(&seq.prompt_tokens, seq.kv_cache.as_mut().unwrap());
let slot = seq.seq_slot.expect("slot");
let logits = self.model.forward_prefill_paged(
&seq.prompt_tokens, slot, &mut self.paged_cache,
);
let next = sample(&logits, &seq.sampling);
seq.generated_tokens.push(next);
seq.prefilled = true;
self.emit_token(seq, next);
emit_token(&self.tokenizer, seq, next);
newly_prefilled.push(seq.id);
}
}
// Step 4b: Batched decode — batch all decode-ready sequences into one forward pass.
// Projections and FFN run as [B, hidden] matmuls; attention remains per-seq.
// Step 5b: Ensure block headroom for this decode step; preempt the
// newest running sequences to host if the pool can't cover it.
let mut needed = decode_block_need(&self.paged_cache, &running, &newly_prefilled);
while self.paged_cache.free_blocks() < needed {
// Victim: newest prefilled, decoding (not just-prefilled) sequence.
let victim = (0..running.len()).rev().find(|&p| {
running[p].prefilled
&& !newly_prefilled.contains(&running[p].id)
&& running[p].seq_slot.is_some()
});
let Some(pos) = victim else { break; };
let seq = running.remove(pos);
let slot = seq.seq_slot.unwrap();
if self.paged_cache.can_swap_out(slot) {
let nblocks = self.paged_cache.block_count(slot);
self.paged_cache.swap_out(slot).expect("swap_out");
eprintln!("[scheduler] preempt: swapped out seq {} ({nblocks} blocks) to host", seq.id);
swapped.push(seq);
needed = decode_block_need(&self.paged_cache, &running, &newly_prefilled);
} else {
running.insert(pos, seq); // CPU pool full — can't evict further
break;
}
}
// Step 5c: Batched paged decode for the surviving prefilled sequences.
let decode_indices: Vec<usize> = running.iter().enumerate()
.filter(|(_, s)| s.prefilled && !newly_prefilled.contains(&s.id))
.map(|(i, _)| i)
@@ -115,65 +243,44 @@ impl Engine {
if !decode_indices.is_empty() {
static LOG_ONCE: Once = Once::new();
LOG_ONCE.call_once(|| {
eprintln!("[scheduler] batched decode active");
eprintln!("[scheduler] paged decode active");
});
eprintln!("[scheduler] decode batch_size={}", decode_indices.len());
if decode_indices.len() == 1 {
// Single sequence: use per-seq path (no batching overhead)
let i = decode_indices[0];
let last = *running[i].generated_tokens.last().unwrap();
let logits = self.model.forward_gpu_cache(&[last], running[i].kv_cache.as_mut().unwrap());
let next = sample(&logits, &running[i].sampling);
let tokens: Vec<u32> = decode_indices.iter()
.map(|&i| *running[i].generated_tokens.last().unwrap())
.collect();
let positions: Vec<usize> = decode_indices.iter()
.map(|&i| self.paged_cache.seq_len(running[i].seq_slot.unwrap()))
.collect();
let slots: Vec<usize> = decode_indices.iter()
.map(|&i| running[i].seq_slot.unwrap())
.collect();
let logits = self.model.forward_decode_paged(
&tokens, &positions, &slots, &mut self.paged_cache,
);
// Sample per-sequence from batched logits [B, vocab_size]
let vocab_size = logits.shape()[1];
let logits_cpu = logits.to_device(xserv_tensor::Device::Cpu);
let data = logits_cpu.as_slice::<half::bf16>();
for (j, &i) in decode_indices.iter().enumerate() {
let row_start = j * vocab_size;
let row_logits = &data[row_start..row_start + vocab_size];
let next = if running[i].sampling.temperature == 0.0 {
row_logits.iter().enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(idx, _)| idx as u32).unwrap()
} else {
let row_tensor = xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
sample(&row_tensor, &running[i].sampling)
};
running[i].generated_tokens.push(next);
self.emit_token(&running[i], next);
} else {
// Batched decode: extract tokens and positions
let tokens: Vec<u32> = decode_indices.iter()
.map(|&i| *running[i].generated_tokens.last().unwrap())
.collect();
let positions: Vec<usize> = decode_indices.iter()
.map(|&i| running[i].kv_cache.as_ref().unwrap().seq_len())
.collect();
// Take caches out of sequences via Option::take (no dummy allocation).
let mut caches: Vec<GpuKVCache> = decode_indices.iter()
.map(|&i| running[i].kv_cache.take().unwrap())
.collect();
let mut cache_refs: Vec<&mut GpuKVCache> = caches.iter_mut().collect();
let logits = self.model.forward_decode_batch(&tokens, &positions, &mut cache_refs);
// Put caches back: pop from end while iterating in reverse
drop(cache_refs);
for &i in decode_indices.iter().rev() {
running[i].kv_cache = Some(caches.pop().unwrap());
}
// Sample per-sequence from batched logits [B, vocab_size]
let vocab_size = logits.shape()[1];
let logits_cpu = logits.to_device(xserv_tensor::Device::Cpu);
let data = logits_cpu.as_slice::<half::bf16>();
for (j, &i) in decode_indices.iter().enumerate() {
let row_start = j * vocab_size;
let row_logits = &data[row_start..row_start + vocab_size];
let next = if running[i].sampling.temperature == 0.0 {
// Greedy: argmax
row_logits.iter().enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(idx, _)| idx as u32).unwrap()
} else {
// Use the row as a single-row tensor for full sampling
let row_tensor = xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
sample(&row_tensor, &running[i].sampling)
};
running[i].generated_tokens.push(next);
self.emit_token(&running[i], next);
}
emit_token(&self.tokenizer, &mut running[i], next);
}
}
// Step 5: Check for newly arrived requests (non-blocking)
// Step 6: Check for newly arrived requests (non-blocking)
loop {
match rx.try_recv() {
Ok(req) => {
@@ -187,39 +294,62 @@ impl Engine {
}
}
fn make_sequence(&self, req: GenerateRequest, next_id: &mut u64) -> Sequence {
fn make_sequence(&mut self, req: GenerateRequest, next_id: &mut u64) -> Sequence {
let id = *next_id;
*next_id += 1;
let kv_cache = GpuKVCache::new(&self.config, self.max_seq_len, DType::BF16, 0);
Sequence {
id,
prompt_tokens: req.prompt_tokens,
generated_tokens: Vec::new(),
max_tokens: req.max_tokens,
sampling: req.sampling,
kv_cache: Some(kv_cache),
seq_slot: None,
sender: req.sender,
prefilled: false,
eos_token_id: self.tokenizer.eos_token_id(),
decode_buffer: Vec::new(),
created_at: Instant::now(),
}
}
}
fn emit_token(&self, seq: &Sequence, token_id: u32) {
let text = self.tokenizer.decode(&[token_id]);
/// Total additional GPU blocks the next decode step needs across all
/// currently-decoding (prefilled, not just-prefilled) sequences.
fn decode_block_need(paged: &PagedKVCache, running: &[Sequence], newly_prefilled: &[u64]) -> usize {
running.iter()
.filter(|s| s.prefilled && !newly_prefilled.contains(&s.id))
.filter_map(|s| s.seq_slot)
.map(|slot| paged.additional_blocks_needed(slot, 1))
.sum()
}
if self.tokenizer.eos_token_id() == Some(token_id) {
let _ = seq.sender.blocking_send(GenerateEvent::Done {
finish_reason: "stop".to_string(),
});
} else if seq.generated_tokens.len() >= seq.max_tokens {
let _ = seq.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
let _ = seq.sender.blocking_send(GenerateEvent::Done {
finish_reason: "length".to_string(),
});
} else {
let _ = seq.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
}
fn emit_token(tokenizer: &Tokenizer, seq: &mut Sequence, token_id: u32) {
if tokenizer.eos_token_id() == Some(token_id) {
let tail = tokenizer.flush_decode_stream(&mut seq.decode_buffer);
send_token_if_nonempty(seq, tail);
let _ = seq.sender.blocking_send(GenerateEvent::Done {
finish_reason: "stop".to_string(),
});
return;
}
let text = tokenizer.decode_token_stream(token_id, &mut seq.decode_buffer);
if seq.generated_tokens.len() >= seq.max_tokens {
let tail = tokenizer.flush_decode_stream(&mut seq.decode_buffer);
send_token_if_nonempty(seq, text);
send_token_if_nonempty(seq, tail);
let _ = seq.sender.blocking_send(GenerateEvent::Done {
finish_reason: "length".to_string(),
});
} else {
send_token_if_nonempty(seq, text);
}
}
fn send_token_if_nonempty(seq: &Sequence, text: String) {
if !text.is_empty() {
let id = *seq.generated_tokens.last().unwrap_or(&0);
let _ = seq.sender.blocking_send(GenerateEvent::Token { id, text });
}
}
@@ -227,7 +357,5 @@ fn is_finished(seq: &Sequence) -> bool {
if seq.generated_tokens.is_empty() { return false; }
let last = *seq.generated_tokens.last().unwrap();
if seq.generated_tokens.len() >= seq.max_tokens { return true; }
// Check EOS — need tokenizer info. Use a simple heuristic:
// If sender is closed (receiver dropped), also consider finished.
seq.sender.is_closed() || seq.eos_token_id == Some(last)
}

View File

@@ -5,6 +5,7 @@ use axum::{routing::{get, post}, Extension, Router};
use std::path::PathBuf;
use std::sync::{mpsc, Arc, Mutex};
use engine::GenerateRequest;
use xserv_model::ModelConfig;
pub struct AppState {
pub model_name: String,
@@ -17,7 +18,7 @@ pub struct AppState {
async fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N]");
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N]");
std::process::exit(1);
}
@@ -31,12 +32,31 @@ async fn main() {
.position(|a| a == "--max-batch")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(4);
let max_seq_len: usize = args.iter()
.unwrap_or(4)
.max(1);
let requested_max_seq_len: usize = args.iter()
.position(|a| a == "--max-seq-len")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(2048);
.unwrap_or(2048)
.max(1);
let swap_space_gb: usize = args.iter()
.position(|a| a == "--swap-space-gb")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(8);
let model_config = ModelConfig::from_file(&model_dir.join("config.json"));
let model_max_seq_len = model_config.max_seq_len();
if model_max_seq_len == 0 {
eprintln!("model config has invalid max_seq_len=0");
std::process::exit(1);
}
let max_seq_len = requested_max_seq_len.min(model_max_seq_len);
if max_seq_len != requested_max_seq_len {
eprintln!(
"[server] --max-seq-len {requested_max_seq_len} exceeds model limit {model_max_seq_len}; using {max_seq_len}"
);
}
let model_name = model_dir.file_name()
.map(|n| n.to_string_lossy().to_string())
@@ -49,7 +69,7 @@ async fn main() {
let model_dir_clone = model_dir.clone();
std::thread::spawn(move || {
let engine = engine::Engine::load(&model_dir_clone, max_batch, max_seq_len);
let mut engine = engine::Engine::load_with_swap(&model_dir_clone, max_batch, max_seq_len, swap_space_gb);
engine.run(rx);
});