server: Jinja chat template rendering via minijinja
Load the model's chat_template.jinja (or tokenizer_config.json chat_template field) at startup and render it with minijinja instead of hardcoded per-model prompt builders. Custom Jinja functions: strftime_now (date formatting), raise_exception (template validation errors). Falls back to Qwen3 ChatML template if no Jinja template is found. Removes the hardcoded build_prompt_gpt_oss() — the model's own template now drives prompt formatting, matching llama.cpp's behavior exactly. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -21,3 +21,4 @@ tokio.workspace = true
|
||||
axum.workspace = true
|
||||
uuid.workspace = true
|
||||
tokio-stream.workspace = true
|
||||
minijinja.workspace = true
|
||||
|
||||
@@ -5,6 +5,7 @@ use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::convert::Infallible;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
@@ -31,7 +32,7 @@ pub struct ChatRequest {
|
||||
pub top_p: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Serialize, Clone)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
@@ -54,6 +55,144 @@ pub struct ModelInfo {
|
||||
owned_by: &'static str,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Chat Template: Jinja2 rendering via minijinja
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct ChatTemplate {
|
||||
source: String,
|
||||
model_type: String,
|
||||
}
|
||||
|
||||
impl ChatTemplate {
|
||||
pub fn load(model_dir: &Path, model_type: &str) -> Self {
|
||||
// 1. Try standalone chat_template.jinja file
|
||||
let jinja_path = model_dir.join("chat_template.jinja");
|
||||
if jinja_path.exists() {
|
||||
let source = std::fs::read_to_string(&jinja_path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", jinja_path.display()));
|
||||
eprintln!("[chat-template] loaded from {}", jinja_path.display());
|
||||
return Self { source, model_type: model_type.to_string() };
|
||||
}
|
||||
|
||||
// 2. Try tokenizer_config.json → chat_template field
|
||||
let tok_cfg_path = model_dir.join("tokenizer_config.json");
|
||||
if tok_cfg_path.exists() {
|
||||
if let Ok(data) = std::fs::read_to_string(&tok_cfg_path) {
|
||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||
if let Some(ct) = v.get("chat_template").and_then(|v| v.as_str()) {
|
||||
eprintln!("[chat-template] loaded from tokenizer_config.json");
|
||||
return Self { source: ct.to_string(), model_type: model_type.to_string() };
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. No template found — use empty source, will fall back to hardcoded
|
||||
eprintln!("[chat-template] no Jinja template found, using hardcoded fallback");
|
||||
Self { source: String::new(), model_type: model_type.to_string() }
|
||||
}
|
||||
|
||||
pub fn render(&self, messages: &[Message]) -> String {
|
||||
if self.source.is_empty() {
|
||||
return build_prompt_hardcoded(messages, &self.model_type);
|
||||
}
|
||||
|
||||
match self.render_jinja(messages) {
|
||||
Ok(prompt) => prompt,
|
||||
Err(e) => {
|
||||
eprintln!("[chat-template] Jinja render error: {e}, falling back to hardcoded");
|
||||
build_prompt_hardcoded(messages, &self.model_type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn render_jinja(&self, messages: &[Message]) -> Result<String, minijinja::Error> {
|
||||
let mut env = minijinja::Environment::new();
|
||||
|
||||
// Register custom functions the template may call.
|
||||
env.add_function("strftime_now", strftime_now);
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
|
||||
env.add_template("chat", &self.source)?;
|
||||
let tmpl = env.get_template("chat")?;
|
||||
|
||||
let ctx = minijinja::context! {
|
||||
messages => minijinja::Value::from_serialize(messages),
|
||||
add_generation_prompt => true,
|
||||
bos_token => "",
|
||||
eos_token => "",
|
||||
};
|
||||
|
||||
tmpl.render(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
fn strftime_now(fmt: String) -> String {
|
||||
use std::time::SystemTime;
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
// Only support %Y-%m-%d (the only format used by known templates)
|
||||
let days = now / 86400;
|
||||
let (y, m, d) = days_to_ymd(days);
|
||||
fmt.replace("%Y", &format!("{y:04}"))
|
||||
.replace("%m", &format!("{m:02}"))
|
||||
.replace("%d", &format!("{d:02}"))
|
||||
}
|
||||
|
||||
fn days_to_ymd(days_since_epoch: u64) -> (u32, u32, u32) {
|
||||
// Civil calendar from days since 1970-01-01 (Rata Die algorithm)
|
||||
let z = days_since_epoch as i64 + 719468;
|
||||
let era = (if z >= 0 { z } else { z - 146096 }) / 146097;
|
||||
let doe = (z - era * 146097) as u32;
|
||||
let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
|
||||
let y = yoe as i64 + era * 400;
|
||||
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
|
||||
let mp = (5 * doy + 2) / 153;
|
||||
let d = doy - (153 * mp + 2) / 5 + 1;
|
||||
let m = if mp < 10 { mp + 3 } else { mp - 9 };
|
||||
let y = if m <= 2 { y + 1 } else { y };
|
||||
(y as u32, m, d)
|
||||
}
|
||||
|
||||
fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
|
||||
Err(minijinja::Error::new(
|
||||
minijinja::ErrorKind::InvalidOperation,
|
||||
msg,
|
||||
))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Hardcoded fallback templates (for models without a Jinja template)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn build_prompt_hardcoded(messages: &[Message], model_type: &str) -> String {
|
||||
// Default: Qwen3 ChatML format
|
||||
let _ = model_type;
|
||||
let mut prompt = String::new();
|
||||
for msg in messages {
|
||||
match msg.role.as_str() {
|
||||
"system" | "user" | "assistant" => {
|
||||
prompt.push_str("<|im_start|>");
|
||||
prompt.push_str(&msg.role);
|
||||
prompt.push('\n');
|
||||
prompt.push_str(&msg.content);
|
||||
prompt.push_str("<|im_end|>\n");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
prompt.push_str("<|im_start|>assistant\n");
|
||||
prompt.push_str("<think>\n\n</think>\n\n");
|
||||
prompt
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HTTP handlers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn health() -> &'static str {
|
||||
"ok"
|
||||
}
|
||||
@@ -89,7 +228,7 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
||||
return response;
|
||||
}
|
||||
|
||||
let prompt = build_prompt(&req.messages, &state.model_type);
|
||||
let prompt = state.chat_template.render(&req.messages);
|
||||
let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt);
|
||||
let prompt_token_count = prompt_tokens.len();
|
||||
|
||||
@@ -159,7 +298,7 @@ fn chat_stream(
|
||||
return response;
|
||||
}
|
||||
|
||||
let prompt = build_prompt(&req.messages, &state.model_type);
|
||||
let prompt = state.chat_template.render(&req.messages);
|
||||
let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt);
|
||||
|
||||
let max_seq_len = state.max_seq_len;
|
||||
@@ -324,64 +463,3 @@ fn sampling_params(req: &ChatRequest) -> SamplingParams {
|
||||
top_p: req.top_p.unwrap_or(1.0),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_prompt(messages: &[Message], model_type: &str) -> String {
|
||||
if model_type == "gpt_oss" {
|
||||
return build_prompt_gpt_oss(messages);
|
||||
}
|
||||
// Default: Qwen3 ChatML format
|
||||
let mut prompt = String::new();
|
||||
for msg in messages {
|
||||
match msg.role.as_str() {
|
||||
"system" | "user" | "assistant" => {
|
||||
prompt.push_str("<|im_start|>");
|
||||
prompt.push_str(&msg.role);
|
||||
prompt.push('\n');
|
||||
prompt.push_str(&msg.content);
|
||||
prompt.push_str("<|im_end|>\n");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
prompt.push_str("<|im_start|>assistant\n");
|
||||
prompt.push_str("<think>\n\n</think>\n\n");
|
||||
prompt
|
||||
}
|
||||
|
||||
fn build_prompt_gpt_oss(messages: &[Message]) -> String {
|
||||
let mut prompt = String::new();
|
||||
// System (meta) block: channel declaration required by harmony.
|
||||
prompt.push_str("<|start|>system<|message|>");
|
||||
prompt.push_str("You are a helpful assistant.\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.");
|
||||
prompt.push_str("<|end|>");
|
||||
// Caller-supplied system prompt(s) go in a developer instructions block
|
||||
// (harmony puts user instructions on the `developer` role, not `system`).
|
||||
let dev_instructions: String = messages
|
||||
.iter()
|
||||
.filter(|m| m.role == "system")
|
||||
.map(|m| m.content.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n");
|
||||
if !dev_instructions.is_empty() {
|
||||
prompt.push_str("<|start|>developer<|message|># Instructions\n\n");
|
||||
prompt.push_str(&dev_instructions);
|
||||
prompt.push_str("<|end|>");
|
||||
}
|
||||
for msg in messages {
|
||||
match msg.role.as_str() {
|
||||
"user" => {
|
||||
prompt.push_str("<|start|>user<|message|>");
|
||||
prompt.push_str(&msg.content);
|
||||
prompt.push_str("<|end|>");
|
||||
}
|
||||
"assistant" => {
|
||||
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
|
||||
prompt.push_str(&msg.content);
|
||||
prompt.push_str("<|end|>");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
prompt.push_str("<|start|>assistant");
|
||||
prompt
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ use xserv_model::ModelConfig;
|
||||
|
||||
pub struct AppState {
|
||||
pub model_name: String,
|
||||
pub model_type: String,
|
||||
pub chat_template: api::ChatTemplate,
|
||||
pub engine_sender: Mutex<mpsc::Sender<GenerateRequest>>,
|
||||
pub engine_tokenizer: Mutex<xserv_tokenizer::Tokenizer>,
|
||||
pub max_seq_len: usize,
|
||||
@@ -101,9 +101,10 @@ async fn main() {
|
||||
});
|
||||
|
||||
let model_type = model_config.model_type.clone().unwrap_or_default();
|
||||
let chat_template = api::ChatTemplate::load(&model_dir, &model_type);
|
||||
let state = Arc::new(AppState {
|
||||
model_name,
|
||||
model_type,
|
||||
chat_template,
|
||||
engine_sender: Mutex::new(tx),
|
||||
engine_tokenizer: Mutex::new(tokenizer),
|
||||
max_seq_len,
|
||||
|
||||
Reference in New Issue
Block a user