server: Jinja chat template rendering via minijinja

Load the model's chat_template.jinja (or tokenizer_config.json
chat_template field) at startup and render it with minijinja instead of
hardcoded per-model prompt builders.

Custom Jinja functions: strftime_now (date formatting), raise_exception
(template validation errors).  Falls back to Qwen3 ChatML template if
no Jinja template is found.

Removes the hardcoded build_prompt_gpt_oss() — the model's own template
now drives prompt formatting, matching llama.cpp's behavior exactly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Gahow Wang
2026-05-31 13:23:18 +08:00
parent 4368e79695
commit 1d0ec32e8d
5 changed files with 175 additions and 66 deletions

View File

@@ -21,3 +21,4 @@ tokio.workspace = true
axum.workspace = true
uuid.workspace = true
tokio-stream.workspace = true
minijinja.workspace = true

View File

@@ -5,6 +5,7 @@ use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::path::Path;
use std::sync::Arc;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::ReceiverStream;
@@ -31,7 +32,7 @@ pub struct ChatRequest {
pub top_p: Option<f32>,
}
#[derive(Deserialize)]
#[derive(Deserialize, Serialize, Clone)]
pub struct Message {
pub role: String,
pub content: String,
@@ -54,6 +55,144 @@ pub struct ModelInfo {
owned_by: &'static str,
}
// ---------------------------------------------------------------------------
// Chat Template: Jinja2 rendering via minijinja
// ---------------------------------------------------------------------------
pub struct ChatTemplate {
source: String,
model_type: String,
}
impl ChatTemplate {
pub fn load(model_dir: &Path, model_type: &str) -> Self {
// 1. Try standalone chat_template.jinja file
let jinja_path = model_dir.join("chat_template.jinja");
if jinja_path.exists() {
let source = std::fs::read_to_string(&jinja_path)
.unwrap_or_else(|e| panic!("failed to read {}: {e}", jinja_path.display()));
eprintln!("[chat-template] loaded from {}", jinja_path.display());
return Self { source, model_type: model_type.to_string() };
}
// 2. Try tokenizer_config.json → chat_template field
let tok_cfg_path = model_dir.join("tokenizer_config.json");
if tok_cfg_path.exists() {
if let Ok(data) = std::fs::read_to_string(&tok_cfg_path) {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&data) {
if let Some(ct) = v.get("chat_template").and_then(|v| v.as_str()) {
eprintln!("[chat-template] loaded from tokenizer_config.json");
return Self { source: ct.to_string(), model_type: model_type.to_string() };
}
}
}
}
// 3. No template found — use empty source, will fall back to hardcoded
eprintln!("[chat-template] no Jinja template found, using hardcoded fallback");
Self { source: String::new(), model_type: model_type.to_string() }
}
pub fn render(&self, messages: &[Message]) -> String {
if self.source.is_empty() {
return build_prompt_hardcoded(messages, &self.model_type);
}
match self.render_jinja(messages) {
Ok(prompt) => prompt,
Err(e) => {
eprintln!("[chat-template] Jinja render error: {e}, falling back to hardcoded");
build_prompt_hardcoded(messages, &self.model_type)
}
}
}
fn render_jinja(&self, messages: &[Message]) -> Result<String, minijinja::Error> {
let mut env = minijinja::Environment::new();
// Register custom functions the template may call.
env.add_function("strftime_now", strftime_now);
env.add_function("raise_exception", raise_exception);
env.add_template("chat", &self.source)?;
let tmpl = env.get_template("chat")?;
let ctx = minijinja::context! {
messages => minijinja::Value::from_serialize(messages),
add_generation_prompt => true,
bos_token => "",
eos_token => "",
};
tmpl.render(ctx)
}
}
fn strftime_now(fmt: String) -> String {
use std::time::SystemTime;
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
// Only support %Y-%m-%d (the only format used by known templates)
let days = now / 86400;
let (y, m, d) = days_to_ymd(days);
fmt.replace("%Y", &format!("{y:04}"))
.replace("%m", &format!("{m:02}"))
.replace("%d", &format!("{d:02}"))
}
fn days_to_ymd(days_since_epoch: u64) -> (u32, u32, u32) {
// Civil calendar from days since 1970-01-01 (Rata Die algorithm)
let z = days_since_epoch as i64 + 719468;
let era = (if z >= 0 { z } else { z - 146096 }) / 146097;
let doe = (z - era * 146097) as u32;
let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
let y = yoe as i64 + era * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
let mp = (5 * doy + 2) / 153;
let d = doy - (153 * mp + 2) / 5 + 1;
let m = if mp < 10 { mp + 3 } else { mp - 9 };
let y = if m <= 2 { y + 1 } else { y };
(y as u32, m, d)
}
fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
Err(minijinja::Error::new(
minijinja::ErrorKind::InvalidOperation,
msg,
))
}
// ---------------------------------------------------------------------------
// Hardcoded fallback templates (for models without a Jinja template)
// ---------------------------------------------------------------------------
fn build_prompt_hardcoded(messages: &[Message], model_type: &str) -> String {
// Default: Qwen3 ChatML format
let _ = model_type;
let mut prompt = String::new();
for msg in messages {
match msg.role.as_str() {
"system" | "user" | "assistant" => {
prompt.push_str("<|im_start|>");
prompt.push_str(&msg.role);
prompt.push('\n');
prompt.push_str(&msg.content);
prompt.push_str("<|im_end|>\n");
}
_ => {}
}
}
prompt.push_str("<|im_start|>assistant\n");
prompt.push_str("<think>\n\n</think>\n\n");
prompt
}
// ---------------------------------------------------------------------------
// HTTP handlers
// ---------------------------------------------------------------------------
pub async fn health() -> &'static str {
"ok"
}
@@ -89,7 +228,7 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
return response;
}
let prompt = build_prompt(&req.messages, &state.model_type);
let prompt = state.chat_template.render(&req.messages);
let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt);
let prompt_token_count = prompt_tokens.len();
@@ -159,7 +298,7 @@ fn chat_stream(
return response;
}
let prompt = build_prompt(&req.messages, &state.model_type);
let prompt = state.chat_template.render(&req.messages);
let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt);
let max_seq_len = state.max_seq_len;
@@ -324,64 +463,3 @@ fn sampling_params(req: &ChatRequest) -> SamplingParams {
top_p: req.top_p.unwrap_or(1.0),
}
}
fn build_prompt(messages: &[Message], model_type: &str) -> String {
if model_type == "gpt_oss" {
return build_prompt_gpt_oss(messages);
}
// Default: Qwen3 ChatML format
let mut prompt = String::new();
for msg in messages {
match msg.role.as_str() {
"system" | "user" | "assistant" => {
prompt.push_str("<|im_start|>");
prompt.push_str(&msg.role);
prompt.push('\n');
prompt.push_str(&msg.content);
prompt.push_str("<|im_end|>\n");
}
_ => {}
}
}
prompt.push_str("<|im_start|>assistant\n");
prompt.push_str("<think>\n\n</think>\n\n");
prompt
}
fn build_prompt_gpt_oss(messages: &[Message]) -> String {
let mut prompt = String::new();
// System (meta) block: channel declaration required by harmony.
prompt.push_str("<|start|>system<|message|>");
prompt.push_str("You are a helpful assistant.\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.");
prompt.push_str("<|end|>");
// Caller-supplied system prompt(s) go in a developer instructions block
// (harmony puts user instructions on the `developer` role, not `system`).
let dev_instructions: String = messages
.iter()
.filter(|m| m.role == "system")
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n\n");
if !dev_instructions.is_empty() {
prompt.push_str("<|start|>developer<|message|># Instructions\n\n");
prompt.push_str(&dev_instructions);
prompt.push_str("<|end|>");
}
for msg in messages {
match msg.role.as_str() {
"user" => {
prompt.push_str("<|start|>user<|message|>");
prompt.push_str(&msg.content);
prompt.push_str("<|end|>");
}
"assistant" => {
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
prompt.push_str(&msg.content);
prompt.push_str("<|end|>");
}
_ => {}
}
}
prompt.push_str("<|start|>assistant");
prompt
}

View File

@@ -11,7 +11,7 @@ use xserv_model::ModelConfig;
pub struct AppState {
pub model_name: String,
pub model_type: String,
pub chat_template: api::ChatTemplate,
pub engine_sender: Mutex<mpsc::Sender<GenerateRequest>>,
pub engine_tokenizer: Mutex<xserv_tokenizer::Tokenizer>,
pub max_seq_len: usize,
@@ -101,9 +101,10 @@ async fn main() {
});
let model_type = model_config.model_type.clone().unwrap_or_default();
let chat_template = api::ChatTemplate::load(&model_dir, &model_type);
let state = Arc::new(AppState {
model_name,
model_type,
chat_template,
engine_sender: Mutex::new(tx),
engine_tokenizer: Mutex::new(tokenizer),
max_seq_len,