fix: cache calculation

This commit is contained in:
2026-04-15 17:31:39 +08:00
parent 365ceac3be
commit ff316c6873
23 changed files with 500 additions and 336 deletions

View File

@@ -33,7 +33,10 @@ pub enum AttentionPattern {
SlidingWindow { window: f64 },
/// DeepSeek Sparse Attention: effective_ctx = min(N, dense_window) +
/// max(0, N - dense_window) / sparse_stride.
Dsa { dense_window: f64, sparse_stride: f64 },
Dsa {
dense_window: f64,
sparse_stride: f64,
},
}
#[derive(Debug, Clone)]
@@ -101,8 +104,9 @@ impl ComputeModel {
// --- MLP FLOPs/token/layer (SwiGLU: gate + up + down = 3 matmuls) ---
let mlp = if let Some(moe) = &model.moe {
let expert_inter = moe.expert_intermediate_size
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
let expert_inter =
moe.expert_intermediate_size
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
let active = moe.num_active_experts as f64;
let shared = moe.num_shared_experts as f64;
active * 6.0 * h * expert_inter + shared * 6.0 * h * inter
@@ -132,16 +136,14 @@ impl ComputeModel {
let qk_hd = (mla.qk_nope_head_dim + mla.qk_rope_head_dim) as f64;
let qk_rd = mla.qk_rope_head_dim as f64;
let vhd = mla.v_head_dim as f64;
(h * qlr + qlr * n_heads * qk_hd
+ h * (kvlr + qk_rd)
+ n_heads * vhd * h)
* wdtype
(h * qlr + qlr * n_heads * qk_hd + h * (kvlr + qk_rd) + n_heads * vhd * h) * wdtype
} else {
((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * wdtype
};
let mlp_wt = if let Some(moe) = &model.moe {
let expert_inter = moe.expert_intermediate_size
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
let expert_inter =
moe.expert_intermediate_size
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
let active = moe.num_active_experts as f64;
let shared = moe.num_shared_experts as f64;
(active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * wdtype
@@ -169,10 +171,9 @@ impl ComputeModel {
},
0.0,
),
Some(AttentionConfig::Dense) | None => (
AttentionPattern::Dense,
model.num_layers as f64,
),
Some(AttentionConfig::Dense) | None => {
(AttentionPattern::Dense, model.num_layers as f64)
}
};
Self {
@@ -237,10 +238,10 @@ impl ComputeModel {
let dense_layers = self.first_dense_layers;
let sparse_layers = self.num_layers - dense_layers;
let dense_flops = dense_layers
* (linear + self.attn_coeff * n * self.effective_ctx(n, true));
let sparse_flops = sparse_layers
* (linear + self.attn_coeff * n * self.effective_ctx(n, false));
let dense_flops =
dense_layers * (linear + self.attn_coeff * n * self.effective_ctx(n, true));
let sparse_flops =
sparse_layers * (linear + self.attn_coeff * n * self.effective_ctx(n, false));
let total_flops = dense_flops + sparse_flops;
let compute_time = total_flops / self.gpu_flops;
@@ -254,7 +255,9 @@ impl ComputeModel {
pub fn describe(&self) -> String {
let pattern_str = match &self.attn_pattern {
AttentionPattern::Dense => "dense".to_string(),
AttentionPattern::SlidingWindow { window } => format!("sliding_window({})", *window as u64),
AttentionPattern::SlidingWindow { window } => {
format!("sliding_window({})", *window as u64)
}
AttentionPattern::Dsa {
dense_window,
sparse_stride,
@@ -266,8 +269,7 @@ impl ComputeModel {
format!(
"linear_flops/tok/layer={:.3e}, attn_coeff={:.0}, pattern={}, \
weight_bytes/layer={:.2e}",
self.linear_flops_per_token, self.attn_coeff, pattern_str,
self.weight_bytes_per_layer,
self.linear_flops_per_token, self.attn_coeff, pattern_str, self.weight_bytes_per_layer,
)
}
}

View File

@@ -1,6 +1,6 @@
pub mod compute;
pub mod kv_cache;
#[allow(clippy::module_inception)]
pub mod instance;
pub mod kv_cache;
pub use instance::Instance;