use anyhow::{Context, Result}; use clap::{Args, Parser, Subcommand}; use std::path::PathBuf; use kvcache_simulator::config::{Config, RouterMode}; use kvcache_simulator::replay::ReplayEvictPolicy; use kvcache_simulator::{driver, oracle, trace::TraceReader}; #[derive(Debug, Parser)] #[command(name = "kvcache-sim", about = "Cluster-level KV cache simulator")] struct Cli { #[command(subcommand)] cmd: Cmd, } /// Optional CLI overrides applied on top of the YAML config so the same /// config can be reused across sweeps without editing the file. #[derive(Debug, Args, Clone, Default)] struct ConfigOverrides { /// Override `cluster.num_instances`. #[arg(long)] num_instances: Option, /// Override `sim.max_requests` (cap on processed trace records). #[arg(long)] max_requests: Option, /// Override `sim.trace_path`. #[arg(long)] trace: Option, /// Override `sim.output_dir`. #[arg(long)] output_dir: Option, /// Override `sim.seed`. #[arg(long)] seed: Option, /// Override `cluster.router.precise_probe_topk`. #[arg(long)] precise_topk: Option, /// Override `cluster.meta_store.ttl_seconds`. #[arg(long)] ttl_seconds: Option, /// Override `sim.input_length_min` — requests with `input_length` below /// this value are dropped from the replay. #[arg(long)] input_length_min: Option, /// Override `sim.input_length_max` — requests with `input_length` above /// this value are dropped from the replay. Combine with `--input-length-min` /// to carve out a specific input-length bucket for ablation. #[arg(long)] input_length_max: Option, } impl ConfigOverrides { fn apply(&self, cfg: &mut Config) { if let Some(n) = self.num_instances { cfg.cluster.num_instances = n; } if let Some(m) = self.max_requests { cfg.sim.max_requests = Some(m); } if let Some(t) = &self.trace { cfg.sim.trace_path = t.to_string_lossy().into_owned(); } if let Some(o) = &self.output_dir { cfg.sim.output_dir = o.to_string_lossy().into_owned(); } if let Some(s) = self.seed { cfg.sim.seed = s; } if let Some(k) = self.precise_topk { cfg.cluster.router.precise_probe_topk = k; } if let Some(ttl) = self.ttl_seconds { cfg.cluster.meta_store.ttl_seconds = ttl; } if let Some(lo) = self.input_length_min { cfg.sim.input_length_min = Some(lo); } if let Some(hi) = self.input_length_max { cfg.sim.input_length_max = Some(hi); } } } #[derive(Debug, Subcommand)] enum Cmd { /// Run a single simulation with the router specified in the config. Run { #[arg(short, long)] config: PathBuf, #[command(flatten)] overrides: ConfigOverrides, }, /// Run the same trace under multiple routers and fixed-placement eviction /// policies, then compare cache-hit summaries. Ablate { #[arg(short, long)] config: PathBuf, /// Comma-separated router modes #[arg( short, long, default_value = "random,least_loaded,least_tokens,ttl_aware,min_pd,cache_load,cache_score,estimated_ttft,prefix_affinity" )] routers: String, /// Comma-separated eviction policies for ablation aggregation. /// Currently only `lru` is supported. #[arg(long, default_value = "lru")] evict_policies: String, /// Sweep `num_instances` from `--auto-candidates` with the /// `--auto-probe-router` and pick the smallest cluster size whose /// TTFT mean ≤ `--auto-target-ttft-mean`. Overrides the YAML /// `num_instances` for the ablation run. #[arg(long, default_value_t = false)] auto_instances: bool, /// Target TTFT mean (seconds) for auto-instances calibration. #[arg(long, default_value_t = 4.0)] auto_target_ttft_mean: f64, /// Comma-separated candidate cluster sizes (ascending). #[arg(long, default_value = "4,8,16,24,32,48,64,96,128")] auto_candidates: String, /// Router used as the calibration baseline. The smallest candidate /// where this router's TTFT mean ≤ target is picked — all ablation /// routers are then run at that cluster size. #[arg(long, default_value = "cache_score")] auto_probe_router: String, #[command(flatten)] overrides: ConfigOverrides, }, /// Parse the config and trace head; do not run a simulation. Validate { #[arg(short, long)] config: PathBuf, #[command(flatten)] overrides: ConfigOverrides, }, /// Offline oracle analysis: theoretical hit-rate ceilings (unlimited /// cache and offline-optimal Belady eviction at finite capacity), plus /// LRU at the same capacity for comparison. Oracle { #[arg(short, long)] config: PathBuf, #[command(flatten)] overrides: ConfigOverrides, /// Cache capacity (in 16-token blocks) used for the Belady and LRU /// analyses. Defaults to `num_instances * per_instance_HBM_blocks` /// (the cluster-aggregate capacity). #[arg(long)] capacity_blocks: Option, /// Use the per-instance HBM block budget instead of the /// cluster-aggregate. Mutually exclusive with --capacity-blocks. #[arg(long, default_value_t = false)] per_instance: bool, /// Optional output JSON path. Defaults to `/oracle.json`. #[arg(long)] out: Option, }, } fn main() -> Result<()> { let cli = Cli::parse(); match cli.cmd { Cmd::Run { config, overrides } => cmd_run(&config, &overrides), Cmd::Ablate { config, routers, evict_policies, auto_instances, auto_target_ttft_mean, auto_candidates, auto_probe_router, overrides, } => cmd_ablate( &config, &routers, &evict_policies, auto_instances, auto_target_ttft_mean, &auto_candidates, &auto_probe_router, &overrides, ), Cmd::Validate { config, overrides } => cmd_validate(&config, &overrides), Cmd::Oracle { config, overrides, capacity_blocks, per_instance, out, } => cmd_oracle( &config, &overrides, capacity_blocks, per_instance, out.as_deref(), ), } } fn load(config: &PathBuf, overrides: &ConfigOverrides) -> Result { let mut cfg = Config::from_yaml_path(config)?; overrides.apply(&mut cfg); Ok(cfg) } fn cmd_run(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> { let cfg = load(path, overrides)?; let out = driver::run(&cfg, None)?; println!("{}", serde_json::to_string_pretty(&out.summary)?); Ok(()) } #[allow(clippy::too_many_arguments)] fn cmd_ablate( path: &PathBuf, routers: &str, evict_policies: &str, auto_instances: bool, auto_target_ttft_mean: f64, auto_candidates: &str, auto_probe_router: &str, overrides: &ConfigOverrides, ) -> Result<()> { let mut base = load(path, overrides)?; let modes: Vec = routers .split(',') .map(|s| s.trim()) .filter(|s| !s.is_empty()) .map(RouterMode::parse) .collect::>>() .with_context(|| format!("parsing --routers='{routers}'"))?; let policies: Vec = evict_policies .split(',') .map(|s| s.trim()) .filter(|s| !s.is_empty()) .map(ReplayEvictPolicy::parse) .collect::>>() .with_context(|| format!("parsing --evict-policies='{evict_policies}'"))?; if auto_instances { let candidates: Vec = auto_candidates .split(',') .map(|s| s.trim()) .filter(|s| !s.is_empty()) .map(|s| { s.parse::() .with_context(|| format!("parsing --auto-candidates entry '{s}'")) }) .collect::>>()?; if candidates.is_empty() { return Err(anyhow::anyhow!("--auto-candidates is empty")); } // Ascending so the first hit is the smallest cluster meeting the target. let mut sorted = candidates.clone(); sorted.sort_unstable(); let probe_mode = RouterMode::parse(auto_probe_router) .with_context(|| format!("parsing --auto-probe-router='{auto_probe_router}'"))?; let chosen = auto_select_instances( &base, &sorted, probe_mode, auto_target_ttft_mean, )?; eprintln!( "[ablate] auto-instances chose num_instances={} (target ttft_mean ≤ {:.3}s, probe_router={})", chosen, auto_target_ttft_mean, probe_mode.as_str() ); base.cluster.num_instances = chosen; } eprintln!( "[ablate] routers={} evict_policies={} num_instances={}", modes .iter() .map(RouterMode::as_str) .collect::>() .join(","), policies .iter() .map(ReplayEvictPolicy::as_str) .collect::>() .join(","), base.cluster.num_instances, ); let all = driver::ablate_fixed_placement(&base, &modes, &policies)?; let agg_path = std::path::Path::new(&base.sim.output_dir).join("ablation.json"); std::fs::create_dir_all(&base.sim.output_dir)?; std::fs::write(&agg_path, serde_json::to_string_pretty(&all)?)?; println!("{}", serde_json::to_string_pretty(&all)?); eprintln!("[ablate] wrote {}", agg_path.display()); Ok(()) } /// Sweep candidate cluster sizes ascending and return the smallest one whose /// TTFT mean under `probe` is ≤ `target_ttft_mean`. Per-candidate calibration /// summaries are written under `/auto_instances/` so the picked /// N is auditable. If no candidate meets the target, returns an error naming /// the best achievable TTFT. fn auto_select_instances( base: &Config, candidates: &[u32], probe: RouterMode, target_ttft_mean: f64, ) -> Result { #[derive(serde::Serialize)] struct CalibRow { num_instances: u32, router: String, ttft_mean: f64, ttft_p50: f64, ttft_p95: f64, ttft_p99: f64, num_requests: u64, hit_rate_l0: f64, passed: bool, } let out_root = std::path::Path::new(&base.sim.output_dir).join("auto_instances"); std::fs::create_dir_all(&out_root)?; let mut log: Vec = Vec::new(); let mut chosen: Option = None; for &n in candidates { let mut cfg = base.clone(); cfg.cluster.num_instances = n; cfg.cluster.router.mode = probe; // Isolate calibration output so ablation runs don't overwrite it. cfg.sim.output_dir = out_root .join(format!("n{n}__{}", probe.as_str())) .to_string_lossy() .into_owned(); eprintln!( "[auto-instances] probing num_instances={} router={} ...", n, probe.as_str() ); let run = driver::run(&cfg, None)?; let passed = run.summary.ttft_mean <= target_ttft_mean; eprintln!( "[auto-instances] ttft_mean={:.3}s p95={:.3}s hit_l0={:.4} -> {}", run.summary.ttft_mean, run.summary.ttft_p95, run.summary.hit_rate_l0, if passed { "PASS" } else { "fail" }, ); log.push(CalibRow { num_instances: n, router: probe.as_str().to_string(), ttft_mean: run.summary.ttft_mean, ttft_p50: run.summary.ttft_p50, ttft_p95: run.summary.ttft_p95, ttft_p99: run.summary.ttft_p99, num_requests: run.summary.num_requests, hit_rate_l0: run.summary.hit_rate_l0, passed, }); if passed && chosen.is_none() { chosen = Some(n); // Keep sweeping if you want a curve; here we stop at the smallest // passing N to satisfy the "not too small, but as small as it can // be while meeting the SLA" requirement. break; } } // Persist the calibration log either way so failures are debuggable. let log_path = out_root.join("calibration.json"); std::fs::write( &log_path, serde_json::to_string_pretty(&serde_json::json!({ "target_ttft_mean": target_ttft_mean, "probe_router": probe.as_str(), "candidates": candidates, "chosen": chosen, "runs": log, }))?, )?; eprintln!("[auto-instances] wrote {}", log_path.display()); chosen.ok_or_else(|| { let best = log .iter() .min_by(|a, b| a.ttft_mean.partial_cmp(&b.ttft_mean).unwrap()) .map(|r| (r.num_instances, r.ttft_mean)) .unwrap_or((0, f64::INFINITY)); anyhow::anyhow!( "no candidate met target ttft_mean ≤ {:.3}s; best was n={} at {:.3}s — \ widen --auto-candidates or raise --auto-target-ttft-mean", target_ttft_mean, best.0, best.1, ) }) } fn cmd_validate(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> { use kvcache_simulator::instance::compute::ComputeModel; let cfg = load(path, overrides)?; eprintln!("config OK: {}", cfg.model.name); eprintln!( "mode = {}", if cfg.model.is_arch_mode() { "architecture-derived" } else { "legacy manual" } ); let cm = ComputeModel::new(&cfg.model, &cfg.hardware, &cfg.calibration); eprintln!("compute: {}", cm.describe()); eprintln!( "kv_block_bytes = {} ({:.2} MB{})", cfg.model.kv_block_bytes(), cfg.model.kv_block_bytes() as f64 / 1e6, if cfg.model.mla.is_some() { ", MLA compressed" } else { "" }, ); let block_bytes = cfg.model.kv_block_bytes() as f64; let hbm_blocks = (cfg.hardware.hbm_bytes / block_bytes) as u64; let dram_blocks = (cfg.hardware.dram_bytes / block_bytes) as u64; eprintln!("per-instance HBM blocks = {hbm_blocks}, DRAM blocks = {dram_blocks}"); eprintln!("num_instances = {}", cfg.cluster.num_instances); // Sample prefill times at a few prompt lengths. eprintln!("prefill_time samples:"); for &n in &[256, 1024, 4096, 16384, 65536, 131072] { let t = cm.prefill_time(n); eprintln!(" N={n:>7} -> {t:.4} s"); } let reader = TraceReader::open(&cfg.sim.trace_path, Some(5))?; for rec in reader { let rec = rec?; eprintln!( " req {} chat={} t={:.3}s in={} out={} blocks={}", rec.req_id, rec.chat_id, rec.arrival, rec.input_len, rec.output_len, rec.hash_ids.len() ); } Ok(()) } fn cmd_oracle( path: &PathBuf, overrides: &ConfigOverrides, capacity_blocks: Option, per_instance: bool, out_path: Option<&std::path::Path>, ) -> Result<()> { let cfg = load(path, overrides)?; let block_bytes = cfg.model.kv_block_bytes() as f64; let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64; let aggregate_blocks = per_instance_blocks * cfg.cluster.num_instances as u64; let capacity = match (capacity_blocks, per_instance) { (Some(_), true) => { return Err(anyhow::anyhow!( "--capacity-blocks and --per-instance are mutually exclusive" )) } (Some(c), false) => c, (None, true) => per_instance_blocks, (None, false) => aggregate_blocks, }; eprintln!( "[oracle] loading trace {} (max_requests={:?})", cfg.sim.trace_path, cfg.sim.max_requests ); let reader = TraceReader::open(&cfg.sim.trace_path, cfg.sim.max_requests)?; let mut records: Vec<_> = reader.collect::, _>>()?; let raw_count = records.len(); driver::apply_input_length_filter(&mut records, &cfg.sim); if records.len() != raw_count { eprintln!( "[oracle] input_length filter [{}, {}] kept {}/{} requests", cfg.sim.input_length_min.unwrap_or(0), cfg.sim.input_length_max.map_or("∞".to_string(), |v| v.to_string()), records.len(), raw_count, ); } eprintln!( "[oracle] loaded {} requests; analyzing with capacity = {} blocks \ ({} per-instance × {} instances{})", records.len(), capacity, per_instance_blocks, cfg.cluster.num_instances, if per_instance { ", per-instance mode" } else { "" } ); let result = oracle::analyze(&records, capacity); let json = serde_json::to_string_pretty(&result)?; println!("{}", json); let target = match out_path { Some(p) => p.to_path_buf(), None => std::path::Path::new(&cfg.sim.output_dir).join("oracle.json"), }; if let Some(parent) = target.parent() { std::fs::create_dir_all(parent)?; } std::fs::write(&target, &json)?; eprintln!("[oracle] wrote {}", target.display()); Ok(()) }