diff --git a/crates/xserv-server/src/api.rs b/crates/xserv-server/src/api.rs index 73eacc6..c7a7d70 100644 --- a/crates/xserv-server/src/api.rs +++ b/crates/xserv-server/src/api.rs @@ -331,6 +331,10 @@ async fn chat_non_stream(state: Arc, req: ChatRequest) -> Response { } } + let fr_value = match normalize_finish_reason(&finish_reason) { + Some(s) => serde_json::Value::String(s.to_string()), + None => serde_json::Value::Null, + }; Json(serde_json::json!({ "id": id, "object": "chat.completion", @@ -339,7 +343,7 @@ async fn chat_non_stream(state: Arc, req: ChatRequest) -> Response { "choices": [{ "index": 0, "message": { "role": "assistant", "content": content }, - "finish_reason": finish_reason, + "finish_reason": fr_value, }], "usage": { "prompt_tokens": prompt_token_count, @@ -412,8 +416,11 @@ fn chat_stream(state: Arc, req: ChatRequest) -> Response { make_chunk(&id, &model_name, created, None, Some("assistant"), None); let _ = sse_tx.send(Ok(Event::default().data(chunk))).await; } - let chunk = - make_chunk(&id, &model_name, created, None, None, Some(&finish_reason)); + // Only "stop" and "length" are OpenAI-standard values. Internal + // codes like "error" (client-stalled from tp/pp engine) map to + // null so SDK clients see a clean stream close. + let fr = normalize_finish_reason(&finish_reason); + let chunk = make_chunk(&id, &model_name, created, None, None, fr); let _ = sse_tx.send(Ok(Event::default().data(chunk))).await; let _ = sse_tx .send(Ok(Event::default().data("[DONE]".to_string()))) @@ -442,6 +449,22 @@ fn validate_request(req: &ChatRequest, model_name: &str) -> Option { return Some(bad_request("max_tokens must be greater than 0")); } + if let Some(t) = req.temperature { + if !t.is_finite() || t < 0.0 { + return Some(bad_request("temperature must be a finite value >= 0")); + } + } + if let Some(p) = req.top_p { + if !p.is_finite() || !(0.0..=1.0).contains(&p) { + return Some(bad_request("top_p must be in [0, 1]")); + } + } + if let Some(k) = req.top_k { + if k > 1_000_000 { + return Some(bad_request("top_k must be <= 1_000_000")); + } + } + None } @@ -453,9 +476,14 @@ fn submit_to_engine(state: &AppState, req: GenerateRequest) -> Result<(), Respon .engine_sender .lock() .unwrap_or_else(|e| e.into_inner()); - sender - .send(req) - .map_err(|_| service_unavailable("inference engine is not available")) + sender.try_send(req).map_err(|err| match err { + std::sync::mpsc::TrySendError::Full(_) => { + service_unavailable("inference engine is busy, retry later") + } + std::sync::mpsc::TrySendError::Disconnected(_) => { + service_unavailable("inference engine is not available") + } + }) } fn service_unavailable(message: impl Into) -> Response { @@ -532,3 +560,14 @@ fn sampling_params(req: &ChatRequest) -> SamplingParams { top_p: req.top_p.unwrap_or(1.0), } } + +/// Map engine finish_reason strings to OpenAI-standard values. Any engine-internal +/// code (e.g. "error" from tp/pp client-stall) collapses to None so SDK clients see +/// a clean null instead of an unknown value. +fn normalize_finish_reason(fr: &str) -> Option<&'static str> { + match fr { + "stop" => Some("stop"), + "length" => Some("length"), + _ => None, + } +} diff --git a/crates/xserv-server/src/main.rs b/crates/xserv-server/src/main.rs index 88e9ed0..2610a4c 100644 --- a/crates/xserv-server/src/main.rs +++ b/crates/xserv-server/src/main.rs @@ -5,6 +5,7 @@ mod tp_engine; use axum::{ Extension, Router, + extract::DefaultBodyLimit, routing::{get, post}, }; use engine::GenerateRequest; @@ -15,7 +16,7 @@ use xserv_model::ModelConfig; pub struct AppState { pub model_name: String, pub chat_template: api::ChatTemplate, - pub engine_sender: Mutex>, + pub engine_sender: Mutex>, pub engine_tokenizer: Mutex, pub max_seq_len: usize, } @@ -104,8 +105,10 @@ async fn main() { let tokenizer = xserv_tokenizer::Tokenizer::from_file(&model_dir.join("tokenizer.json")); - // Unbounded channel: allows multiple requests to queue up - let (tx, rx) = mpsc::channel::(); + // Bounded channel to backpressure incoming requests when the engine falls + // behind, instead of letting them pile up in RAM. try_send in the API + // handler surfaces this as 503 to the client. + let (tx, rx) = mpsc::sync_channel::(256); let model_dir_clone = model_dir.clone(); std::thread::spawn(move || { @@ -140,6 +143,7 @@ async fn main() { .route("/health", get(api::health)) .route("/v1/models", get(api::list_models)) .route("/v1/chat/completions", post(api::chat_completions)) + .layer(DefaultBodyLimit::max(4 * 1024 * 1024)) .layer(Extension(state)); let addr = format!("0.0.0.0:{port}");