fix: cache calculation
This commit is contained in:
@@ -7,7 +7,7 @@ use anyhow::{Context, Result};
|
||||
use serde_json::Value;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::config::{AttentionConfig, MlaConfig, MoeConfig, ModelConfig};
|
||||
use crate::config::{AttentionConfig, MlaConfig, ModelConfig, MoeConfig};
|
||||
|
||||
/// Parse a HuggingFace config.json and return a partially-populated
|
||||
/// [`ModelConfig`]. The caller must still set `dtype_bytes` and
|
||||
@@ -34,8 +34,7 @@ fn parse_value(v: &Value) -> Result<ModelConfig> {
|
||||
let num_layers = u32_field(v, "num_hidden_layers");
|
||||
let hidden_size = u32_field(v, "hidden_size");
|
||||
let num_attention_heads = u32_field(v, "num_attention_heads");
|
||||
let num_kv_heads = u32_field(v, "num_key_value_heads")
|
||||
.or(num_attention_heads); // default to MHA
|
||||
let num_kv_heads = u32_field(v, "num_key_value_heads").or(num_attention_heads); // default to MHA
|
||||
let head_dim = u32_field(v, "head_dim").or_else(|| {
|
||||
// Infer: hidden_size / num_attention_heads
|
||||
match (hidden_size, num_attention_heads) {
|
||||
@@ -70,25 +69,24 @@ fn parse_value(v: &Value) -> Result<ModelConfig> {
|
||||
});
|
||||
|
||||
// --- Attention pattern ---
|
||||
let attention =
|
||||
if let Some(first_dense) = u32_field(v, "first_k_dense_replace") {
|
||||
// DSA-style model (GLM-5, DeepSeek-V3).
|
||||
// dense_window and sparse_stride are typically not in config.json;
|
||||
// use sensible defaults the user can override in YAML.
|
||||
Some(AttentionConfig::Dsa {
|
||||
dense_window: 4096,
|
||||
sparse_stride: 8,
|
||||
first_dense_layers: first_dense,
|
||||
})
|
||||
} else if let Some(sw) = v
|
||||
.get("sliding_window")
|
||||
.and_then(|x| x.as_u64())
|
||||
.map(|x| x as u32)
|
||||
{
|
||||
Some(AttentionConfig::SlidingWindow { window_size: sw })
|
||||
} else {
|
||||
None // dense by default
|
||||
};
|
||||
let attention = if let Some(first_dense) = u32_field(v, "first_k_dense_replace") {
|
||||
// DSA-style model (GLM-5, DeepSeek-V3).
|
||||
// dense_window and sparse_stride are typically not in config.json;
|
||||
// use sensible defaults the user can override in YAML.
|
||||
Some(AttentionConfig::Dsa {
|
||||
dense_window: 4096,
|
||||
sparse_stride: 8,
|
||||
first_dense_layers: first_dense,
|
||||
})
|
||||
} else if let Some(sw) = v
|
||||
.get("sliding_window")
|
||||
.and_then(|x| x.as_u64())
|
||||
.map(|x| x as u32)
|
||||
{
|
||||
Some(AttentionConfig::SlidingWindow { window_size: sw })
|
||||
} else {
|
||||
None // dense by default
|
||||
};
|
||||
|
||||
Ok(ModelConfig {
|
||||
name,
|
||||
@@ -188,6 +186,12 @@ mod tests {
|
||||
assert_eq!(moe.num_active_experts, 8);
|
||||
let mla = m.mla.as_ref().unwrap();
|
||||
assert_eq!(mla.kv_lora_rank, 512);
|
||||
assert!(matches!(m.attention, Some(AttentionConfig::Dsa { first_dense_layers: 3, .. })));
|
||||
assert!(matches!(
|
||||
m.attention,
|
||||
Some(AttentionConfig::Dsa {
|
||||
first_dense_layers: 3,
|
||||
..
|
||||
})
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user