Files
kvcache-simulator/src/main.rs
Gahow Wang c86d931d8f feat(ablate): input-length bucketing + auto-instance sizing
- Add sim.input_length_{min,max} (+ CLI overrides) that drop requests
  outside the bucket after trace load, enabling per-bucket ablation
  (e.g. 0-40k) without rewriting the trace file. Applied uniformly in
  both `run`/`ablate` driver path and `oracle` analysis.

- Add cache_score_strong router (alpha=1, beta=1) to isolate how much
  of cache_affinity's win is reproducible by just retuning beta in the
  existing cache_score framework (no rendezvous, no meta-store bonus).

- Add --auto-instances to ablate: sweeps --auto-candidates ascending
  with --auto-probe-router and picks the smallest cluster size whose
  TTFT mean <= --auto-target-ttft-mean. Per-candidate calibration
  results are persisted under runs/<output_dir>/auto_instances/ so the
  pick is auditable; the chosen N is then used for the whole ablation.
2026-04-15 19:42:28 +08:00

516 lines
18 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use anyhow::{Context, Result};
use clap::{Args, Parser, Subcommand};
use std::path::PathBuf;
use kvcache_simulator::config::{Config, RouterMode};
use kvcache_simulator::replay::ReplayEvictPolicy;
use kvcache_simulator::{driver, oracle, trace::TraceReader};
#[derive(Debug, Parser)]
#[command(name = "kvcache-sim", about = "Cluster-level KV cache simulator")]
struct Cli {
#[command(subcommand)]
cmd: Cmd,
}
/// Optional CLI overrides applied on top of the YAML config so the same
/// config can be reused across sweeps without editing the file.
#[derive(Debug, Args, Clone, Default)]
struct ConfigOverrides {
/// Override `cluster.num_instances`.
#[arg(long)]
num_instances: Option<u32>,
/// Override `sim.max_requests` (cap on processed trace records).
#[arg(long)]
max_requests: Option<u64>,
/// Override `sim.trace_path`.
#[arg(long)]
trace: Option<PathBuf>,
/// Override `sim.output_dir`.
#[arg(long)]
output_dir: Option<PathBuf>,
/// Override `sim.seed`.
#[arg(long)]
seed: Option<u64>,
/// Override `cluster.router.precise_probe_topk`.
#[arg(long)]
precise_topk: Option<u32>,
/// Override `cluster.meta_store.ttl_seconds`.
#[arg(long)]
ttl_seconds: Option<f64>,
/// Override `sim.input_length_min` — requests with `input_length` below
/// this value are dropped from the replay.
#[arg(long)]
input_length_min: Option<u32>,
/// Override `sim.input_length_max` — requests with `input_length` above
/// this value are dropped from the replay. Combine with `--input-length-min`
/// to carve out a specific input-length bucket for ablation.
#[arg(long)]
input_length_max: Option<u32>,
}
impl ConfigOverrides {
fn apply(&self, cfg: &mut Config) {
if let Some(n) = self.num_instances {
cfg.cluster.num_instances = n;
}
if let Some(m) = self.max_requests {
cfg.sim.max_requests = Some(m);
}
if let Some(t) = &self.trace {
cfg.sim.trace_path = t.to_string_lossy().into_owned();
}
if let Some(o) = &self.output_dir {
cfg.sim.output_dir = o.to_string_lossy().into_owned();
}
if let Some(s) = self.seed {
cfg.sim.seed = s;
}
if let Some(k) = self.precise_topk {
cfg.cluster.router.precise_probe_topk = k;
}
if let Some(ttl) = self.ttl_seconds {
cfg.cluster.meta_store.ttl_seconds = ttl;
}
if let Some(lo) = self.input_length_min {
cfg.sim.input_length_min = Some(lo);
}
if let Some(hi) = self.input_length_max {
cfg.sim.input_length_max = Some(hi);
}
}
}
#[derive(Debug, Subcommand)]
enum Cmd {
/// Run a single simulation with the router specified in the config.
Run {
#[arg(short, long)]
config: PathBuf,
#[command(flatten)]
overrides: ConfigOverrides,
},
/// Run the same trace under multiple routers and fixed-placement eviction
/// policies, then compare cache-hit summaries.
Ablate {
#[arg(short, long)]
config: PathBuf,
/// Comma-separated router modes
#[arg(
short,
long,
default_value = "random,least_loaded,least_tokens,ttl_aware,min_pd,cache_load,cache_score,estimated_ttft,prefix_affinity"
)]
routers: String,
/// Comma-separated eviction policies for ablation aggregation.
/// Currently only `lru` is supported.
#[arg(long, default_value = "lru")]
evict_policies: String,
/// Sweep `num_instances` from `--auto-candidates` with the
/// `--auto-probe-router` and pick the smallest cluster size whose
/// TTFT mean ≤ `--auto-target-ttft-mean`. Overrides the YAML
/// `num_instances` for the ablation run.
#[arg(long, default_value_t = false)]
auto_instances: bool,
/// Target TTFT mean (seconds) for auto-instances calibration.
#[arg(long, default_value_t = 4.0)]
auto_target_ttft_mean: f64,
/// Comma-separated candidate cluster sizes (ascending).
#[arg(long, default_value = "4,8,16,24,32,48,64,96,128")]
auto_candidates: String,
/// Router used as the calibration baseline. The smallest candidate
/// where this router's TTFT mean ≤ target is picked — all ablation
/// routers are then run at that cluster size.
#[arg(long, default_value = "cache_score")]
auto_probe_router: String,
#[command(flatten)]
overrides: ConfigOverrides,
},
/// Parse the config and trace head; do not run a simulation.
Validate {
#[arg(short, long)]
config: PathBuf,
#[command(flatten)]
overrides: ConfigOverrides,
},
/// Offline oracle analysis: theoretical hit-rate ceilings (unlimited
/// cache and offline-optimal Belady eviction at finite capacity), plus
/// LRU at the same capacity for comparison.
Oracle {
#[arg(short, long)]
config: PathBuf,
#[command(flatten)]
overrides: ConfigOverrides,
/// Cache capacity (in 16-token blocks) used for the Belady and LRU
/// analyses. Defaults to `num_instances * per_instance_HBM_blocks`
/// (the cluster-aggregate capacity).
#[arg(long)]
capacity_blocks: Option<u64>,
/// Use the per-instance HBM block budget instead of the
/// cluster-aggregate. Mutually exclusive with --capacity-blocks.
#[arg(long, default_value_t = false)]
per_instance: bool,
/// Optional output JSON path. Defaults to `<output_dir>/oracle.json`.
#[arg(long)]
out: Option<PathBuf>,
},
}
fn main() -> Result<()> {
let cli = Cli::parse();
match cli.cmd {
Cmd::Run { config, overrides } => cmd_run(&config, &overrides),
Cmd::Ablate {
config,
routers,
evict_policies,
auto_instances,
auto_target_ttft_mean,
auto_candidates,
auto_probe_router,
overrides,
} => cmd_ablate(
&config,
&routers,
&evict_policies,
auto_instances,
auto_target_ttft_mean,
&auto_candidates,
&auto_probe_router,
&overrides,
),
Cmd::Validate { config, overrides } => cmd_validate(&config, &overrides),
Cmd::Oracle {
config,
overrides,
capacity_blocks,
per_instance,
out,
} => cmd_oracle(
&config,
&overrides,
capacity_blocks,
per_instance,
out.as_deref(),
),
}
}
fn load(config: &PathBuf, overrides: &ConfigOverrides) -> Result<Config> {
let mut cfg = Config::from_yaml_path(config)?;
overrides.apply(&mut cfg);
Ok(cfg)
}
fn cmd_run(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> {
let cfg = load(path, overrides)?;
let out = driver::run(&cfg, None)?;
println!("{}", serde_json::to_string_pretty(&out.summary)?);
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn cmd_ablate(
path: &PathBuf,
routers: &str,
evict_policies: &str,
auto_instances: bool,
auto_target_ttft_mean: f64,
auto_candidates: &str,
auto_probe_router: &str,
overrides: &ConfigOverrides,
) -> Result<()> {
let mut base = load(path, overrides)?;
let modes: Vec<RouterMode> = routers
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(RouterMode::parse)
.collect::<Result<Vec<_>>>()
.with_context(|| format!("parsing --routers='{routers}'"))?;
let policies: Vec<ReplayEvictPolicy> = evict_policies
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(ReplayEvictPolicy::parse)
.collect::<Result<Vec<_>>>()
.with_context(|| format!("parsing --evict-policies='{evict_policies}'"))?;
if auto_instances {
let candidates: Vec<u32> = auto_candidates
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| {
s.parse::<u32>()
.with_context(|| format!("parsing --auto-candidates entry '{s}'"))
})
.collect::<Result<Vec<_>>>()?;
if candidates.is_empty() {
return Err(anyhow::anyhow!("--auto-candidates is empty"));
}
// Ascending so the first hit is the smallest cluster meeting the target.
let mut sorted = candidates.clone();
sorted.sort_unstable();
let probe_mode = RouterMode::parse(auto_probe_router)
.with_context(|| format!("parsing --auto-probe-router='{auto_probe_router}'"))?;
let chosen = auto_select_instances(
&base,
&sorted,
probe_mode,
auto_target_ttft_mean,
)?;
eprintln!(
"[ablate] auto-instances chose num_instances={} (target ttft_mean ≤ {:.3}s, probe_router={})",
chosen,
auto_target_ttft_mean,
probe_mode.as_str()
);
base.cluster.num_instances = chosen;
}
eprintln!(
"[ablate] routers={} evict_policies={} num_instances={}",
modes
.iter()
.map(RouterMode::as_str)
.collect::<Vec<_>>()
.join(","),
policies
.iter()
.map(ReplayEvictPolicy::as_str)
.collect::<Vec<_>>()
.join(","),
base.cluster.num_instances,
);
let all = driver::ablate_fixed_placement(&base, &modes, &policies)?;
let agg_path = std::path::Path::new(&base.sim.output_dir).join("ablation.json");
std::fs::create_dir_all(&base.sim.output_dir)?;
std::fs::write(&agg_path, serde_json::to_string_pretty(&all)?)?;
println!("{}", serde_json::to_string_pretty(&all)?);
eprintln!("[ablate] wrote {}", agg_path.display());
Ok(())
}
/// Sweep candidate cluster sizes ascending and return the smallest one whose
/// TTFT mean under `probe` is ≤ `target_ttft_mean`. Per-candidate calibration
/// summaries are written under `<output_dir>/auto_instances/` so the picked
/// N is auditable. If no candidate meets the target, returns an error naming
/// the best achievable TTFT.
fn auto_select_instances(
base: &Config,
candidates: &[u32],
probe: RouterMode,
target_ttft_mean: f64,
) -> Result<u32> {
#[derive(serde::Serialize)]
struct CalibRow {
num_instances: u32,
router: String,
ttft_mean: f64,
ttft_p50: f64,
ttft_p95: f64,
ttft_p99: f64,
num_requests: u64,
hit_rate_l0: f64,
passed: bool,
}
let out_root = std::path::Path::new(&base.sim.output_dir).join("auto_instances");
std::fs::create_dir_all(&out_root)?;
let mut log: Vec<CalibRow> = Vec::new();
let mut chosen: Option<u32> = None;
for &n in candidates {
let mut cfg = base.clone();
cfg.cluster.num_instances = n;
cfg.cluster.router.mode = probe;
// Isolate calibration output so ablation runs don't overwrite it.
cfg.sim.output_dir = out_root
.join(format!("n{n}__{}", probe.as_str()))
.to_string_lossy()
.into_owned();
eprintln!(
"[auto-instances] probing num_instances={} router={} ...",
n,
probe.as_str()
);
let run = driver::run(&cfg, None)?;
let passed = run.summary.ttft_mean <= target_ttft_mean;
eprintln!(
"[auto-instances] ttft_mean={:.3}s p95={:.3}s hit_l0={:.4} -> {}",
run.summary.ttft_mean,
run.summary.ttft_p95,
run.summary.hit_rate_l0,
if passed { "PASS" } else { "fail" },
);
log.push(CalibRow {
num_instances: n,
router: probe.as_str().to_string(),
ttft_mean: run.summary.ttft_mean,
ttft_p50: run.summary.ttft_p50,
ttft_p95: run.summary.ttft_p95,
ttft_p99: run.summary.ttft_p99,
num_requests: run.summary.num_requests,
hit_rate_l0: run.summary.hit_rate_l0,
passed,
});
if passed && chosen.is_none() {
chosen = Some(n);
// Keep sweeping if you want a curve; here we stop at the smallest
// passing N to satisfy the "not too small, but as small as it can
// be while meeting the SLA" requirement.
break;
}
}
// Persist the calibration log either way so failures are debuggable.
let log_path = out_root.join("calibration.json");
std::fs::write(
&log_path,
serde_json::to_string_pretty(&serde_json::json!({
"target_ttft_mean": target_ttft_mean,
"probe_router": probe.as_str(),
"candidates": candidates,
"chosen": chosen,
"runs": log,
}))?,
)?;
eprintln!("[auto-instances] wrote {}", log_path.display());
chosen.ok_or_else(|| {
let best = log
.iter()
.min_by(|a, b| a.ttft_mean.partial_cmp(&b.ttft_mean).unwrap())
.map(|r| (r.num_instances, r.ttft_mean))
.unwrap_or((0, f64::INFINITY));
anyhow::anyhow!(
"no candidate met target ttft_mean ≤ {:.3}s; best was n={} at {:.3}s — \
widen --auto-candidates or raise --auto-target-ttft-mean",
target_ttft_mean,
best.0,
best.1,
)
})
}
fn cmd_validate(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> {
use kvcache_simulator::instance::compute::ComputeModel;
let cfg = load(path, overrides)?;
eprintln!("config OK: {}", cfg.model.name);
eprintln!(
"mode = {}",
if cfg.model.is_arch_mode() {
"architecture-derived"
} else {
"legacy manual"
}
);
let cm = ComputeModel::new(&cfg.model, &cfg.hardware, &cfg.calibration);
eprintln!("compute: {}", cm.describe());
eprintln!(
"kv_block_bytes = {} ({:.2} MB{})",
cfg.model.kv_block_bytes(),
cfg.model.kv_block_bytes() as f64 / 1e6,
if cfg.model.mla.is_some() {
", MLA compressed"
} else {
""
},
);
let block_bytes = cfg.model.kv_block_bytes() as f64;
let hbm_blocks = (cfg.hardware.hbm_bytes / block_bytes) as u64;
let dram_blocks = (cfg.hardware.dram_bytes / block_bytes) as u64;
eprintln!("per-instance HBM blocks = {hbm_blocks}, DRAM blocks = {dram_blocks}");
eprintln!("num_instances = {}", cfg.cluster.num_instances);
// Sample prefill times at a few prompt lengths.
eprintln!("prefill_time samples:");
for &n in &[256, 1024, 4096, 16384, 65536, 131072] {
let t = cm.prefill_time(n);
eprintln!(" N={n:>7} -> {t:.4} s");
}
let reader = TraceReader::open(&cfg.sim.trace_path, Some(5))?;
for rec in reader {
let rec = rec?;
eprintln!(
" req {} chat={} t={:.3}s in={} out={} blocks={}",
rec.req_id,
rec.chat_id,
rec.arrival,
rec.input_len,
rec.output_len,
rec.hash_ids.len()
);
}
Ok(())
}
fn cmd_oracle(
path: &PathBuf,
overrides: &ConfigOverrides,
capacity_blocks: Option<u64>,
per_instance: bool,
out_path: Option<&std::path::Path>,
) -> Result<()> {
let cfg = load(path, overrides)?;
let block_bytes = cfg.model.kv_block_bytes() as f64;
let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64;
let aggregate_blocks = per_instance_blocks * cfg.cluster.num_instances as u64;
let capacity = match (capacity_blocks, per_instance) {
(Some(_), true) => {
return Err(anyhow::anyhow!(
"--capacity-blocks and --per-instance are mutually exclusive"
))
}
(Some(c), false) => c,
(None, true) => per_instance_blocks,
(None, false) => aggregate_blocks,
};
eprintln!(
"[oracle] loading trace {} (max_requests={:?})",
cfg.sim.trace_path, cfg.sim.max_requests
);
let reader = TraceReader::open(&cfg.sim.trace_path, cfg.sim.max_requests)?;
let mut records: Vec<_> = reader.collect::<Result<Vec<_>, _>>()?;
let raw_count = records.len();
driver::apply_input_length_filter(&mut records, &cfg.sim);
if records.len() != raw_count {
eprintln!(
"[oracle] input_length filter [{}, {}] kept {}/{} requests",
cfg.sim.input_length_min.unwrap_or(0),
cfg.sim.input_length_max.map_or("".to_string(), |v| v.to_string()),
records.len(),
raw_count,
);
}
eprintln!(
"[oracle] loaded {} requests; analyzing with capacity = {} blocks \
({} per-instance × {} instances{})",
records.len(),
capacity,
per_instance_blocks,
cfg.cluster.num_instances,
if per_instance {
", per-instance mode"
} else {
""
}
);
let result = oracle::analyze(&records, capacity);
let json = serde_json::to_string_pretty(&result)?;
println!("{}", json);
let target = match out_path {
Some(p) => p.to_path_buf(),
None => std::path::Path::new(&cfg.sim.output_dir).join("oracle.json"),
};
if let Some(parent) = target.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(&target, &json)?;
eprintln!("[oracle] wrote {}", target.display());
Ok(())
}