server: sampling-param validation, finish_reason normalization, backpressure

Three related hardening changes for the API surface:

- validate_request rejects NaN/negative temperature, out-of-range top_p,
  and absurd top_k before those values reach the CUDA sampling paths.
  Prevents NaN logits from downstream sampling and matches typical
  OpenAI-compatible server behavior (400 instead of 500).
- normalize_finish_reason maps engine strings to the OpenAI-standard
  subset. Currently only "error" (from tp/pp engine client-stall) needs
  normalization — it collapses to null so SDK clients see a clean stream
  close instead of an unknown finish_reason value. Applied to both
  streaming (SSE) and non-streaming JSON responses.
- Replace the unbounded std::sync::mpsc engine channel with a bounded
  sync_channel(256) and switch submit_to_engine to try_send. A saturated
  engine now returns 503 "engine is busy" instead of letting requests
  pile up in RAM. Also add axum DefaultBodyLimit(4 MiB) so a malicious
  or misbehaving client cannot exhaust memory with an arbitrary JSON POST.
This commit is contained in:
2026-07-01 15:13:24 +08:00
parent ce10e4a998
commit d96ee0766c
2 changed files with 52 additions and 9 deletions

View File

@@ -331,6 +331,10 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
}
}
let fr_value = match normalize_finish_reason(&finish_reason) {
Some(s) => serde_json::Value::String(s.to_string()),
None => serde_json::Value::Null,
};
Json(serde_json::json!({
"id": id,
"object": "chat.completion",
@@ -339,7 +343,7 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
"choices": [{
"index": 0,
"message": { "role": "assistant", "content": content },
"finish_reason": finish_reason,
"finish_reason": fr_value,
}],
"usage": {
"prompt_tokens": prompt_token_count,
@@ -412,8 +416,11 @@ fn chat_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
make_chunk(&id, &model_name, created, None, Some("assistant"), None);
let _ = sse_tx.send(Ok(Event::default().data(chunk))).await;
}
let chunk =
make_chunk(&id, &model_name, created, None, None, Some(&finish_reason));
// Only "stop" and "length" are OpenAI-standard values. Internal
// codes like "error" (client-stalled from tp/pp engine) map to
// null so SDK clients see a clean stream close.
let fr = normalize_finish_reason(&finish_reason);
let chunk = make_chunk(&id, &model_name, created, None, None, fr);
let _ = sse_tx.send(Ok(Event::default().data(chunk))).await;
let _ = sse_tx
.send(Ok(Event::default().data("[DONE]".to_string())))
@@ -442,6 +449,22 @@ fn validate_request(req: &ChatRequest, model_name: &str) -> Option<Response> {
return Some(bad_request("max_tokens must be greater than 0"));
}
if let Some(t) = req.temperature {
if !t.is_finite() || t < 0.0 {
return Some(bad_request("temperature must be a finite value >= 0"));
}
}
if let Some(p) = req.top_p {
if !p.is_finite() || !(0.0..=1.0).contains(&p) {
return Some(bad_request("top_p must be in [0, 1]"));
}
}
if let Some(k) = req.top_k {
if k > 1_000_000 {
return Some(bad_request("top_k must be <= 1_000_000"));
}
}
None
}
@@ -453,9 +476,14 @@ fn submit_to_engine(state: &AppState, req: GenerateRequest) -> Result<(), Respon
.engine_sender
.lock()
.unwrap_or_else(|e| e.into_inner());
sender
.send(req)
.map_err(|_| service_unavailable("inference engine is not available"))
sender.try_send(req).map_err(|err| match err {
std::sync::mpsc::TrySendError::Full(_) => {
service_unavailable("inference engine is busy, retry later")
}
std::sync::mpsc::TrySendError::Disconnected(_) => {
service_unavailable("inference engine is not available")
}
})
}
fn service_unavailable(message: impl Into<String>) -> Response {
@@ -532,3 +560,14 @@ fn sampling_params(req: &ChatRequest) -> SamplingParams {
top_p: req.top_p.unwrap_or(1.0),
}
}
/// Map engine finish_reason strings to OpenAI-standard values. Any engine-internal
/// code (e.g. "error" from tp/pp client-stall) collapses to None so SDK clients see
/// a clean null instead of an unknown value.
fn normalize_finish_reason(fr: &str) -> Option<&'static str> {
match fr {
"stop" => Some("stop"),
"length" => Some("length"),
_ => None,
}
}

View File

@@ -5,6 +5,7 @@ mod tp_engine;
use axum::{
Extension, Router,
extract::DefaultBodyLimit,
routing::{get, post},
};
use engine::GenerateRequest;
@@ -15,7 +16,7 @@ use xserv_model::ModelConfig;
pub struct AppState {
pub model_name: String,
pub chat_template: api::ChatTemplate,
pub engine_sender: Mutex<mpsc::Sender<GenerateRequest>>,
pub engine_sender: Mutex<mpsc::SyncSender<GenerateRequest>>,
pub engine_tokenizer: Mutex<xserv_tokenizer::Tokenizer>,
pub max_seq_len: usize,
}
@@ -104,8 +105,10 @@ async fn main() {
let tokenizer = xserv_tokenizer::Tokenizer::from_file(&model_dir.join("tokenizer.json"));
// Unbounded channel: allows multiple requests to queue up
let (tx, rx) = mpsc::channel::<GenerateRequest>();
// Bounded channel to backpressure incoming requests when the engine falls
// behind, instead of letting them pile up in RAM. try_send in the API
// handler surfaces this as 503 to the client.
let (tx, rx) = mpsc::sync_channel::<GenerateRequest>(256);
let model_dir_clone = model_dir.clone();
std::thread::spawn(move || {
@@ -140,6 +143,7 @@ async fn main() {
.route("/health", get(api::health))
.route("/v1/models", get(api::list_models))
.route("/v1/chat/completions", post(api::chat_completions))
.layer(DefaultBodyLimit::max(4 * 1024 * 1024))
.layer(Extension(state));
let addr = format!("0.0.0.0:{port}");