From 7de38fa9987e635f5a71b66ed4a3a788a24f8e5c Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 17 Apr 2026 14:35:09 +0800 Subject: [PATCH] fix: guard legacy runtime paths for bucketed configs --- src/config.rs | 125 ++++++++++++++++++++++++++++++++++++++++++++ src/driver.rs | 5 ++ src/main.rs | 137 +++++++++++++++++++++++++++++++++++++++++++++++-- src/replay.rs | 2 + tests/smoke.rs | 55 +++++++++++++++++++- 5 files changed, 319 insertions(+), 5 deletions(-) diff --git a/src/config.rs b/src/config.rs index f646dcb..5be0ba6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -419,6 +419,15 @@ pub enum GlobalRouterMode { BucketScore, } +impl GlobalRouterMode { + pub fn as_str(&self) -> &'static str { + match self { + Self::StrictInputLength => "strict_input_length", + Self::BucketScore => "bucket_score", + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MetaStoreConfig { pub ttl_seconds: f64, @@ -468,6 +477,15 @@ fn default_prefix_k() -> usize { } impl ClusterConfig { + pub fn require_legacy_single_pool(&self, context: &str) -> Result { + if !self.buckets.is_empty() { + anyhow::bail!("{context} does not support cluster.buckets until Task 2 lands"); + } + + self.legacy_num_instances() + .ok_or_else(|| anyhow::anyhow!("{context} requires cluster.num_instances")) + } + pub fn legacy_num_instances(&self) -> Option { if self.buckets.is_empty() { self.num_instances @@ -481,6 +499,24 @@ impl ClusterConfig { .unwrap_or_else(|| self.buckets.iter().map(|bucket| bucket.num_instances).sum()) } + pub fn bucket_index_for_input_len(&self, input_len: u32) -> Result { + if self.buckets.is_empty() { + return Ok(0); + } + + self.buckets + .iter() + .position(|bucket| { + bucket.input_length_min <= input_len && input_len <= bucket.input_length_max + }) + .ok_or_else(|| { + anyhow::anyhow!( + "cluster.global_router.mode={} has no bucket for input_length={input_len}", + self.global_router.mode.as_str() + ) + }) + } + pub fn effective_buckets(&self) -> Vec { if !self.buckets.is_empty() { return self.buckets.clone(); @@ -1270,4 +1306,93 @@ sim: assert!(err.to_string().contains("num_instances")); assert!(err.to_string().contains("buckets")); } + + #[test] + fn bucketed_config_reports_unmatched_input_length_gaps_clearly() { + let gapful = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 + - name: long + input_length_min: 64 + input_length_max: 96 + num_instances: 1 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let cfg = Config::from_yaml_path(&gapful).unwrap(); + + let err = cfg.cluster.bucket_index_for_input_len(40).unwrap_err(); + assert!(err.to_string().contains("40")); + assert!(err.to_string().contains("no bucket")); + } + + #[test] + fn bucketed_config_requires_legacy_single_pool_for_legacy_runtime_paths() { + let bucketed = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let cfg = Config::from_yaml_path(&bucketed).unwrap(); + + let err = cfg + .cluster + .require_legacy_single_pool("driver run") + .unwrap_err(); + assert!(err.to_string().contains("driver run")); + assert!(err.to_string().contains("cluster.buckets")); + } } diff --git a/src/driver.rs b/src/driver.rs index 6c1c064..05e9531 100644 --- a/src/driver.rs +++ b/src/driver.rs @@ -49,6 +49,9 @@ struct InflightInfo { } pub fn run(config: &Config, output_subdir: Option<&str>) -> Result { + config + .cluster + .require_legacy_single_pool("driver run")?; let mut cluster = Cluster::new(config, &config.model); let mut q = EventQueue::new(); @@ -217,6 +220,8 @@ pub fn ablate_fixed_placement_with_parallelism( evict_policies: &[ReplayEvictPolicy], jobs: usize, ) -> Result> { + base.cluster + .require_legacy_single_pool("fixed-placement ablation")?; let mut out = Vec::new(); for &policy in evict_policies { if policy != ReplayEvictPolicy::Lru { diff --git a/src/main.rs b/src/main.rs index 37d6677..b174148 100644 --- a/src/main.rs +++ b/src/main.rs @@ -50,10 +50,14 @@ struct ConfigOverrides { } impl ConfigOverrides { - fn apply(&self, cfg: &mut Config) { + fn apply(&self, cfg: &mut Config) -> Result<()> { if let Some(n) = self.num_instances { + if !cfg.cluster.buckets.is_empty() { + anyhow::bail!( + "--num-instances does not support cluster.buckets until Task 2 lands" + ); + } cfg.cluster.num_instances = Some(n); - cfg.cluster.buckets.clear(); } if let Some(m) = self.max_requests { cfg.sim.max_requests = Some(m); @@ -79,6 +83,7 @@ impl ConfigOverrides { if let Some(hi) = self.input_length_max { cfg.sim.input_length_max = Some(hi); } + Ok(()) } } @@ -205,7 +210,7 @@ fn main() -> Result<()> { fn load(config: &PathBuf, overrides: &ConfigOverrides) -> Result { let mut cfg = Config::from_yaml_path(config)?; - overrides.apply(&mut cfg); + overrides.apply(&mut cfg)?; cfg.cluster.validate()?; Ok(cfg) } @@ -313,6 +318,9 @@ fn auto_select_instances( probe: RouterMode, target_ttft_mean: f64, ) -> Result { + base.cluster + .require_legacy_single_pool("auto-instances calibration")?; + #[derive(serde::Serialize)] struct CalibRow { num_instances: u32, @@ -334,7 +342,6 @@ fn auto_select_instances( for &n in candidates { let mut cfg = base.clone(); cfg.cluster.num_instances = Some(n); - cfg.cluster.buckets.clear(); cfg.cluster.router.mode = probe; // Isolate calibration output so ablation runs don't overwrite it. cfg.sim.output_dir = out_root @@ -464,6 +471,8 @@ fn cmd_oracle( out_path: Option<&std::path::Path>, ) -> Result<()> { let cfg = load(path, overrides)?; + cfg.cluster + .require_legacy_single_pool("oracle analysis")?; let block_bytes = cfg.model.kv_block_bytes() as f64; let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64; let aggregate_blocks = per_instance_blocks * cfg.cluster.total_instances() as u64; @@ -545,3 +554,123 @@ fn cmd_oracle( eprintln!("[oracle] wrote {}", target.display()); Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use kvcache_simulator::config::{ + BucketConfig, CalibrationConfig, ClusterConfig, GlobalRouterConfig, HardwareConfig, + MetaStoreConfig, ModelConfig, RouterConfig, SimConfig, + }; + + fn bucketed_config(out_dir: &str) -> Config { + Config { + model: ModelConfig { + name: "test".into(), + num_layers: 4, + num_kv_heads: 2, + head_dim: 64, + dtype_bytes: 2, + block_size_tokens: 16, + flops_per_token_prefill: Some(1.0e9), + attn_quadratic_coeff: Some(64.0), + ..Default::default() + }, + hardware: HardwareConfig { + gpu_flops: 1.0e14, + gpu_fp8_flops: 0.0, + gpu_fp4_flops: 0.0, + gpu_mem_bw: 1.0e12, + hbm_bytes: 1.0e9, + dram_bytes: 4.0e9, + host_dram_bw: 5.0e11, + pcie_bw: 32.0e9, + pcie_latency_us: 1.0, + rdma_bw: 12.0e9, + rdma_latency_us: 5.0, + intra_node_tp_bw: 9.0e11, + intra_node_tp_latency_us: 2.0, + tp_degree: 1, + max_batch_slots: 32, + prefill_chunk_tokens: 1024, + }, + calibration: CalibrationConfig::default(), + cluster: ClusterConfig { + num_instances: None, + buckets: vec![ + BucketConfig { + name: "short".into(), + input_length_min: 0, + input_length_max: 64, + num_instances: 2, + }, + BucketConfig { + name: "long".into(), + input_length_min: 65, + input_length_max: 128, + num_instances: 1, + }, + ], + global_router: GlobalRouterConfig::default(), + meta_store: MetaStoreConfig { + ttl_seconds: 1000.0, + }, + router: RouterConfig { + mode: RouterMode::Random, + precise_probe_latency_us: 10.0, + precise_probe_topk: 4, + load_alpha: 0.1, + score_alpha: 1.0, + score_beta: 0.1, + prefix_k: 8, + affinity_fan_out: 0, + }, + }, + sim: SimConfig { + trace_path: "unused.jsonl".into(), + max_requests: None, + output_dir: out_dir.into(), + sample_interval_s: 0.0, + seed: 7, + input_length_min: None, + input_length_max: None, + }, + } + } + + #[test] + fn num_instances_override_rejects_bucketed_configs() { + let mut cfg = bucketed_config(std::env::temp_dir().to_str().unwrap()); + let overrides = ConfigOverrides { + num_instances: Some(8), + ..ConfigOverrides::default() + }; + + let err = overrides.apply(&mut cfg).unwrap_err(); + assert!(err.to_string().contains("--num-instances")); + assert!(err.to_string().contains("cluster.buckets")); + } + + #[test] + fn auto_instances_rejects_bucketed_configs() { + let cfg = bucketed_config(std::env::temp_dir().to_str().unwrap()); + + let err = auto_select_instances(&cfg, &[4, 8], RouterMode::Random, 1.0).unwrap_err(); + assert!(err.to_string().contains("auto-instances")); + assert!(err.to_string().contains("cluster.buckets")); + } + + #[test] + fn oracle_rejects_bucketed_configs() { + let tmp = std::env::temp_dir().join("kvcache_sim_main_tests"); + let _ = std::fs::remove_dir_all(&tmp); + std::fs::create_dir_all(&tmp).unwrap(); + let path = tmp.join("bucketed.yaml"); + let cfg = bucketed_config(tmp.to_str().unwrap()); + std::fs::write(&path, serde_yaml::to_string(&cfg).unwrap()).unwrap(); + + let err = cmd_oracle(&path, &ConfigOverrides::default(), None, false, None).unwrap_err(); + assert!(err.to_string().contains("oracle analysis")); + assert!(err.to_string().contains("cluster.buckets")); + } +} diff --git a/src/replay.rs b/src/replay.rs index 2199dc8..a26b724 100644 --- a/src/replay.rs +++ b/src/replay.rs @@ -496,6 +496,8 @@ pub fn replay_fixed_placement( placements: &[PlacementEntry], policy: ReplayEvictPolicy, ) -> Result { + cfg.cluster + .require_legacy_single_pool("fixed-placement replay")?; if records.len() != placements.len() { return Err(anyhow!( "records/placements length mismatch: {} vs {}", diff --git a/tests/smoke.rs b/tests/smoke.rs index 33c81db..76374f4 100644 --- a/tests/smoke.rs +++ b/tests/smoke.rs @@ -6,7 +6,7 @@ use std::io::Write; use kvcache_simulator::config::*; use kvcache_simulator::driver; -use kvcache_simulator::replay::ReplayEvictPolicy; +use kvcache_simulator::replay::{self, PlacementEntry, ReplayEvictPolicy}; fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config { Config { @@ -70,6 +70,26 @@ fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config { } } +fn bucketed_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config { + let mut cfg = base_config(trace_path, out_dir, mode); + cfg.cluster.num_instances = None; + cfg.cluster.buckets = vec![ + BucketConfig { + name: "short".into(), + input_length_min: 0, + input_length_max: 64, + num_instances: 2, + }, + BucketConfig { + name: "long".into(), + input_length_min: 65, + input_length_max: 128, + num_instances: 1, + }, + ]; + cfg +} + fn write_synthetic_trace(path: &std::path::Path) { // 5 distinct conversations, each with 8 turns. Within a conversation, // turn k+1 reuses the prefix of turn k (shared first ~10 blocks) and @@ -274,3 +294,36 @@ fn ablation_parallel_matches_serial() { assert!((lhs.miss_rate - rhs.miss_rate).abs() < 1e-12); } } + +#[test] +fn bucketed_configs_are_rejected_by_legacy_runtime_paths() { + let tmp = std::env::temp_dir().join("kvcache_sim_bucketed_reject"); + let _ = std::fs::remove_dir_all(&tmp); + std::fs::create_dir_all(&tmp).unwrap(); + let trace_path = tmp.join("trace.jsonl"); + write_synthetic_trace(&trace_path); + + let cfg = bucketed_config( + trace_path.to_str().unwrap(), + tmp.to_str().unwrap(), + RouterMode::Random, + ); + + let result = driver::run(&cfg, Some("bucketed_guard")); + assert!(result.is_err(), "bucketed run should fail"); + let err = result.err().unwrap(); + assert!(err.to_string().contains("cluster.buckets")); + + let err = driver::ablate_fixed_placement(&cfg, &[RouterMode::Random], &[ReplayEvictPolicy::Lru]) + .expect_err("bucketed ablation should fail"); + assert!(err.to_string().contains("cluster.buckets")); + + let err = replay::replay_fixed_placement( + &cfg, + &[], + &Vec::::new(), + ReplayEvictPolicy::Lru, + ) + .expect_err("bucketed replay should fail"); + assert!(err.to_string().contains("cluster.buckets")); +}