198 lines
6.6 KiB
Rust
198 lines
6.6 KiB
Rust
//! Parse a HuggingFace `config.json` into [`ModelConfig`] fields.
|
|
//!
|
|
//! Handles common architectures: standard transformer, GQA, MoE, MLA
|
|
//! (Multi-head Latent Attention), and DSA (DeepSeek Sparse Attention).
|
|
|
|
use anyhow::{Context, Result};
|
|
use serde_json::Value;
|
|
use std::path::Path;
|
|
|
|
use crate::config::{AttentionConfig, MlaConfig, ModelConfig, MoeConfig};
|
|
|
|
/// Parse a HuggingFace config.json and return a partially-populated
|
|
/// [`ModelConfig`]. The caller must still set `dtype_bytes` and
|
|
/// `block_size_tokens` (not part of the HF schema).
|
|
pub fn parse(path: &Path) -> Result<ModelConfig> {
|
|
let raw = std::fs::read_to_string(path)
|
|
.with_context(|| format!("reading config.json at {}", path.display()))?;
|
|
let v: Value = serde_json::from_str(&raw)
|
|
.with_context(|| format!("parsing config.json at {}", path.display()))?;
|
|
parse_value(&v)
|
|
}
|
|
|
|
fn u32_field(v: &Value, key: &str) -> Option<u32> {
|
|
v.get(key).and_then(|x| x.as_u64()).map(|x| x as u32)
|
|
}
|
|
|
|
fn parse_value(v: &Value) -> Result<ModelConfig> {
|
|
let name = v
|
|
.get("model_type")
|
|
.and_then(|x| x.as_str())
|
|
.unwrap_or("unknown")
|
|
.to_string();
|
|
|
|
let num_layers = u32_field(v, "num_hidden_layers");
|
|
let hidden_size = u32_field(v, "hidden_size");
|
|
let num_attention_heads = u32_field(v, "num_attention_heads");
|
|
let num_kv_heads = u32_field(v, "num_key_value_heads").or(num_attention_heads); // default to MHA
|
|
let head_dim = u32_field(v, "head_dim").or_else(|| {
|
|
// Infer: hidden_size / num_attention_heads
|
|
match (hidden_size, num_attention_heads) {
|
|
(Some(h), Some(n)) if n > 0 => Some(h / n),
|
|
_ => None,
|
|
}
|
|
});
|
|
let intermediate_size = u32_field(v, "intermediate_size");
|
|
|
|
// --- MoE detection ---
|
|
let moe = u32_field(v, "n_routed_experts")
|
|
.or_else(|| u32_field(v, "num_local_experts"))
|
|
.or_else(|| u32_field(v, "num_experts"))
|
|
.map(|num_experts| MoeConfig {
|
|
num_experts,
|
|
num_active_experts: u32_field(v, "num_experts_per_tok")
|
|
.or_else(|| u32_field(v, "num_experts_per_topk"))
|
|
.unwrap_or(2),
|
|
num_shared_experts: u32_field(v, "n_shared_experts").unwrap_or(0),
|
|
expert_intermediate_size: u32_field(v, "moe_intermediate_size"),
|
|
});
|
|
|
|
// --- MLA detection (kv_lora_rank present → MLA) ---
|
|
let mla = u32_field(v, "kv_lora_rank").and_then(|kv_lora_rank| {
|
|
Some(MlaConfig {
|
|
kv_lora_rank,
|
|
q_lora_rank: u32_field(v, "q_lora_rank")?,
|
|
qk_nope_head_dim: u32_field(v, "qk_nope_head_dim")?,
|
|
qk_rope_head_dim: u32_field(v, "qk_rope_head_dim")?,
|
|
v_head_dim: u32_field(v, "v_head_dim")?,
|
|
})
|
|
});
|
|
|
|
// --- Attention pattern ---
|
|
let attention = if let Some(first_dense) = u32_field(v, "first_k_dense_replace") {
|
|
// DSA-style model (GLM-5, DeepSeek-V3).
|
|
// dense_window and sparse_stride are typically not in config.json;
|
|
// use sensible defaults the user can override in YAML.
|
|
Some(AttentionConfig::Dsa {
|
|
dense_window: 4096,
|
|
sparse_stride: 8,
|
|
first_dense_layers: first_dense,
|
|
})
|
|
} else if let Some(sw) = v
|
|
.get("sliding_window")
|
|
.and_then(|x| x.as_u64())
|
|
.map(|x| x as u32)
|
|
{
|
|
Some(AttentionConfig::SlidingWindow { window_size: sw })
|
|
} else {
|
|
None // dense by default
|
|
};
|
|
|
|
Ok(ModelConfig {
|
|
name,
|
|
num_layers: num_layers.unwrap_or(0),
|
|
num_kv_heads: num_kv_heads.unwrap_or(0),
|
|
head_dim: head_dim.unwrap_or(0),
|
|
hidden_size,
|
|
num_attention_heads,
|
|
intermediate_size,
|
|
moe,
|
|
mla,
|
|
attention,
|
|
// Deployment fields: must come from YAML
|
|
dtype_bytes: 0,
|
|
block_size_tokens: 0,
|
|
..Default::default()
|
|
})
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn parse_dense_model() {
|
|
let json = serde_json::json!({
|
|
"model_type": "qwen2",
|
|
"num_hidden_layers": 28,
|
|
"hidden_size": 3584,
|
|
"num_attention_heads": 28,
|
|
"num_key_value_heads": 4,
|
|
"intermediate_size": 18944,
|
|
});
|
|
let m = parse_value(&json).unwrap();
|
|
assert_eq!(m.num_layers, 28);
|
|
assert_eq!(m.hidden_size, Some(3584));
|
|
assert_eq!(m.num_kv_heads, 4);
|
|
assert_eq!(m.head_dim, 128); // 3584 / 28
|
|
assert!(m.moe.is_none());
|
|
assert!(m.mla.is_none());
|
|
assert!(m.attention.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn parse_qwen3_moe() {
|
|
let json = serde_json::json!({
|
|
"model_type": "qwen3_moe",
|
|
"num_hidden_layers": 62,
|
|
"hidden_size": 6144,
|
|
"num_attention_heads": 96,
|
|
"num_key_value_heads": 8,
|
|
"head_dim": 128,
|
|
"intermediate_size": 8192,
|
|
"num_experts": 160,
|
|
"num_experts_per_tok": 8,
|
|
"moe_intermediate_size": 2560,
|
|
});
|
|
let m = parse_value(&json).unwrap();
|
|
assert_eq!(m.num_layers, 62);
|
|
assert_eq!(m.num_kv_heads, 8);
|
|
assert_eq!(m.head_dim, 128);
|
|
let moe = m.moe.as_ref().unwrap();
|
|
assert_eq!(moe.num_experts, 160);
|
|
assert_eq!(moe.num_active_experts, 8);
|
|
assert_eq!(moe.expert_intermediate_size, Some(2560));
|
|
assert_eq!(moe.num_shared_experts, 0);
|
|
assert!(m.mla.is_none());
|
|
assert!(m.attention.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn parse_moe_mla_dsa() {
|
|
let json = serde_json::json!({
|
|
"model_type": "glm_moe_dsa",
|
|
"num_hidden_layers": 78,
|
|
"hidden_size": 6144,
|
|
"num_attention_heads": 64,
|
|
"num_key_value_heads": 64,
|
|
"head_dim": 64,
|
|
"intermediate_size": 12288,
|
|
"n_routed_experts": 256,
|
|
"num_experts_per_tok": 8,
|
|
"n_shared_experts": 1,
|
|
"moe_intermediate_size": 2048,
|
|
"kv_lora_rank": 512,
|
|
"q_lora_rank": 2048,
|
|
"qk_nope_head_dim": 192,
|
|
"qk_rope_head_dim": 64,
|
|
"v_head_dim": 256,
|
|
"first_k_dense_replace": 3,
|
|
});
|
|
let m = parse_value(&json).unwrap();
|
|
assert_eq!(m.num_layers, 78);
|
|
assert_eq!(m.head_dim, 64);
|
|
let moe = m.moe.as_ref().unwrap();
|
|
assert_eq!(moe.num_experts, 256);
|
|
assert_eq!(moe.num_active_experts, 8);
|
|
let mla = m.mla.as_ref().unwrap();
|
|
assert_eq!(mla.kv_lora_rank, 512);
|
|
assert!(matches!(
|
|
m.attention,
|
|
Some(AttentionConfig::Dsa {
|
|
first_dense_layers: 3,
|
|
..
|
|
})
|
|
));
|
|
}
|
|
}
|