server: Jinja chat template rendering via minijinja
Load the model's chat_template.jinja (or tokenizer_config.json chat_template field) at startup and render it with minijinja instead of hardcoded per-model prompt builders. Custom Jinja functions: strftime_now (date formatting), raise_exception (template validation errors). Falls back to Qwen3 ChatML template if no Jinja template is found. Removes the hardcoded build_prompt_gpt_oss() — the model's own template now drives prompt formatting, matching llama.cpp's behavior exactly. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
28
Cargo.lock
generated
28
Cargo.lock
generated
@@ -408,12 +408,28 @@ version = "2.8.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
|
checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "memo-map"
|
||||||
|
version = "0.3.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mime"
|
name = "mime"
|
||||||
version = "0.3.17"
|
version = "0.3.17"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
|
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "minijinja"
|
||||||
|
version = "2.20.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2929e494b2280e1e18959bb2e121da03347ae896896fdfaceaab43c88a02803f"
|
||||||
|
dependencies = [
|
||||||
|
"memo-map",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mio"
|
name = "mio"
|
||||||
version = "1.2.0"
|
version = "1.2.0"
|
||||||
@@ -1097,6 +1113,14 @@ dependencies = [
|
|||||||
"rand 0.9.4",
|
"rand 0.9.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "xserv-distributed"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"half",
|
||||||
|
"xserv-cuda",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "xserv-kernels"
|
name = "xserv-kernels"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@@ -1112,12 +1136,14 @@ name = "xserv-model"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"half",
|
"half",
|
||||||
|
"libc",
|
||||||
"rand 0.8.6",
|
"rand 0.8.6",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"smallvec",
|
"smallvec",
|
||||||
"xserv-cuda",
|
"xserv-cuda",
|
||||||
|
"xserv-distributed",
|
||||||
"xserv-kernels",
|
"xserv-kernels",
|
||||||
"xserv-tensor",
|
"xserv-tensor",
|
||||||
"xserv-tokenizer",
|
"xserv-tokenizer",
|
||||||
@@ -1129,12 +1155,14 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"axum",
|
"axum",
|
||||||
"half",
|
"half",
|
||||||
|
"minijinja",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"uuid",
|
"uuid",
|
||||||
"xserv-cuda",
|
"xserv-cuda",
|
||||||
|
"xserv-distributed",
|
||||||
"xserv-kernels",
|
"xserv-kernels",
|
||||||
"xserv-model",
|
"xserv-model",
|
||||||
"xserv-tensor",
|
"xserv-tensor",
|
||||||
|
|||||||
@@ -28,3 +28,4 @@ axum = "0.8"
|
|||||||
uuid = { version = "1", features = ["v4"] }
|
uuid = { version = "1", features = ["v4"] }
|
||||||
tokio-stream = "0.1"
|
tokio-stream = "0.1"
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
|
minijinja = { version = "2", features = ["builtins"] }
|
||||||
|
|||||||
@@ -21,3 +21,4 @@ tokio.workspace = true
|
|||||||
axum.workspace = true
|
axum.workspace = true
|
||||||
uuid.workspace = true
|
uuid.workspace = true
|
||||||
tokio-stream.workspace = true
|
tokio-stream.workspace = true
|
||||||
|
minijinja.workspace = true
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ use axum::response::sse::{Event, KeepAlive, Sse};
|
|||||||
use axum::response::{IntoResponse, Response};
|
use axum::response::{IntoResponse, Response};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
|
use std::path::Path;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
@@ -31,7 +32,7 @@ pub struct ChatRequest {
|
|||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize, Serialize, Clone)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
pub role: String,
|
pub role: String,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
@@ -54,6 +55,144 @@ pub struct ModelInfo {
|
|||||||
owned_by: &'static str,
|
owned_by: &'static str,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Chat Template: Jinja2 rendering via minijinja
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
pub struct ChatTemplate {
|
||||||
|
source: String,
|
||||||
|
model_type: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatTemplate {
|
||||||
|
pub fn load(model_dir: &Path, model_type: &str) -> Self {
|
||||||
|
// 1. Try standalone chat_template.jinja file
|
||||||
|
let jinja_path = model_dir.join("chat_template.jinja");
|
||||||
|
if jinja_path.exists() {
|
||||||
|
let source = std::fs::read_to_string(&jinja_path)
|
||||||
|
.unwrap_or_else(|e| panic!("failed to read {}: {e}", jinja_path.display()));
|
||||||
|
eprintln!("[chat-template] loaded from {}", jinja_path.display());
|
||||||
|
return Self { source, model_type: model_type.to_string() };
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Try tokenizer_config.json → chat_template field
|
||||||
|
let tok_cfg_path = model_dir.join("tokenizer_config.json");
|
||||||
|
if tok_cfg_path.exists() {
|
||||||
|
if let Ok(data) = std::fs::read_to_string(&tok_cfg_path) {
|
||||||
|
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||||
|
if let Some(ct) = v.get("chat_template").and_then(|v| v.as_str()) {
|
||||||
|
eprintln!("[chat-template] loaded from tokenizer_config.json");
|
||||||
|
return Self { source: ct.to_string(), model_type: model_type.to_string() };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. No template found — use empty source, will fall back to hardcoded
|
||||||
|
eprintln!("[chat-template] no Jinja template found, using hardcoded fallback");
|
||||||
|
Self { source: String::new(), model_type: model_type.to_string() }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn render(&self, messages: &[Message]) -> String {
|
||||||
|
if self.source.is_empty() {
|
||||||
|
return build_prompt_hardcoded(messages, &self.model_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.render_jinja(messages) {
|
||||||
|
Ok(prompt) => prompt,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("[chat-template] Jinja render error: {e}, falling back to hardcoded");
|
||||||
|
build_prompt_hardcoded(messages, &self.model_type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_jinja(&self, messages: &[Message]) -> Result<String, minijinja::Error> {
|
||||||
|
let mut env = minijinja::Environment::new();
|
||||||
|
|
||||||
|
// Register custom functions the template may call.
|
||||||
|
env.add_function("strftime_now", strftime_now);
|
||||||
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
|
||||||
|
env.add_template("chat", &self.source)?;
|
||||||
|
let tmpl = env.get_template("chat")?;
|
||||||
|
|
||||||
|
let ctx = minijinja::context! {
|
||||||
|
messages => minijinja::Value::from_serialize(messages),
|
||||||
|
add_generation_prompt => true,
|
||||||
|
bos_token => "",
|
||||||
|
eos_token => "",
|
||||||
|
};
|
||||||
|
|
||||||
|
tmpl.render(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn strftime_now(fmt: String) -> String {
|
||||||
|
use std::time::SystemTime;
|
||||||
|
let now = SystemTime::now()
|
||||||
|
.duration_since(SystemTime::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs();
|
||||||
|
// Only support %Y-%m-%d (the only format used by known templates)
|
||||||
|
let days = now / 86400;
|
||||||
|
let (y, m, d) = days_to_ymd(days);
|
||||||
|
fmt.replace("%Y", &format!("{y:04}"))
|
||||||
|
.replace("%m", &format!("{m:02}"))
|
||||||
|
.replace("%d", &format!("{d:02}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn days_to_ymd(days_since_epoch: u64) -> (u32, u32, u32) {
|
||||||
|
// Civil calendar from days since 1970-01-01 (Rata Die algorithm)
|
||||||
|
let z = days_since_epoch as i64 + 719468;
|
||||||
|
let era = (if z >= 0 { z } else { z - 146096 }) / 146097;
|
||||||
|
let doe = (z - era * 146097) as u32;
|
||||||
|
let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
|
||||||
|
let y = yoe as i64 + era * 400;
|
||||||
|
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
|
||||||
|
let mp = (5 * doy + 2) / 153;
|
||||||
|
let d = doy - (153 * mp + 2) / 5 + 1;
|
||||||
|
let m = if mp < 10 { mp + 3 } else { mp - 9 };
|
||||||
|
let y = if m <= 2 { y + 1 } else { y };
|
||||||
|
(y as u32, m, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
|
||||||
|
Err(minijinja::Error::new(
|
||||||
|
minijinja::ErrorKind::InvalidOperation,
|
||||||
|
msg,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Hardcoded fallback templates (for models without a Jinja template)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn build_prompt_hardcoded(messages: &[Message], model_type: &str) -> String {
|
||||||
|
// Default: Qwen3 ChatML format
|
||||||
|
let _ = model_type;
|
||||||
|
let mut prompt = String::new();
|
||||||
|
for msg in messages {
|
||||||
|
match msg.role.as_str() {
|
||||||
|
"system" | "user" | "assistant" => {
|
||||||
|
prompt.push_str("<|im_start|>");
|
||||||
|
prompt.push_str(&msg.role);
|
||||||
|
prompt.push('\n');
|
||||||
|
prompt.push_str(&msg.content);
|
||||||
|
prompt.push_str("<|im_end|>\n");
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prompt.push_str("<|im_start|>assistant\n");
|
||||||
|
prompt.push_str("<think>\n\n</think>\n\n");
|
||||||
|
prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HTTP handlers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
pub async fn health() -> &'static str {
|
pub async fn health() -> &'static str {
|
||||||
"ok"
|
"ok"
|
||||||
}
|
}
|
||||||
@@ -89,7 +228,7 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
|||||||
return response;
|
return response;
|
||||||
}
|
}
|
||||||
|
|
||||||
let prompt = build_prompt(&req.messages, &state.model_type);
|
let prompt = state.chat_template.render(&req.messages);
|
||||||
let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt);
|
let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt);
|
||||||
let prompt_token_count = prompt_tokens.len();
|
let prompt_token_count = prompt_tokens.len();
|
||||||
|
|
||||||
@@ -159,7 +298,7 @@ fn chat_stream(
|
|||||||
return response;
|
return response;
|
||||||
}
|
}
|
||||||
|
|
||||||
let prompt = build_prompt(&req.messages, &state.model_type);
|
let prompt = state.chat_template.render(&req.messages);
|
||||||
let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt);
|
let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt);
|
||||||
|
|
||||||
let max_seq_len = state.max_seq_len;
|
let max_seq_len = state.max_seq_len;
|
||||||
@@ -324,64 +463,3 @@ fn sampling_params(req: &ChatRequest) -> SamplingParams {
|
|||||||
top_p: req.top_p.unwrap_or(1.0),
|
top_p: req.top_p.unwrap_or(1.0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_prompt(messages: &[Message], model_type: &str) -> String {
|
|
||||||
if model_type == "gpt_oss" {
|
|
||||||
return build_prompt_gpt_oss(messages);
|
|
||||||
}
|
|
||||||
// Default: Qwen3 ChatML format
|
|
||||||
let mut prompt = String::new();
|
|
||||||
for msg in messages {
|
|
||||||
match msg.role.as_str() {
|
|
||||||
"system" | "user" | "assistant" => {
|
|
||||||
prompt.push_str("<|im_start|>");
|
|
||||||
prompt.push_str(&msg.role);
|
|
||||||
prompt.push('\n');
|
|
||||||
prompt.push_str(&msg.content);
|
|
||||||
prompt.push_str("<|im_end|>\n");
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
prompt.push_str("<|im_start|>assistant\n");
|
|
||||||
prompt.push_str("<think>\n\n</think>\n\n");
|
|
||||||
prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_prompt_gpt_oss(messages: &[Message]) -> String {
|
|
||||||
let mut prompt = String::new();
|
|
||||||
// System (meta) block: channel declaration required by harmony.
|
|
||||||
prompt.push_str("<|start|>system<|message|>");
|
|
||||||
prompt.push_str("You are a helpful assistant.\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.");
|
|
||||||
prompt.push_str("<|end|>");
|
|
||||||
// Caller-supplied system prompt(s) go in a developer instructions block
|
|
||||||
// (harmony puts user instructions on the `developer` role, not `system`).
|
|
||||||
let dev_instructions: String = messages
|
|
||||||
.iter()
|
|
||||||
.filter(|m| m.role == "system")
|
|
||||||
.map(|m| m.content.as_str())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("\n\n");
|
|
||||||
if !dev_instructions.is_empty() {
|
|
||||||
prompt.push_str("<|start|>developer<|message|># Instructions\n\n");
|
|
||||||
prompt.push_str(&dev_instructions);
|
|
||||||
prompt.push_str("<|end|>");
|
|
||||||
}
|
|
||||||
for msg in messages {
|
|
||||||
match msg.role.as_str() {
|
|
||||||
"user" => {
|
|
||||||
prompt.push_str("<|start|>user<|message|>");
|
|
||||||
prompt.push_str(&msg.content);
|
|
||||||
prompt.push_str("<|end|>");
|
|
||||||
}
|
|
||||||
"assistant" => {
|
|
||||||
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
|
|
||||||
prompt.push_str(&msg.content);
|
|
||||||
prompt.push_str("<|end|>");
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
prompt.push_str("<|start|>assistant");
|
|
||||||
prompt
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ use xserv_model::ModelConfig;
|
|||||||
|
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub model_name: String,
|
pub model_name: String,
|
||||||
pub model_type: String,
|
pub chat_template: api::ChatTemplate,
|
||||||
pub engine_sender: Mutex<mpsc::Sender<GenerateRequest>>,
|
pub engine_sender: Mutex<mpsc::Sender<GenerateRequest>>,
|
||||||
pub engine_tokenizer: Mutex<xserv_tokenizer::Tokenizer>,
|
pub engine_tokenizer: Mutex<xserv_tokenizer::Tokenizer>,
|
||||||
pub max_seq_len: usize,
|
pub max_seq_len: usize,
|
||||||
@@ -101,9 +101,10 @@ async fn main() {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let model_type = model_config.model_type.clone().unwrap_or_default();
|
let model_type = model_config.model_type.clone().unwrap_or_default();
|
||||||
|
let chat_template = api::ChatTemplate::load(&model_dir, &model_type);
|
||||||
let state = Arc::new(AppState {
|
let state = Arc::new(AppState {
|
||||||
model_name,
|
model_name,
|
||||||
model_type,
|
chat_template,
|
||||||
engine_sender: Mutex::new(tx),
|
engine_sender: Mutex::new(tx),
|
||||||
engine_tokenizer: Mutex::new(tokenizer),
|
engine_tokenizer: Mutex::new(tokenizer),
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
|
|||||||
Reference in New Issue
Block a user