diff --git a/Cargo.lock b/Cargo.lock index a35fc09..95e4c24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -408,12 +408,28 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "memo-map" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" + [[package]] name = "mime" version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minijinja" +version = "2.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2929e494b2280e1e18959bb2e121da03347ae896896fdfaceaab43c88a02803f" +dependencies = [ + "memo-map", + "serde", +] + [[package]] name = "mio" version = "1.2.0" @@ -1097,6 +1113,14 @@ dependencies = [ "rand 0.9.4", ] +[[package]] +name = "xserv-distributed" +version = "0.1.0" +dependencies = [ + "half", + "xserv-cuda", +] + [[package]] name = "xserv-kernels" version = "0.1.0" @@ -1112,12 +1136,14 @@ name = "xserv-model" version = "0.1.0" dependencies = [ "half", + "libc", "rand 0.8.6", "safetensors", "serde", "serde_json", "smallvec", "xserv-cuda", + "xserv-distributed", "xserv-kernels", "xserv-tensor", "xserv-tokenizer", @@ -1129,12 +1155,14 @@ version = "0.1.0" dependencies = [ "axum", "half", + "minijinja", "serde", "serde_json", "tokio", "tokio-stream", "uuid", "xserv-cuda", + "xserv-distributed", "xserv-kernels", "xserv-model", "xserv-tensor", diff --git a/Cargo.toml b/Cargo.toml index 316d24f..20f9217 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,3 +28,4 @@ axum = "0.8" uuid = { version = "1", features = ["v4"] } tokio-stream = "0.1" rand = "0.8" +minijinja = { version = "2", features = ["builtins"] } diff --git a/crates/xserv-server/Cargo.toml b/crates/xserv-server/Cargo.toml index 40ed14a..9e6b088 100644 --- a/crates/xserv-server/Cargo.toml +++ b/crates/xserv-server/Cargo.toml @@ -21,3 +21,4 @@ tokio.workspace = true axum.workspace = true uuid.workspace = true tokio-stream.workspace = true +minijinja.workspace = true diff --git a/crates/xserv-server/src/api.rs b/crates/xserv-server/src/api.rs index 6a17d33..3026df5 100644 --- a/crates/xserv-server/src/api.rs +++ b/crates/xserv-server/src/api.rs @@ -5,6 +5,7 @@ use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::{IntoResponse, Response}; use serde::{Deserialize, Serialize}; use std::convert::Infallible; +use std::path::Path; use std::sync::Arc; use tokio_stream::StreamExt; use tokio_stream::wrappers::ReceiverStream; @@ -31,7 +32,7 @@ pub struct ChatRequest { pub top_p: Option, } -#[derive(Deserialize)] +#[derive(Deserialize, Serialize, Clone)] pub struct Message { pub role: String, pub content: String, @@ -54,6 +55,144 @@ pub struct ModelInfo { owned_by: &'static str, } +// --------------------------------------------------------------------------- +// Chat Template: Jinja2 rendering via minijinja +// --------------------------------------------------------------------------- + +pub struct ChatTemplate { + source: String, + model_type: String, +} + +impl ChatTemplate { + pub fn load(model_dir: &Path, model_type: &str) -> Self { + // 1. Try standalone chat_template.jinja file + let jinja_path = model_dir.join("chat_template.jinja"); + if jinja_path.exists() { + let source = std::fs::read_to_string(&jinja_path) + .unwrap_or_else(|e| panic!("failed to read {}: {e}", jinja_path.display())); + eprintln!("[chat-template] loaded from {}", jinja_path.display()); + return Self { source, model_type: model_type.to_string() }; + } + + // 2. Try tokenizer_config.json → chat_template field + let tok_cfg_path = model_dir.join("tokenizer_config.json"); + if tok_cfg_path.exists() { + if let Ok(data) = std::fs::read_to_string(&tok_cfg_path) { + if let Ok(v) = serde_json::from_str::(&data) { + if let Some(ct) = v.get("chat_template").and_then(|v| v.as_str()) { + eprintln!("[chat-template] loaded from tokenizer_config.json"); + return Self { source: ct.to_string(), model_type: model_type.to_string() }; + } + } + } + } + + // 3. No template found — use empty source, will fall back to hardcoded + eprintln!("[chat-template] no Jinja template found, using hardcoded fallback"); + Self { source: String::new(), model_type: model_type.to_string() } + } + + pub fn render(&self, messages: &[Message]) -> String { + if self.source.is_empty() { + return build_prompt_hardcoded(messages, &self.model_type); + } + + match self.render_jinja(messages) { + Ok(prompt) => prompt, + Err(e) => { + eprintln!("[chat-template] Jinja render error: {e}, falling back to hardcoded"); + build_prompt_hardcoded(messages, &self.model_type) + } + } + } + + fn render_jinja(&self, messages: &[Message]) -> Result { + let mut env = minijinja::Environment::new(); + + // Register custom functions the template may call. + env.add_function("strftime_now", strftime_now); + env.add_function("raise_exception", raise_exception); + + env.add_template("chat", &self.source)?; + let tmpl = env.get_template("chat")?; + + let ctx = minijinja::context! { + messages => minijinja::Value::from_serialize(messages), + add_generation_prompt => true, + bos_token => "", + eos_token => "", + }; + + tmpl.render(ctx) + } +} + +fn strftime_now(fmt: String) -> String { + use std::time::SystemTime; + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(); + // Only support %Y-%m-%d (the only format used by known templates) + let days = now / 86400; + let (y, m, d) = days_to_ymd(days); + fmt.replace("%Y", &format!("{y:04}")) + .replace("%m", &format!("{m:02}")) + .replace("%d", &format!("{d:02}")) +} + +fn days_to_ymd(days_since_epoch: u64) -> (u32, u32, u32) { + // Civil calendar from days since 1970-01-01 (Rata Die algorithm) + let z = days_since_epoch as i64 + 719468; + let era = (if z >= 0 { z } else { z - 146096 }) / 146097; + let doe = (z - era * 146097) as u32; + let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; + let y = yoe as i64 + era * 400; + let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); + let mp = (5 * doy + 2) / 153; + let d = doy - (153 * mp + 2) / 5 + 1; + let m = if mp < 10 { mp + 3 } else { mp - 9 }; + let y = if m <= 2 { y + 1 } else { y }; + (y as u32, m, d) +} + +fn raise_exception(msg: String) -> Result { + Err(minijinja::Error::new( + minijinja::ErrorKind::InvalidOperation, + msg, + )) +} + +// --------------------------------------------------------------------------- +// Hardcoded fallback templates (for models without a Jinja template) +// --------------------------------------------------------------------------- + +fn build_prompt_hardcoded(messages: &[Message], model_type: &str) -> String { + // Default: Qwen3 ChatML format + let _ = model_type; + let mut prompt = String::new(); + for msg in messages { + match msg.role.as_str() { + "system" | "user" | "assistant" => { + prompt.push_str("<|im_start|>"); + prompt.push_str(&msg.role); + prompt.push('\n'); + prompt.push_str(&msg.content); + prompt.push_str("<|im_end|>\n"); + } + _ => {} + } + } + prompt.push_str("<|im_start|>assistant\n"); + prompt.push_str("\n\n\n\n"); + prompt +} + +// --------------------------------------------------------------------------- +// HTTP handlers +// --------------------------------------------------------------------------- + pub async fn health() -> &'static str { "ok" } @@ -89,7 +228,7 @@ async fn chat_non_stream(state: Arc, req: ChatRequest) -> Response { return response; } - let prompt = build_prompt(&req.messages, &state.model_type); + let prompt = state.chat_template.render(&req.messages); let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt); let prompt_token_count = prompt_tokens.len(); @@ -159,7 +298,7 @@ fn chat_stream( return response; } - let prompt = build_prompt(&req.messages, &state.model_type); + let prompt = state.chat_template.render(&req.messages); let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt); let max_seq_len = state.max_seq_len; @@ -324,64 +463,3 @@ fn sampling_params(req: &ChatRequest) -> SamplingParams { top_p: req.top_p.unwrap_or(1.0), } } - -fn build_prompt(messages: &[Message], model_type: &str) -> String { - if model_type == "gpt_oss" { - return build_prompt_gpt_oss(messages); - } - // Default: Qwen3 ChatML format - let mut prompt = String::new(); - for msg in messages { - match msg.role.as_str() { - "system" | "user" | "assistant" => { - prompt.push_str("<|im_start|>"); - prompt.push_str(&msg.role); - prompt.push('\n'); - prompt.push_str(&msg.content); - prompt.push_str("<|im_end|>\n"); - } - _ => {} - } - } - prompt.push_str("<|im_start|>assistant\n"); - prompt.push_str("\n\n\n\n"); - prompt -} - -fn build_prompt_gpt_oss(messages: &[Message]) -> String { - let mut prompt = String::new(); - // System (meta) block: channel declaration required by harmony. - prompt.push_str("<|start|>system<|message|>"); - prompt.push_str("You are a helpful assistant.\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message."); - prompt.push_str("<|end|>"); - // Caller-supplied system prompt(s) go in a developer instructions block - // (harmony puts user instructions on the `developer` role, not `system`). - let dev_instructions: String = messages - .iter() - .filter(|m| m.role == "system") - .map(|m| m.content.as_str()) - .collect::>() - .join("\n\n"); - if !dev_instructions.is_empty() { - prompt.push_str("<|start|>developer<|message|># Instructions\n\n"); - prompt.push_str(&dev_instructions); - prompt.push_str("<|end|>"); - } - for msg in messages { - match msg.role.as_str() { - "user" => { - prompt.push_str("<|start|>user<|message|>"); - prompt.push_str(&msg.content); - prompt.push_str("<|end|>"); - } - "assistant" => { - prompt.push_str("<|start|>assistant<|channel|>final<|message|>"); - prompt.push_str(&msg.content); - prompt.push_str("<|end|>"); - } - _ => {} - } - } - prompt.push_str("<|start|>assistant"); - prompt -} diff --git a/crates/xserv-server/src/main.rs b/crates/xserv-server/src/main.rs index 7dd78b7..aee0d53 100644 --- a/crates/xserv-server/src/main.rs +++ b/crates/xserv-server/src/main.rs @@ -11,7 +11,7 @@ use xserv_model::ModelConfig; pub struct AppState { pub model_name: String, - pub model_type: String, + pub chat_template: api::ChatTemplate, pub engine_sender: Mutex>, pub engine_tokenizer: Mutex, pub max_seq_len: usize, @@ -101,9 +101,10 @@ async fn main() { }); let model_type = model_config.model_type.clone().unwrap_or_default(); + let chat_template = api::ChatTemplate::load(&model_dir, &model_type); let state = Arc::new(AppState { model_name, - model_type, + chat_template, engine_sender: Mutex::new(tx), engine_tokenizer: Mutex::new(tokenizer), max_seq_len,