From d96ee0766c18ed522f1e6befe6102403072d895d Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Wed, 1 Jul 2026 15:13:24 +0800 Subject: [PATCH] server: sampling-param validation, finish_reason normalization, backpressure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three related hardening changes for the API surface: - validate_request rejects NaN/negative temperature, out-of-range top_p, and absurd top_k before those values reach the CUDA sampling paths. Prevents NaN logits from downstream sampling and matches typical OpenAI-compatible server behavior (400 instead of 500). - normalize_finish_reason maps engine strings to the OpenAI-standard subset. Currently only "error" (from tp/pp engine client-stall) needs normalization — it collapses to null so SDK clients see a clean stream close instead of an unknown finish_reason value. Applied to both streaming (SSE) and non-streaming JSON responses. - Replace the unbounded std::sync::mpsc engine channel with a bounded sync_channel(256) and switch submit_to_engine to try_send. A saturated engine now returns 503 "engine is busy" instead of letting requests pile up in RAM. Also add axum DefaultBodyLimit(4 MiB) so a malicious or misbehaving client cannot exhaust memory with an arbitrary JSON POST. --- crates/xserv-server/src/api.rs | 51 +++++++++++++++++++++++++++++---- crates/xserv-server/src/main.rs | 10 +++++-- 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/crates/xserv-server/src/api.rs b/crates/xserv-server/src/api.rs index 73eacc6..c7a7d70 100644 --- a/crates/xserv-server/src/api.rs +++ b/crates/xserv-server/src/api.rs @@ -331,6 +331,10 @@ async fn chat_non_stream(state: Arc, req: ChatRequest) -> Response { } } + let fr_value = match normalize_finish_reason(&finish_reason) { + Some(s) => serde_json::Value::String(s.to_string()), + None => serde_json::Value::Null, + }; Json(serde_json::json!({ "id": id, "object": "chat.completion", @@ -339,7 +343,7 @@ async fn chat_non_stream(state: Arc, req: ChatRequest) -> Response { "choices": [{ "index": 0, "message": { "role": "assistant", "content": content }, - "finish_reason": finish_reason, + "finish_reason": fr_value, }], "usage": { "prompt_tokens": prompt_token_count, @@ -412,8 +416,11 @@ fn chat_stream(state: Arc, req: ChatRequest) -> Response { make_chunk(&id, &model_name, created, None, Some("assistant"), None); let _ = sse_tx.send(Ok(Event::default().data(chunk))).await; } - let chunk = - make_chunk(&id, &model_name, created, None, None, Some(&finish_reason)); + // Only "stop" and "length" are OpenAI-standard values. Internal + // codes like "error" (client-stalled from tp/pp engine) map to + // null so SDK clients see a clean stream close. + let fr = normalize_finish_reason(&finish_reason); + let chunk = make_chunk(&id, &model_name, created, None, None, fr); let _ = sse_tx.send(Ok(Event::default().data(chunk))).await; let _ = sse_tx .send(Ok(Event::default().data("[DONE]".to_string()))) @@ -442,6 +449,22 @@ fn validate_request(req: &ChatRequest, model_name: &str) -> Option { return Some(bad_request("max_tokens must be greater than 0")); } + if let Some(t) = req.temperature { + if !t.is_finite() || t < 0.0 { + return Some(bad_request("temperature must be a finite value >= 0")); + } + } + if let Some(p) = req.top_p { + if !p.is_finite() || !(0.0..=1.0).contains(&p) { + return Some(bad_request("top_p must be in [0, 1]")); + } + } + if let Some(k) = req.top_k { + if k > 1_000_000 { + return Some(bad_request("top_k must be <= 1_000_000")); + } + } + None } @@ -453,9 +476,14 @@ fn submit_to_engine(state: &AppState, req: GenerateRequest) -> Result<(), Respon .engine_sender .lock() .unwrap_or_else(|e| e.into_inner()); - sender - .send(req) - .map_err(|_| service_unavailable("inference engine is not available")) + sender.try_send(req).map_err(|err| match err { + std::sync::mpsc::TrySendError::Full(_) => { + service_unavailable("inference engine is busy, retry later") + } + std::sync::mpsc::TrySendError::Disconnected(_) => { + service_unavailable("inference engine is not available") + } + }) } fn service_unavailable(message: impl Into) -> Response { @@ -532,3 +560,14 @@ fn sampling_params(req: &ChatRequest) -> SamplingParams { top_p: req.top_p.unwrap_or(1.0), } } + +/// Map engine finish_reason strings to OpenAI-standard values. Any engine-internal +/// code (e.g. "error" from tp/pp client-stall) collapses to None so SDK clients see +/// a clean null instead of an unknown value. +fn normalize_finish_reason(fr: &str) -> Option<&'static str> { + match fr { + "stop" => Some("stop"), + "length" => Some("length"), + _ => None, + } +} diff --git a/crates/xserv-server/src/main.rs b/crates/xserv-server/src/main.rs index 88e9ed0..2610a4c 100644 --- a/crates/xserv-server/src/main.rs +++ b/crates/xserv-server/src/main.rs @@ -5,6 +5,7 @@ mod tp_engine; use axum::{ Extension, Router, + extract::DefaultBodyLimit, routing::{get, post}, }; use engine::GenerateRequest; @@ -15,7 +16,7 @@ use xserv_model::ModelConfig; pub struct AppState { pub model_name: String, pub chat_template: api::ChatTemplate, - pub engine_sender: Mutex>, + pub engine_sender: Mutex>, pub engine_tokenizer: Mutex, pub max_seq_len: usize, } @@ -104,8 +105,10 @@ async fn main() { let tokenizer = xserv_tokenizer::Tokenizer::from_file(&model_dir.join("tokenizer.json")); - // Unbounded channel: allows multiple requests to queue up - let (tx, rx) = mpsc::channel::(); + // Bounded channel to backpressure incoming requests when the engine falls + // behind, instead of letting them pile up in RAM. try_send in the API + // handler surfaces this as 503 to the client. + let (tx, rx) = mpsc::sync_channel::(256); let model_dir_clone = model_dir.clone(); std::thread::spawn(move || { @@ -140,6 +143,7 @@ async fn main() { .route("/health", get(api::health)) .route("/v1/models", get(api::list_models)) .route("/v1/chat/completions", post(api::chat_completions)) + .layer(DefaultBodyLimit::max(4 * 1024 * 1024)) .layer(Extension(state)); let addr = format!("0.0.0.0:{port}");