server: sampling-param validation, finish_reason normalization, backpressure
Three related hardening changes for the API surface: - validate_request rejects NaN/negative temperature, out-of-range top_p, and absurd top_k before those values reach the CUDA sampling paths. Prevents NaN logits from downstream sampling and matches typical OpenAI-compatible server behavior (400 instead of 500). - normalize_finish_reason maps engine strings to the OpenAI-standard subset. Currently only "error" (from tp/pp engine client-stall) needs normalization — it collapses to null so SDK clients see a clean stream close instead of an unknown finish_reason value. Applied to both streaming (SSE) and non-streaming JSON responses. - Replace the unbounded std::sync::mpsc engine channel with a bounded sync_channel(256) and switch submit_to_engine to try_send. A saturated engine now returns 503 "engine is busy" instead of letting requests pile up in RAM. Also add axum DefaultBodyLimit(4 MiB) so a malicious or misbehaving client cannot exhaust memory with an arbitrary JSON POST.
This commit is contained in:
@@ -331,6 +331,10 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let fr_value = match normalize_finish_reason(&finish_reason) {
|
||||||
|
Some(s) => serde_json::Value::String(s.to_string()),
|
||||||
|
None => serde_json::Value::Null,
|
||||||
|
};
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"id": id,
|
"id": id,
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
@@ -339,7 +343,7 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
|||||||
"choices": [{
|
"choices": [{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"message": { "role": "assistant", "content": content },
|
"message": { "role": "assistant", "content": content },
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": fr_value,
|
||||||
}],
|
}],
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": prompt_token_count,
|
"prompt_tokens": prompt_token_count,
|
||||||
@@ -412,8 +416,11 @@ fn chat_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
|||||||
make_chunk(&id, &model_name, created, None, Some("assistant"), None);
|
make_chunk(&id, &model_name, created, None, Some("assistant"), None);
|
||||||
let _ = sse_tx.send(Ok(Event::default().data(chunk))).await;
|
let _ = sse_tx.send(Ok(Event::default().data(chunk))).await;
|
||||||
}
|
}
|
||||||
let chunk =
|
// Only "stop" and "length" are OpenAI-standard values. Internal
|
||||||
make_chunk(&id, &model_name, created, None, None, Some(&finish_reason));
|
// codes like "error" (client-stalled from tp/pp engine) map to
|
||||||
|
// null so SDK clients see a clean stream close.
|
||||||
|
let fr = normalize_finish_reason(&finish_reason);
|
||||||
|
let chunk = make_chunk(&id, &model_name, created, None, None, fr);
|
||||||
let _ = sse_tx.send(Ok(Event::default().data(chunk))).await;
|
let _ = sse_tx.send(Ok(Event::default().data(chunk))).await;
|
||||||
let _ = sse_tx
|
let _ = sse_tx
|
||||||
.send(Ok(Event::default().data("[DONE]".to_string())))
|
.send(Ok(Event::default().data("[DONE]".to_string())))
|
||||||
@@ -442,6 +449,22 @@ fn validate_request(req: &ChatRequest, model_name: &str) -> Option<Response> {
|
|||||||
return Some(bad_request("max_tokens must be greater than 0"));
|
return Some(bad_request("max_tokens must be greater than 0"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(t) = req.temperature {
|
||||||
|
if !t.is_finite() || t < 0.0 {
|
||||||
|
return Some(bad_request("temperature must be a finite value >= 0"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(p) = req.top_p {
|
||||||
|
if !p.is_finite() || !(0.0..=1.0).contains(&p) {
|
||||||
|
return Some(bad_request("top_p must be in [0, 1]"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(k) = req.top_k {
|
||||||
|
if k > 1_000_000 {
|
||||||
|
return Some(bad_request("top_k must be <= 1_000_000"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -453,9 +476,14 @@ fn submit_to_engine(state: &AppState, req: GenerateRequest) -> Result<(), Respon
|
|||||||
.engine_sender
|
.engine_sender
|
||||||
.lock()
|
.lock()
|
||||||
.unwrap_or_else(|e| e.into_inner());
|
.unwrap_or_else(|e| e.into_inner());
|
||||||
sender
|
sender.try_send(req).map_err(|err| match err {
|
||||||
.send(req)
|
std::sync::mpsc::TrySendError::Full(_) => {
|
||||||
.map_err(|_| service_unavailable("inference engine is not available"))
|
service_unavailable("inference engine is busy, retry later")
|
||||||
|
}
|
||||||
|
std::sync::mpsc::TrySendError::Disconnected(_) => {
|
||||||
|
service_unavailable("inference engine is not available")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn service_unavailable(message: impl Into<String>) -> Response {
|
fn service_unavailable(message: impl Into<String>) -> Response {
|
||||||
@@ -532,3 +560,14 @@ fn sampling_params(req: &ChatRequest) -> SamplingParams {
|
|||||||
top_p: req.top_p.unwrap_or(1.0),
|
top_p: req.top_p.unwrap_or(1.0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Map engine finish_reason strings to OpenAI-standard values. Any engine-internal
|
||||||
|
/// code (e.g. "error" from tp/pp client-stall) collapses to None so SDK clients see
|
||||||
|
/// a clean null instead of an unknown value.
|
||||||
|
fn normalize_finish_reason(fr: &str) -> Option<&'static str> {
|
||||||
|
match fr {
|
||||||
|
"stop" => Some("stop"),
|
||||||
|
"length" => Some("length"),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ mod tp_engine;
|
|||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
Extension, Router,
|
Extension, Router,
|
||||||
|
extract::DefaultBodyLimit,
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
};
|
};
|
||||||
use engine::GenerateRequest;
|
use engine::GenerateRequest;
|
||||||
@@ -15,7 +16,7 @@ use xserv_model::ModelConfig;
|
|||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub model_name: String,
|
pub model_name: String,
|
||||||
pub chat_template: api::ChatTemplate,
|
pub chat_template: api::ChatTemplate,
|
||||||
pub engine_sender: Mutex<mpsc::Sender<GenerateRequest>>,
|
pub engine_sender: Mutex<mpsc::SyncSender<GenerateRequest>>,
|
||||||
pub engine_tokenizer: Mutex<xserv_tokenizer::Tokenizer>,
|
pub engine_tokenizer: Mutex<xserv_tokenizer::Tokenizer>,
|
||||||
pub max_seq_len: usize,
|
pub max_seq_len: usize,
|
||||||
}
|
}
|
||||||
@@ -104,8 +105,10 @@ async fn main() {
|
|||||||
|
|
||||||
let tokenizer = xserv_tokenizer::Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
let tokenizer = xserv_tokenizer::Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||||
|
|
||||||
// Unbounded channel: allows multiple requests to queue up
|
// Bounded channel to backpressure incoming requests when the engine falls
|
||||||
let (tx, rx) = mpsc::channel::<GenerateRequest>();
|
// behind, instead of letting them pile up in RAM. try_send in the API
|
||||||
|
// handler surfaces this as 503 to the client.
|
||||||
|
let (tx, rx) = mpsc::sync_channel::<GenerateRequest>(256);
|
||||||
|
|
||||||
let model_dir_clone = model_dir.clone();
|
let model_dir_clone = model_dir.clone();
|
||||||
std::thread::spawn(move || {
|
std::thread::spawn(move || {
|
||||||
@@ -140,6 +143,7 @@ async fn main() {
|
|||||||
.route("/health", get(api::health))
|
.route("/health", get(api::health))
|
||||||
.route("/v1/models", get(api::list_models))
|
.route("/v1/models", get(api::list_models))
|
||||||
.route("/v1/chat/completions", post(api::chat_completions))
|
.route("/v1/chat/completions", post(api::chat_completions))
|
||||||
|
.layer(DefaultBodyLimit::max(4 * 1024 * 1024))
|
||||||
.layer(Extension(state));
|
.layer(Extension(state));
|
||||||
|
|
||||||
let addr = format!("0.0.0.0:{port}");
|
let addr = format!("0.0.0.0:{port}");
|
||||||
|
|||||||
Reference in New Issue
Block a user