fix: cache calculation
This commit is contained in:
@@ -33,7 +33,10 @@ pub enum AttentionPattern {
|
||||
SlidingWindow { window: f64 },
|
||||
/// DeepSeek Sparse Attention: effective_ctx = min(N, dense_window) +
|
||||
/// max(0, N - dense_window) / sparse_stride.
|
||||
Dsa { dense_window: f64, sparse_stride: f64 },
|
||||
Dsa {
|
||||
dense_window: f64,
|
||||
sparse_stride: f64,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -101,8 +104,9 @@ impl ComputeModel {
|
||||
|
||||
// --- MLP FLOPs/token/layer (SwiGLU: gate + up + down = 3 matmuls) ---
|
||||
let mlp = if let Some(moe) = &model.moe {
|
||||
let expert_inter = moe.expert_intermediate_size
|
||||
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
|
||||
let expert_inter =
|
||||
moe.expert_intermediate_size
|
||||
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
|
||||
let active = moe.num_active_experts as f64;
|
||||
let shared = moe.num_shared_experts as f64;
|
||||
active * 6.0 * h * expert_inter + shared * 6.0 * h * inter
|
||||
@@ -132,16 +136,14 @@ impl ComputeModel {
|
||||
let qk_hd = (mla.qk_nope_head_dim + mla.qk_rope_head_dim) as f64;
|
||||
let qk_rd = mla.qk_rope_head_dim as f64;
|
||||
let vhd = mla.v_head_dim as f64;
|
||||
(h * qlr + qlr * n_heads * qk_hd
|
||||
+ h * (kvlr + qk_rd)
|
||||
+ n_heads * vhd * h)
|
||||
* wdtype
|
||||
(h * qlr + qlr * n_heads * qk_hd + h * (kvlr + qk_rd) + n_heads * vhd * h) * wdtype
|
||||
} else {
|
||||
((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * wdtype
|
||||
};
|
||||
let mlp_wt = if let Some(moe) = &model.moe {
|
||||
let expert_inter = moe.expert_intermediate_size
|
||||
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
|
||||
let expert_inter =
|
||||
moe.expert_intermediate_size
|
||||
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
|
||||
let active = moe.num_active_experts as f64;
|
||||
let shared = moe.num_shared_experts as f64;
|
||||
(active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * wdtype
|
||||
@@ -169,10 +171,9 @@ impl ComputeModel {
|
||||
},
|
||||
0.0,
|
||||
),
|
||||
Some(AttentionConfig::Dense) | None => (
|
||||
AttentionPattern::Dense,
|
||||
model.num_layers as f64,
|
||||
),
|
||||
Some(AttentionConfig::Dense) | None => {
|
||||
(AttentionPattern::Dense, model.num_layers as f64)
|
||||
}
|
||||
};
|
||||
|
||||
Self {
|
||||
@@ -237,10 +238,10 @@ impl ComputeModel {
|
||||
let dense_layers = self.first_dense_layers;
|
||||
let sparse_layers = self.num_layers - dense_layers;
|
||||
|
||||
let dense_flops = dense_layers
|
||||
* (linear + self.attn_coeff * n * self.effective_ctx(n, true));
|
||||
let sparse_flops = sparse_layers
|
||||
* (linear + self.attn_coeff * n * self.effective_ctx(n, false));
|
||||
let dense_flops =
|
||||
dense_layers * (linear + self.attn_coeff * n * self.effective_ctx(n, true));
|
||||
let sparse_flops =
|
||||
sparse_layers * (linear + self.attn_coeff * n * self.effective_ctx(n, false));
|
||||
let total_flops = dense_flops + sparse_flops;
|
||||
|
||||
let compute_time = total_flops / self.gpu_flops;
|
||||
@@ -254,7 +255,9 @@ impl ComputeModel {
|
||||
pub fn describe(&self) -> String {
|
||||
let pattern_str = match &self.attn_pattern {
|
||||
AttentionPattern::Dense => "dense".to_string(),
|
||||
AttentionPattern::SlidingWindow { window } => format!("sliding_window({})", *window as u64),
|
||||
AttentionPattern::SlidingWindow { window } => {
|
||||
format!("sliding_window({})", *window as u64)
|
||||
}
|
||||
AttentionPattern::Dsa {
|
||||
dense_window,
|
||||
sparse_stride,
|
||||
@@ -266,8 +269,7 @@ impl ComputeModel {
|
||||
format!(
|
||||
"linear_flops/tok/layer={:.3e}, attn_coeff={:.0}, pattern={}, \
|
||||
weight_bytes/layer={:.2e}",
|
||||
self.linear_flops_per_token, self.attn_coeff, pattern_str,
|
||||
self.weight_bytes_per_layer,
|
||||
self.linear_flops_per_token, self.attn_coeff, pattern_str, self.weight_bytes_per_layer,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
pub mod compute;
|
||||
pub mod kv_cache;
|
||||
#[allow(clippy::module_inception)]
|
||||
pub mod instance;
|
||||
pub mod kv_cache;
|
||||
|
||||
pub use instance::Instance;
|
||||
|
||||
Reference in New Issue
Block a user