//! Roofline cost model for prefill (PD disaggregation — decode not modeled). //! //! Two construction modes: //! //! **Architecture-derived** (`ModelConfig.hidden_size` present): //! All FLOPs, attention coefficients, and weight-stream costs are computed //! from the model shape. Handles standard / GQA / MLA attention projections, //! MoE routing, and DSA / sliding-window sub-quadratic attention patterns. //! //! **Legacy manual** (`hidden_size` absent): uses the raw //! `flops_per_token_prefill` + `attn_quadratic_coeff` scalars from the YAML. //! //! ```text //! prefill_time(N) = max(compute_time(N), mem_time) //! //! compute_time = sum over layers of: //! (N * linear_flops + attn_coeff * N * effective_ctx(N)) / gpu_flops //! //! mem_time = num_layers * weight_bytes_per_layer / gpu_mem_bw //! ``` //! //! `effective_ctx(N)` equals `N` for dense attention (→ O(N²) total) but //! is sub-linear for DSA / sliding-window. use crate::config::{AttentionConfig, HardwareConfig, ModelConfig}; /// Resolved attention pattern used at runtime. #[derive(Debug, Clone)] pub enum AttentionPattern { /// Full quadratic: effective_ctx = N. Dense, /// Sliding window: effective_ctx = min(N, window). SlidingWindow { window: f64 }, /// DeepSeek Sparse Attention: effective_ctx = min(N, dense_window) + /// max(0, N - dense_window) / sparse_stride. Dsa { dense_window: f64, sparse_stride: f64 }, } #[derive(Debug, Clone)] pub struct ComputeModel { /// Total transformer layers. pub num_layers: f64, /// How many initial layers use dense attention (rest use `attn_pattern`). /// For `Dense` pattern this equals `num_layers`. pub first_dense_layers: f64, /// Non-attention FLOPs per token per layer (QKV proj + output proj + MLP). pub linear_flops_per_token: f64, /// Attention score coefficient: per-layer attention FLOPs = /// `attn_coeff * N * effective_ctx(N)`. pub attn_coeff: f64, /// Attention pattern for non-dense layers. pub attn_pattern: AttentionPattern, /// Weight bytes read from HBM per layer (for memory-bound check). pub weight_bytes_per_layer: f64, /// Peak GPU FLOPs (aggregate across TP group). pub gpu_flops: f64, /// Peak GPU memory bandwidth (aggregate across TP group). pub gpu_mem_bw: f64, } impl ComputeModel { pub fn new(model: &ModelConfig, hw: &HardwareConfig) -> Self { if model.is_arch_mode() { Self::from_arch(model, hw) } else { Self::from_manual(model, hw) } } // ----- Architecture-derived construction -------------------------------- fn from_arch(model: &ModelConfig, hw: &HardwareConfig) -> Self { let h = model.hidden_size.unwrap() as f64; let n_heads = model.num_attention_heads.unwrap_or(model.num_kv_heads) as f64; let n_kv = model.num_kv_heads as f64; let hd = model.head_dim as f64; let inter = model.intermediate_size.unwrap_or(0) as f64; // Weight dtype for memory-bound check (separate from KV cache dtype). let wdtype = model.weight_dtype_bytes(); // --- Attention linear FLOPs/token/layer --- let attn_linear = if let Some(mla) = &model.mla { let qlr = mla.q_lora_rank as f64; let kvlr = mla.kv_lora_rank as f64; let qk_hd = (mla.qk_nope_head_dim + mla.qk_rope_head_dim) as f64; let qk_rd = mla.qk_rope_head_dim as f64; let vhd = mla.v_head_dim as f64; // Q: down-project + up-project let q = 2.0 * h * qlr + 2.0 * qlr * n_heads * qk_hd; // KV: down-project (compressed latent + RoPE key) let kv = 2.0 * h * (kvlr + qk_rd); // Output: up-project let o = 2.0 * n_heads * vhd * h; q + kv + o } else { // Standard / GQA let qkv = 2.0 * h * (n_heads + 2.0 * n_kv) * hd; let o = 2.0 * n_heads * hd * h; qkv + o }; // --- MLP FLOPs/token/layer (SwiGLU: gate + up + down = 3 matmuls) --- let mlp = if let Some(moe) = &model.moe { let expert_inter = moe.expert_intermediate_size .unwrap_or(model.intermediate_size.unwrap_or(0)) as f64; let active = moe.num_active_experts as f64; let shared = moe.num_shared_experts as f64; active * 6.0 * h * expert_inter + shared * 6.0 * h * inter } else { 6.0 * h * inter }; let linear_flops = attn_linear + mlp; // --- Attention quadratic coefficient --- // attn_flops_per_layer(N) = attn_coeff * N * effective_ctx(N) let attn_coeff = if let Some(mla) = &model.mla { let kvlr = mla.kv_lora_rank as f64; let qk_rd = mla.qk_rope_head_dim as f64; // Absorbed QK^T: each head dots over (kv_lora_rank + qk_rope_head_dim) dims. // Absorbed V: each head dots over kv_lora_rank dims. 2.0 * n_heads * (2.0 * kvlr + qk_rd) } else { // Standard: QK^T + attn@V, each 2 * n_heads * head_dim per pair. 4.0 * n_heads * hd }; // --- Weight bytes per layer (active params only for MoE) --- let attn_wt = if let Some(mla) = &model.mla { let qlr = mla.q_lora_rank as f64; let kvlr = mla.kv_lora_rank as f64; let qk_hd = (mla.qk_nope_head_dim + mla.qk_rope_head_dim) as f64; let qk_rd = mla.qk_rope_head_dim as f64; let vhd = mla.v_head_dim as f64; (h * qlr + qlr * n_heads * qk_hd + h * (kvlr + qk_rd) + n_heads * vhd * h) * wdtype } else { ((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * wdtype }; let mlp_wt = if let Some(moe) = &model.moe { let expert_inter = moe.expert_intermediate_size .unwrap_or(model.intermediate_size.unwrap_or(0)) as f64; let active = moe.num_active_experts as f64; let shared = moe.num_shared_experts as f64; (active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * wdtype } else { 3.0 * h * inter * wdtype }; let weight_bytes = attn_wt + mlp_wt; // --- Attention pattern --- let (attn_pattern, first_dense) = match &model.attention { Some(AttentionConfig::Dsa { dense_window, sparse_stride, first_dense_layers, }) => ( AttentionPattern::Dsa { dense_window: *dense_window as f64, sparse_stride: *sparse_stride as f64, }, *first_dense_layers as f64, ), Some(AttentionConfig::SlidingWindow { window_size }) => ( AttentionPattern::SlidingWindow { window: *window_size as f64, }, 0.0, ), Some(AttentionConfig::Dense) | None => ( AttentionPattern::Dense, model.num_layers as f64, ), }; Self { num_layers: model.num_layers as f64, first_dense_layers: first_dense, linear_flops_per_token: linear_flops, attn_coeff, attn_pattern, weight_bytes_per_layer: weight_bytes, gpu_flops: hw.gpu_flops, gpu_mem_bw: hw.gpu_mem_bw, } } // ----- Legacy manual construction --------------------------------------- fn from_manual(model: &ModelConfig, hw: &HardwareConfig) -> Self { Self { num_layers: model.num_layers as f64, first_dense_layers: model.num_layers as f64, linear_flops_per_token: model.flops_per_token_prefill.unwrap_or(0.0), attn_coeff: model.attn_quadratic_coeff.unwrap_or(0.0), attn_pattern: AttentionPattern::Dense, weight_bytes_per_layer: 0.0, gpu_flops: hw.gpu_flops, gpu_mem_bw: hw.gpu_mem_bw, } } // ----- Prefill time ----------------------------------------------------- /// Effective context length a single token attends to at sequence length N. fn effective_ctx(&self, n: f64, dense_layer: bool) -> f64 { if dense_layer { return n; } match &self.attn_pattern { AttentionPattern::Dense => n, AttentionPattern::SlidingWindow { window } => n.min(*window), AttentionPattern::Dsa { dense_window, sparse_stride, } => { if n <= *dense_window { n } else { *dense_window + (n - *dense_window) / *sparse_stride } } } } /// Time (s) to prefill `n` tokens. pub fn prefill_time(&self, n: u32) -> f64 { if n == 0 { return 0.0; } let n = n as f64; let linear = n * self.linear_flops_per_token; // Compute FLOPs across all layers (dense + sparse may differ). let dense_layers = self.first_dense_layers; let sparse_layers = self.num_layers - dense_layers; let dense_flops = dense_layers * (linear + self.attn_coeff * n * self.effective_ctx(n, true)); let sparse_flops = sparse_layers * (linear + self.attn_coeff * n * self.effective_ctx(n, false)); let total_flops = dense_flops + sparse_flops; let compute_time = total_flops / self.gpu_flops; // Weight stream: all layers' active weights read once from HBM. let mem_time = self.weight_bytes_per_layer * self.num_layers / self.gpu_mem_bw; compute_time.max(mem_time) } /// Print human-readable derived coefficients (for `validate` output). pub fn describe(&self) -> String { let pattern_str = match &self.attn_pattern { AttentionPattern::Dense => "dense".to_string(), AttentionPattern::SlidingWindow { window } => format!("sliding_window({})", *window as u64), AttentionPattern::Dsa { dense_window, sparse_stride, } => format!( "dsa(window={}, stride={}, {} dense layers)", *dense_window as u64, *sparse_stride as u64, self.first_dense_layers as u64 ), }; format!( "linear_flops/tok/layer={:.3e}, attn_coeff={:.0}, pattern={}, \ weight_bytes/layer={:.2e}", self.linear_flops_per_token, self.attn_coeff, pattern_str, self.weight_bytes_per_layer, ) } } #[cfg(test)] mod tests { use super::*; fn cm_legacy() -> ComputeModel { ComputeModel { num_layers: 28.0, first_dense_layers: 28.0, linear_flops_per_token: 1.4e10, attn_coeff: 1024.0, attn_pattern: AttentionPattern::Dense, weight_bytes_per_layer: 0.0, gpu_flops: 9.89e14, gpu_mem_bw: 3.35e12, } } #[test] fn prefill_monotonic_in_n() { let m = cm_legacy(); let mut prev = 0.0; for &n in &[1u32, 8, 64, 512, 4096, 32768] { let t = m.prefill_time(n); assert!(t > prev, "prefill_time should be monotonic; n={n} t={t}"); prev = t; } } #[test] fn quadratic_dominates_for_long_prompt() { let m = cm_legacy(); let lin = m.prefill_time(1024); let big = m.prefill_time(32768); assert!(big / lin > 32.0); } #[test] fn zero_tokens_is_free() { let m = cm_legacy(); assert_eq!(m.prefill_time(0), 0.0); } #[test] fn dsa_subquadratic() { // With DSA (window=4096, stride=8) the cost at 128k should be // MUCH less than pure quadratic. let dense = ComputeModel { num_layers: 78.0, first_dense_layers: 78.0, linear_flops_per_token: 1.0e9, attn_coeff: 139264.0, attn_pattern: AttentionPattern::Dense, weight_bytes_per_layer: 0.0, gpu_flops: 1.8e16, gpu_mem_bw: 6.4e13, }; let dsa = ComputeModel { attn_pattern: AttentionPattern::Dsa { dense_window: 4096.0, sparse_stride: 8.0, }, first_dense_layers: 3.0, ..dense.clone() }; let n = 131072; // 128k tokens let t_dense = dense.prefill_time(n); let t_dsa = dsa.prefill_time(n); // DSA should be dramatically cheaper at long context. assert!( t_dsa < t_dense * 0.3, "DSA should be <30% of dense at 128k: dense={t_dense:.3} dsa={t_dsa:.3}" ); // But still monotonic. assert!(t_dsa > dsa.prefill_time(n / 2)); } #[test] fn mem_bound_short_prefill() { // With very heavy weights and a short prompt, memory should dominate. let m = ComputeModel { num_layers: 10.0, first_dense_layers: 10.0, linear_flops_per_token: 1.0e6, // tiny compute attn_coeff: 1.0, attn_pattern: AttentionPattern::Dense, weight_bytes_per_layer: 1.0e12, // 1 TB per layer gpu_flops: 1.0e15, gpu_mem_bw: 1.0e12, }; let t1 = m.prefill_time(1); let t8 = m.prefill_time(8); // Memory time = 10 * 1e12 / 1e12 = 10s, should dominate. assert!((t1 - 10.0).abs() < 0.01); // Doubling tokens shouldn't change time much (mem-bound). assert!((t8 - t1).abs() / t1 < 0.01); } #[test] fn arch_derives_from_model_config() { // Minimal dense model: verify from_arch produces something sensible. let model = ModelConfig { name: "test".into(), num_layers: 4, num_kv_heads: 2, head_dim: 64, dtype_bytes: 2, block_size_tokens: 16, hidden_size: Some(256), num_attention_heads: Some(4), intermediate_size: Some(512), ..Default::default() }; let hw = HardwareConfig { gpu_flops: 1e14, gpu_fp8_flops: 0.0, gpu_fp4_flops: 0.0, gpu_mem_bw: 1e12, hbm_bytes: 1e9, dram_bytes: 4e9, pcie_bw: 32e9, pcie_latency_us: 1.0, rdma_bw: 12e9, rdma_latency_us: 5.0, max_batch_slots: 32, prefill_chunk_tokens: 1024, }; let cm = ComputeModel::new(&model, &hw); assert!(cm.linear_flops_per_token > 0.0); assert!(cm.attn_coeff > 0.0); assert!(cm.weight_bytes_per_layer > 0.0); let t = cm.prefill_time(1024); assert!(t > 0.0); } }