Files
kvcache-simulator/src/instance/compute.rs
Gahow Wang ec73a95e05 KVCache simulator for LLM serving cluster routing research
Discrete-event simulator for evaluating KV cache-aware routing policies
in prefill-disaggregated LLM serving clusters. Models a two-tier KV cache
hierarchy (L0 GPU HBM + L1 CPU DRAM) with RDMA/PCIe link contention,
architecture-derived roofline compute (MoE, MLA, DSA), and a cluster-wide
meta-store for prefix-aware routing decisions.

Includes 11 routing policies (random, round_robin, least_loaded,
least_tokens, ttl_aware, precise, min_pd, cache_load, cache_score,
estimated_ttft, prefix_affinity), HuggingFace config.json auto-parsing,
built-in GPU hardware presets (H100/H800/H20/A100/B200), and ablation
tooling for systematic policy comparison across real Alibaba serving traces.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-14 01:16:02 +08:00

406 lines
14 KiB
Rust

//! Roofline cost model for prefill (PD disaggregation — decode not modeled).
//!
//! Two construction modes:
//!
//! **Architecture-derived** (`ModelConfig.hidden_size` present):
//! All FLOPs, attention coefficients, and weight-stream costs are computed
//! from the model shape. Handles standard / GQA / MLA attention projections,
//! MoE routing, and DSA / sliding-window sub-quadratic attention patterns.
//!
//! **Legacy manual** (`hidden_size` absent): uses the raw
//! `flops_per_token_prefill` + `attn_quadratic_coeff` scalars from the YAML.
//!
//! ```text
//! prefill_time(N) = max(compute_time(N), mem_time)
//!
//! compute_time = sum over layers of:
//! (N * linear_flops + attn_coeff * N * effective_ctx(N)) / gpu_flops
//!
//! mem_time = num_layers * weight_bytes_per_layer / gpu_mem_bw
//! ```
//!
//! `effective_ctx(N)` equals `N` for dense attention (→ O(N²) total) but
//! is sub-linear for DSA / sliding-window.
use crate::config::{AttentionConfig, HardwareConfig, ModelConfig};
/// Resolved attention pattern used at runtime.
#[derive(Debug, Clone)]
pub enum AttentionPattern {
/// Full quadratic: effective_ctx = N.
Dense,
/// Sliding window: effective_ctx = min(N, window).
SlidingWindow { window: f64 },
/// DeepSeek Sparse Attention: effective_ctx = min(N, dense_window) +
/// max(0, N - dense_window) / sparse_stride.
Dsa { dense_window: f64, sparse_stride: f64 },
}
#[derive(Debug, Clone)]
pub struct ComputeModel {
/// Total transformer layers.
pub num_layers: f64,
/// How many initial layers use dense attention (rest use `attn_pattern`).
/// For `Dense` pattern this equals `num_layers`.
pub first_dense_layers: f64,
/// Non-attention FLOPs per token per layer (QKV proj + output proj + MLP).
pub linear_flops_per_token: f64,
/// Attention score coefficient: per-layer attention FLOPs =
/// `attn_coeff * N * effective_ctx(N)`.
pub attn_coeff: f64,
/// Attention pattern for non-dense layers.
pub attn_pattern: AttentionPattern,
/// Weight bytes read from HBM per layer (for memory-bound check).
pub weight_bytes_per_layer: f64,
/// Peak GPU FLOPs (aggregate across TP group).
pub gpu_flops: f64,
/// Peak GPU memory bandwidth (aggregate across TP group).
pub gpu_mem_bw: f64,
}
impl ComputeModel {
pub fn new(model: &ModelConfig, hw: &HardwareConfig) -> Self {
if model.is_arch_mode() {
Self::from_arch(model, hw)
} else {
Self::from_manual(model, hw)
}
}
// ----- Architecture-derived construction --------------------------------
fn from_arch(model: &ModelConfig, hw: &HardwareConfig) -> Self {
let h = model.hidden_size.unwrap() as f64;
let n_heads = model.num_attention_heads.unwrap_or(model.num_kv_heads) as f64;
let n_kv = model.num_kv_heads as f64;
let hd = model.head_dim as f64;
let inter = model.intermediate_size.unwrap_or(0) as f64;
let dtype = model.dtype_bytes as f64;
// --- Attention linear FLOPs/token/layer ---
let attn_linear = if let Some(mla) = &model.mla {
let qlr = mla.q_lora_rank as f64;
let kvlr = mla.kv_lora_rank as f64;
let qk_hd = (mla.qk_nope_head_dim + mla.qk_rope_head_dim) as f64;
let qk_rd = mla.qk_rope_head_dim as f64;
let vhd = mla.v_head_dim as f64;
// Q: down-project + up-project
let q = 2.0 * h * qlr + 2.0 * qlr * n_heads * qk_hd;
// KV: down-project (compressed latent + RoPE key)
let kv = 2.0 * h * (kvlr + qk_rd);
// Output: up-project
let o = 2.0 * n_heads * vhd * h;
q + kv + o
} else {
// Standard / GQA
let qkv = 2.0 * h * (n_heads + 2.0 * n_kv) * hd;
let o = 2.0 * n_heads * hd * h;
qkv + o
};
// --- MLP FLOPs/token/layer (SwiGLU: gate + up + down = 3 matmuls) ---
let mlp = if let Some(moe) = &model.moe {
let expert_inter = moe.expert_intermediate_size
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
let active = moe.num_active_experts as f64;
let shared = moe.num_shared_experts as f64;
active * 6.0 * h * expert_inter + shared * 6.0 * h * inter
} else {
6.0 * h * inter
};
let linear_flops = attn_linear + mlp;
// --- Attention quadratic coefficient ---
// attn_flops_per_layer(N) = attn_coeff * N * effective_ctx(N)
let attn_coeff = if let Some(mla) = &model.mla {
let kvlr = mla.kv_lora_rank as f64;
let qk_rd = mla.qk_rope_head_dim as f64;
// Absorbed QK^T: each head dots over (kv_lora_rank + qk_rope_head_dim) dims.
// Absorbed V: each head dots over kv_lora_rank dims.
2.0 * n_heads * (2.0 * kvlr + qk_rd)
} else {
// Standard: QK^T + attn@V, each 2 * n_heads * head_dim per pair.
4.0 * n_heads * hd
};
// --- Weight bytes per layer (active params only for MoE) ---
let attn_wt = if let Some(mla) = &model.mla {
let qlr = mla.q_lora_rank as f64;
let kvlr = mla.kv_lora_rank as f64;
let qk_hd = (mla.qk_nope_head_dim + mla.qk_rope_head_dim) as f64;
let qk_rd = mla.qk_rope_head_dim as f64;
let vhd = mla.v_head_dim as f64;
(h * qlr + qlr * n_heads * qk_hd
+ h * (kvlr + qk_rd)
+ n_heads * vhd * h)
* dtype
} else {
((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * dtype
};
let mlp_wt = if let Some(moe) = &model.moe {
let expert_inter = moe.expert_intermediate_size
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
let active = moe.num_active_experts as f64;
let shared = moe.num_shared_experts as f64;
(active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * dtype
} else {
3.0 * h * inter * dtype
};
let weight_bytes = attn_wt + mlp_wt;
// --- Attention pattern ---
let (attn_pattern, first_dense) = match &model.attention {
Some(AttentionConfig::Dsa {
dense_window,
sparse_stride,
first_dense_layers,
}) => (
AttentionPattern::Dsa {
dense_window: *dense_window as f64,
sparse_stride: *sparse_stride as f64,
},
*first_dense_layers as f64,
),
Some(AttentionConfig::SlidingWindow { window_size }) => (
AttentionPattern::SlidingWindow {
window: *window_size as f64,
},
0.0,
),
Some(AttentionConfig::Dense) | None => (
AttentionPattern::Dense,
model.num_layers as f64,
),
};
Self {
num_layers: model.num_layers as f64,
first_dense_layers: first_dense,
linear_flops_per_token: linear_flops,
attn_coeff,
attn_pattern,
weight_bytes_per_layer: weight_bytes,
gpu_flops: hw.gpu_flops,
gpu_mem_bw: hw.gpu_mem_bw,
}
}
// ----- Legacy manual construction ---------------------------------------
fn from_manual(model: &ModelConfig, hw: &HardwareConfig) -> Self {
Self {
num_layers: model.num_layers as f64,
first_dense_layers: model.num_layers as f64,
linear_flops_per_token: model.flops_per_token_prefill.unwrap_or(0.0),
attn_coeff: model.attn_quadratic_coeff.unwrap_or(0.0),
attn_pattern: AttentionPattern::Dense,
weight_bytes_per_layer: 0.0,
gpu_flops: hw.gpu_flops,
gpu_mem_bw: hw.gpu_mem_bw,
}
}
// ----- Prefill time -----------------------------------------------------
/// Effective context length a single token attends to at sequence length N.
fn effective_ctx(&self, n: f64, dense_layer: bool) -> f64 {
if dense_layer {
return n;
}
match &self.attn_pattern {
AttentionPattern::Dense => n,
AttentionPattern::SlidingWindow { window } => n.min(*window),
AttentionPattern::Dsa {
dense_window,
sparse_stride,
} => {
if n <= *dense_window {
n
} else {
*dense_window + (n - *dense_window) / *sparse_stride
}
}
}
}
/// Time (s) to prefill `n` tokens.
pub fn prefill_time(&self, n: u32) -> f64 {
if n == 0 {
return 0.0;
}
let n = n as f64;
let linear = n * self.linear_flops_per_token;
// Compute FLOPs across all layers (dense + sparse may differ).
let dense_layers = self.first_dense_layers;
let sparse_layers = self.num_layers - dense_layers;
let dense_flops = dense_layers
* (linear + self.attn_coeff * n * self.effective_ctx(n, true));
let sparse_flops = sparse_layers
* (linear + self.attn_coeff * n * self.effective_ctx(n, false));
let total_flops = dense_flops + sparse_flops;
let compute_time = total_flops / self.gpu_flops;
// Weight stream: all layers' active weights read once from HBM.
let mem_time = self.weight_bytes_per_layer * self.num_layers / self.gpu_mem_bw;
compute_time.max(mem_time)
}
/// Print human-readable derived coefficients (for `validate` output).
pub fn describe(&self) -> String {
let pattern_str = match &self.attn_pattern {
AttentionPattern::Dense => "dense".to_string(),
AttentionPattern::SlidingWindow { window } => format!("sliding_window({})", *window as u64),
AttentionPattern::Dsa {
dense_window,
sparse_stride,
} => format!(
"dsa(window={}, stride={}, {} dense layers)",
*dense_window as u64, *sparse_stride as u64, self.first_dense_layers as u64
),
};
format!(
"linear_flops/tok/layer={:.3e}, attn_coeff={:.0}, pattern={}, \
weight_bytes/layer={:.2e}",
self.linear_flops_per_token, self.attn_coeff, pattern_str,
self.weight_bytes_per_layer,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cm_legacy() -> ComputeModel {
ComputeModel {
num_layers: 28.0,
first_dense_layers: 28.0,
linear_flops_per_token: 1.4e10,
attn_coeff: 1024.0,
attn_pattern: AttentionPattern::Dense,
weight_bytes_per_layer: 0.0,
gpu_flops: 9.89e14,
gpu_mem_bw: 3.35e12,
}
}
#[test]
fn prefill_monotonic_in_n() {
let m = cm_legacy();
let mut prev = 0.0;
for &n in &[1u32, 8, 64, 512, 4096, 32768] {
let t = m.prefill_time(n);
assert!(t > prev, "prefill_time should be monotonic; n={n} t={t}");
prev = t;
}
}
#[test]
fn quadratic_dominates_for_long_prompt() {
let m = cm_legacy();
let lin = m.prefill_time(1024);
let big = m.prefill_time(32768);
assert!(big / lin > 32.0);
}
#[test]
fn zero_tokens_is_free() {
let m = cm_legacy();
assert_eq!(m.prefill_time(0), 0.0);
}
#[test]
fn dsa_subquadratic() {
// With DSA (window=4096, stride=8) the cost at 128k should be
// MUCH less than pure quadratic.
let dense = ComputeModel {
num_layers: 78.0,
first_dense_layers: 78.0,
linear_flops_per_token: 1.0e9,
attn_coeff: 139264.0,
attn_pattern: AttentionPattern::Dense,
weight_bytes_per_layer: 0.0,
gpu_flops: 1.8e16,
gpu_mem_bw: 6.4e13,
};
let dsa = ComputeModel {
attn_pattern: AttentionPattern::Dsa {
dense_window: 4096.0,
sparse_stride: 8.0,
},
first_dense_layers: 3.0,
..dense.clone()
};
let n = 131072; // 128k tokens
let t_dense = dense.prefill_time(n);
let t_dsa = dsa.prefill_time(n);
// DSA should be dramatically cheaper at long context.
assert!(
t_dsa < t_dense * 0.3,
"DSA should be <30% of dense at 128k: dense={t_dense:.3} dsa={t_dsa:.3}"
);
// But still monotonic.
assert!(t_dsa > dsa.prefill_time(n / 2));
}
#[test]
fn mem_bound_short_prefill() {
// With very heavy weights and a short prompt, memory should dominate.
let m = ComputeModel {
num_layers: 10.0,
first_dense_layers: 10.0,
linear_flops_per_token: 1.0e6, // tiny compute
attn_coeff: 1.0,
attn_pattern: AttentionPattern::Dense,
weight_bytes_per_layer: 1.0e12, // 1 TB per layer
gpu_flops: 1.0e15,
gpu_mem_bw: 1.0e12,
};
let t1 = m.prefill_time(1);
let t8 = m.prefill_time(8);
// Memory time = 10 * 1e12 / 1e12 = 10s, should dominate.
assert!((t1 - 10.0).abs() < 0.01);
// Doubling tokens shouldn't change time much (mem-bound).
assert!((t8 - t1).abs() / t1 < 0.01);
}
#[test]
fn arch_derives_from_model_config() {
// Minimal dense model: verify from_arch produces something sensible.
let model = ModelConfig {
name: "test".into(),
num_layers: 4,
num_kv_heads: 2,
head_dim: 64,
dtype_bytes: 2,
block_size_tokens: 16,
hidden_size: Some(256),
num_attention_heads: Some(4),
intermediate_size: Some(512),
..Default::default()
};
let hw = HardwareConfig {
gpu_flops: 1e14,
gpu_mem_bw: 1e12,
hbm_bytes: 1e9,
dram_bytes: 4e9,
pcie_bw: 32e9,
pcie_latency_us: 1.0,
rdma_bw: 12e9,
rdma_latency_us: 5.0,
max_batch_slots: 32,
prefill_chunk_tokens: 1024,
};
let cm = ComputeModel::new(&model, &hw);
assert!(cm.linear_flops_per_token > 0.0);
assert!(cm.attn_coeff > 0.0);
assert!(cm.weight_bytes_per_layer > 0.0);
let t = cm.prefill_time(1024);
assert!(t > 0.0);
}
}