Files
kvcache-simulator/src/hf_config.rs
2026-04-15 17:31:39 +08:00

198 lines
6.6 KiB
Rust

//! Parse a HuggingFace `config.json` into [`ModelConfig`] fields.
//!
//! Handles common architectures: standard transformer, GQA, MoE, MLA
//! (Multi-head Latent Attention), and DSA (DeepSeek Sparse Attention).
use anyhow::{Context, Result};
use serde_json::Value;
use std::path::Path;
use crate::config::{AttentionConfig, MlaConfig, ModelConfig, MoeConfig};
/// Parse a HuggingFace config.json and return a partially-populated
/// [`ModelConfig`]. The caller must still set `dtype_bytes` and
/// `block_size_tokens` (not part of the HF schema).
pub fn parse(path: &Path) -> Result<ModelConfig> {
let raw = std::fs::read_to_string(path)
.with_context(|| format!("reading config.json at {}", path.display()))?;
let v: Value = serde_json::from_str(&raw)
.with_context(|| format!("parsing config.json at {}", path.display()))?;
parse_value(&v)
}
fn u32_field(v: &Value, key: &str) -> Option<u32> {
v.get(key).and_then(|x| x.as_u64()).map(|x| x as u32)
}
fn parse_value(v: &Value) -> Result<ModelConfig> {
let name = v
.get("model_type")
.and_then(|x| x.as_str())
.unwrap_or("unknown")
.to_string();
let num_layers = u32_field(v, "num_hidden_layers");
let hidden_size = u32_field(v, "hidden_size");
let num_attention_heads = u32_field(v, "num_attention_heads");
let num_kv_heads = u32_field(v, "num_key_value_heads").or(num_attention_heads); // default to MHA
let head_dim = u32_field(v, "head_dim").or_else(|| {
// Infer: hidden_size / num_attention_heads
match (hidden_size, num_attention_heads) {
(Some(h), Some(n)) if n > 0 => Some(h / n),
_ => None,
}
});
let intermediate_size = u32_field(v, "intermediate_size");
// --- MoE detection ---
let moe = u32_field(v, "n_routed_experts")
.or_else(|| u32_field(v, "num_local_experts"))
.or_else(|| u32_field(v, "num_experts"))
.map(|num_experts| MoeConfig {
num_experts,
num_active_experts: u32_field(v, "num_experts_per_tok")
.or_else(|| u32_field(v, "num_experts_per_topk"))
.unwrap_or(2),
num_shared_experts: u32_field(v, "n_shared_experts").unwrap_or(0),
expert_intermediate_size: u32_field(v, "moe_intermediate_size"),
});
// --- MLA detection (kv_lora_rank present → MLA) ---
let mla = u32_field(v, "kv_lora_rank").and_then(|kv_lora_rank| {
Some(MlaConfig {
kv_lora_rank,
q_lora_rank: u32_field(v, "q_lora_rank")?,
qk_nope_head_dim: u32_field(v, "qk_nope_head_dim")?,
qk_rope_head_dim: u32_field(v, "qk_rope_head_dim")?,
v_head_dim: u32_field(v, "v_head_dim")?,
})
});
// --- Attention pattern ---
let attention = if let Some(first_dense) = u32_field(v, "first_k_dense_replace") {
// DSA-style model (GLM-5, DeepSeek-V3).
// dense_window and sparse_stride are typically not in config.json;
// use sensible defaults the user can override in YAML.
Some(AttentionConfig::Dsa {
dense_window: 4096,
sparse_stride: 8,
first_dense_layers: first_dense,
})
} else if let Some(sw) = v
.get("sliding_window")
.and_then(|x| x.as_u64())
.map(|x| x as u32)
{
Some(AttentionConfig::SlidingWindow { window_size: sw })
} else {
None // dense by default
};
Ok(ModelConfig {
name,
num_layers: num_layers.unwrap_or(0),
num_kv_heads: num_kv_heads.unwrap_or(0),
head_dim: head_dim.unwrap_or(0),
hidden_size,
num_attention_heads,
intermediate_size,
moe,
mla,
attention,
// Deployment fields: must come from YAML
dtype_bytes: 0,
block_size_tokens: 0,
..Default::default()
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_dense_model() {
let json = serde_json::json!({
"model_type": "qwen2",
"num_hidden_layers": 28,
"hidden_size": 3584,
"num_attention_heads": 28,
"num_key_value_heads": 4,
"intermediate_size": 18944,
});
let m = parse_value(&json).unwrap();
assert_eq!(m.num_layers, 28);
assert_eq!(m.hidden_size, Some(3584));
assert_eq!(m.num_kv_heads, 4);
assert_eq!(m.head_dim, 128); // 3584 / 28
assert!(m.moe.is_none());
assert!(m.mla.is_none());
assert!(m.attention.is_none());
}
#[test]
fn parse_qwen3_moe() {
let json = serde_json::json!({
"model_type": "qwen3_moe",
"num_hidden_layers": 62,
"hidden_size": 6144,
"num_attention_heads": 96,
"num_key_value_heads": 8,
"head_dim": 128,
"intermediate_size": 8192,
"num_experts": 160,
"num_experts_per_tok": 8,
"moe_intermediate_size": 2560,
});
let m = parse_value(&json).unwrap();
assert_eq!(m.num_layers, 62);
assert_eq!(m.num_kv_heads, 8);
assert_eq!(m.head_dim, 128);
let moe = m.moe.as_ref().unwrap();
assert_eq!(moe.num_experts, 160);
assert_eq!(moe.num_active_experts, 8);
assert_eq!(moe.expert_intermediate_size, Some(2560));
assert_eq!(moe.num_shared_experts, 0);
assert!(m.mla.is_none());
assert!(m.attention.is_none());
}
#[test]
fn parse_moe_mla_dsa() {
let json = serde_json::json!({
"model_type": "glm_moe_dsa",
"num_hidden_layers": 78,
"hidden_size": 6144,
"num_attention_heads": 64,
"num_key_value_heads": 64,
"head_dim": 64,
"intermediate_size": 12288,
"n_routed_experts": 256,
"num_experts_per_tok": 8,
"n_shared_experts": 1,
"moe_intermediate_size": 2048,
"kv_lora_rank": 512,
"q_lora_rank": 2048,
"qk_nope_head_dim": 192,
"qk_rope_head_dim": 64,
"v_head_dim": 256,
"first_k_dense_replace": 3,
});
let m = parse_value(&json).unwrap();
assert_eq!(m.num_layers, 78);
assert_eq!(m.head_dim, 64);
let moe = m.moe.as_ref().unwrap();
assert_eq!(moe.num_experts, 256);
assert_eq!(moe.num_active_experts, 8);
let mla = m.mla.as_ref().unwrap();
assert_eq!(mla.kv_lora_rank, 512);
assert!(matches!(
m.attention,
Some(AttentionConfig::Dsa {
first_dense_layers: 3,
..
})
));
}
}