525 lines
18 KiB
Rust
525 lines
18 KiB
Rust
use anyhow::{Context, Result};
|
||
use clap::{Args, Parser, Subcommand};
|
||
use std::path::PathBuf;
|
||
|
||
use kvcache_simulator::config::{Config, RouterMode};
|
||
use kvcache_simulator::replay::ReplayEvictPolicy;
|
||
use kvcache_simulator::{driver, oracle, trace::TraceReader};
|
||
|
||
#[derive(Debug, Parser)]
|
||
#[command(name = "kvcache-sim", about = "Cluster-level KV cache simulator")]
|
||
struct Cli {
|
||
#[command(subcommand)]
|
||
cmd: Cmd,
|
||
}
|
||
|
||
/// Optional CLI overrides applied on top of the YAML config so the same
|
||
/// config can be reused across sweeps without editing the file.
|
||
#[derive(Debug, Args, Clone, Default)]
|
||
struct ConfigOverrides {
|
||
/// Override `cluster.num_instances`.
|
||
#[arg(long)]
|
||
num_instances: Option<u32>,
|
||
/// Override `sim.max_requests` (cap on processed trace records).
|
||
#[arg(long)]
|
||
max_requests: Option<u64>,
|
||
/// Override `sim.trace_path`.
|
||
#[arg(long)]
|
||
trace: Option<PathBuf>,
|
||
/// Override `sim.output_dir`.
|
||
#[arg(long)]
|
||
output_dir: Option<PathBuf>,
|
||
/// Override `sim.seed`.
|
||
#[arg(long)]
|
||
seed: Option<u64>,
|
||
/// Override `cluster.router.precise_probe_topk`.
|
||
#[arg(long)]
|
||
precise_topk: Option<u32>,
|
||
/// Override `cluster.meta_store.ttl_seconds`.
|
||
#[arg(long)]
|
||
ttl_seconds: Option<f64>,
|
||
/// Override `sim.input_length_min` — requests with `input_length` below
|
||
/// this value are dropped from the replay.
|
||
#[arg(long)]
|
||
input_length_min: Option<u32>,
|
||
/// Override `sim.input_length_max` — requests with `input_length` above
|
||
/// this value are dropped from the replay. Combine with `--input-length-min`
|
||
/// to carve out a specific input-length bucket for ablation.
|
||
#[arg(long)]
|
||
input_length_max: Option<u32>,
|
||
}
|
||
|
||
impl ConfigOverrides {
|
||
fn apply(&self, cfg: &mut Config) {
|
||
if let Some(n) = self.num_instances {
|
||
cfg.cluster.num_instances = n;
|
||
}
|
||
if let Some(m) = self.max_requests {
|
||
cfg.sim.max_requests = Some(m);
|
||
}
|
||
if let Some(t) = &self.trace {
|
||
cfg.sim.trace_path = t.to_string_lossy().into_owned();
|
||
}
|
||
if let Some(o) = &self.output_dir {
|
||
cfg.sim.output_dir = o.to_string_lossy().into_owned();
|
||
}
|
||
if let Some(s) = self.seed {
|
||
cfg.sim.seed = s;
|
||
}
|
||
if let Some(k) = self.precise_topk {
|
||
cfg.cluster.router.precise_probe_topk = k;
|
||
}
|
||
if let Some(ttl) = self.ttl_seconds {
|
||
cfg.cluster.meta_store.ttl_seconds = ttl;
|
||
}
|
||
if let Some(lo) = self.input_length_min {
|
||
cfg.sim.input_length_min = Some(lo);
|
||
}
|
||
if let Some(hi) = self.input_length_max {
|
||
cfg.sim.input_length_max = Some(hi);
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, Subcommand)]
|
||
enum Cmd {
|
||
/// Run a single simulation with the router specified in the config.
|
||
Run {
|
||
#[arg(short, long)]
|
||
config: PathBuf,
|
||
#[command(flatten)]
|
||
overrides: ConfigOverrides,
|
||
},
|
||
/// Run the same trace under multiple routers and fixed-placement eviction
|
||
/// policies, then compare cache-hit summaries.
|
||
Ablate {
|
||
#[arg(short, long)]
|
||
config: PathBuf,
|
||
/// Comma-separated router modes
|
||
#[arg(
|
||
short,
|
||
long,
|
||
default_value = "random,least_loaded,least_tokens,ttl_aware,min_pd,cache_load,cache_score,estimated_ttft,prefix_affinity"
|
||
)]
|
||
routers: String,
|
||
/// Comma-separated eviction policies for ablation aggregation.
|
||
/// Currently only `lru` is supported.
|
||
#[arg(long, default_value = "lru")]
|
||
evict_policies: String,
|
||
/// Sweep `num_instances` from `--auto-candidates` with the
|
||
/// `--auto-probe-router` and pick the smallest cluster size whose
|
||
/// TTFT mean ≤ `--auto-target-ttft-mean`. Overrides the YAML
|
||
/// `num_instances` for the ablation run.
|
||
#[arg(long, default_value_t = false)]
|
||
auto_instances: bool,
|
||
/// Target TTFT mean (seconds) for auto-instances calibration.
|
||
#[arg(long, default_value_t = 4.0)]
|
||
auto_target_ttft_mean: f64,
|
||
/// Comma-separated candidate cluster sizes (ascending).
|
||
#[arg(long, default_value = "4,8,16,24,32,48,64,96,128")]
|
||
auto_candidates: String,
|
||
/// Router used as the calibration baseline. The smallest candidate
|
||
/// where this router's TTFT mean ≤ target is picked — all ablation
|
||
/// routers are then run at that cluster size.
|
||
#[arg(long, default_value = "cache_score")]
|
||
auto_probe_router: String,
|
||
/// Maximum number of routers to simulate concurrently.
|
||
/// `0` means auto-detect from available CPU parallelism.
|
||
#[arg(long, default_value_t = 0)]
|
||
jobs: usize,
|
||
#[command(flatten)]
|
||
overrides: ConfigOverrides,
|
||
},
|
||
/// Parse the config and trace head; do not run a simulation.
|
||
Validate {
|
||
#[arg(short, long)]
|
||
config: PathBuf,
|
||
#[command(flatten)]
|
||
overrides: ConfigOverrides,
|
||
},
|
||
/// Offline oracle analysis: theoretical hit-rate ceilings (unlimited
|
||
/// cache and offline-optimal Belady eviction at finite capacity), plus
|
||
/// LRU at the same capacity for comparison.
|
||
Oracle {
|
||
#[arg(short, long)]
|
||
config: PathBuf,
|
||
#[command(flatten)]
|
||
overrides: ConfigOverrides,
|
||
/// Cache capacity (in 16-token blocks) used for the Belady and LRU
|
||
/// analyses. Defaults to `num_instances * per_instance_HBM_blocks`
|
||
/// (the cluster-aggregate capacity).
|
||
#[arg(long)]
|
||
capacity_blocks: Option<u64>,
|
||
/// Use the per-instance HBM block budget instead of the
|
||
/// cluster-aggregate. Mutually exclusive with --capacity-blocks.
|
||
#[arg(long, default_value_t = false)]
|
||
per_instance: bool,
|
||
/// Optional output JSON path. Defaults to `<output_dir>/oracle.json`.
|
||
#[arg(long)]
|
||
out: Option<PathBuf>,
|
||
},
|
||
}
|
||
|
||
fn main() -> Result<()> {
|
||
let cli = Cli::parse();
|
||
match cli.cmd {
|
||
Cmd::Run { config, overrides } => cmd_run(&config, &overrides),
|
||
Cmd::Ablate {
|
||
config,
|
||
routers,
|
||
evict_policies,
|
||
auto_instances,
|
||
auto_target_ttft_mean,
|
||
auto_candidates,
|
||
auto_probe_router,
|
||
jobs,
|
||
overrides,
|
||
} => cmd_ablate(
|
||
&config,
|
||
&routers,
|
||
&evict_policies,
|
||
auto_instances,
|
||
auto_target_ttft_mean,
|
||
&auto_candidates,
|
||
&auto_probe_router,
|
||
jobs,
|
||
&overrides,
|
||
),
|
||
Cmd::Validate { config, overrides } => cmd_validate(&config, &overrides),
|
||
Cmd::Oracle {
|
||
config,
|
||
overrides,
|
||
capacity_blocks,
|
||
per_instance,
|
||
out,
|
||
} => cmd_oracle(
|
||
&config,
|
||
&overrides,
|
||
capacity_blocks,
|
||
per_instance,
|
||
out.as_deref(),
|
||
),
|
||
}
|
||
}
|
||
|
||
fn load(config: &PathBuf, overrides: &ConfigOverrides) -> Result<Config> {
|
||
let mut cfg = Config::from_yaml_path(config)?;
|
||
overrides.apply(&mut cfg);
|
||
Ok(cfg)
|
||
}
|
||
|
||
fn cmd_run(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> {
|
||
let cfg = load(path, overrides)?;
|
||
let out = driver::run(&cfg, None)?;
|
||
println!("{}", serde_json::to_string_pretty(&out.summary)?);
|
||
Ok(())
|
||
}
|
||
|
||
#[allow(clippy::too_many_arguments)]
|
||
fn cmd_ablate(
|
||
path: &PathBuf,
|
||
routers: &str,
|
||
evict_policies: &str,
|
||
auto_instances: bool,
|
||
auto_target_ttft_mean: f64,
|
||
auto_candidates: &str,
|
||
auto_probe_router: &str,
|
||
jobs: usize,
|
||
overrides: &ConfigOverrides,
|
||
) -> Result<()> {
|
||
let mut base = load(path, overrides)?;
|
||
let modes: Vec<RouterMode> = routers
|
||
.split(',')
|
||
.map(|s| s.trim())
|
||
.filter(|s| !s.is_empty())
|
||
.map(RouterMode::parse)
|
||
.collect::<Result<Vec<_>>>()
|
||
.with_context(|| format!("parsing --routers='{routers}'"))?;
|
||
let policies: Vec<ReplayEvictPolicy> = evict_policies
|
||
.split(',')
|
||
.map(|s| s.trim())
|
||
.filter(|s| !s.is_empty())
|
||
.map(ReplayEvictPolicy::parse)
|
||
.collect::<Result<Vec<_>>>()
|
||
.with_context(|| format!("parsing --evict-policies='{evict_policies}'"))?;
|
||
|
||
if auto_instances {
|
||
let candidates: Vec<u32> = auto_candidates
|
||
.split(',')
|
||
.map(|s| s.trim())
|
||
.filter(|s| !s.is_empty())
|
||
.map(|s| {
|
||
s.parse::<u32>()
|
||
.with_context(|| format!("parsing --auto-candidates entry '{s}'"))
|
||
})
|
||
.collect::<Result<Vec<_>>>()?;
|
||
if candidates.is_empty() {
|
||
return Err(anyhow::anyhow!("--auto-candidates is empty"));
|
||
}
|
||
// Ascending so the first hit is the smallest cluster meeting the target.
|
||
let mut sorted = candidates.clone();
|
||
sorted.sort_unstable();
|
||
let probe_mode = RouterMode::parse(auto_probe_router)
|
||
.with_context(|| format!("parsing --auto-probe-router='{auto_probe_router}'"))?;
|
||
let chosen = auto_select_instances(&base, &sorted, probe_mode, auto_target_ttft_mean)?;
|
||
eprintln!(
|
||
"[ablate] auto-instances chose num_instances={} (target ttft_mean ≤ {:.3}s, probe_router={})",
|
||
chosen,
|
||
auto_target_ttft_mean,
|
||
probe_mode.as_str()
|
||
);
|
||
base.cluster.num_instances = chosen;
|
||
}
|
||
|
||
eprintln!(
|
||
"[ablate] routers={} evict_policies={} num_instances={} jobs={}",
|
||
modes
|
||
.iter()
|
||
.map(RouterMode::as_str)
|
||
.collect::<Vec<_>>()
|
||
.join(","),
|
||
policies
|
||
.iter()
|
||
.map(ReplayEvictPolicy::as_str)
|
||
.collect::<Vec<_>>()
|
||
.join(","),
|
||
base.cluster.num_instances,
|
||
if jobs == 0 {
|
||
"auto".to_string()
|
||
} else {
|
||
jobs.to_string()
|
||
},
|
||
);
|
||
let all = driver::ablate_fixed_placement_with_parallelism(&base, &modes, &policies, jobs)?;
|
||
let agg_path = std::path::Path::new(&base.sim.output_dir).join("ablation.json");
|
||
std::fs::create_dir_all(&base.sim.output_dir)?;
|
||
std::fs::write(&agg_path, serde_json::to_string_pretty(&all)?)?;
|
||
println!("{}", serde_json::to_string_pretty(&all)?);
|
||
eprintln!("[ablate] wrote {}", agg_path.display());
|
||
Ok(())
|
||
}
|
||
|
||
/// Sweep candidate cluster sizes ascending and return the smallest one whose
|
||
/// TTFT mean under `probe` is ≤ `target_ttft_mean`. Per-candidate calibration
|
||
/// summaries are written under `<output_dir>/auto_instances/` so the picked
|
||
/// N is auditable. If no candidate meets the target, returns an error naming
|
||
/// the best achievable TTFT.
|
||
fn auto_select_instances(
|
||
base: &Config,
|
||
candidates: &[u32],
|
||
probe: RouterMode,
|
||
target_ttft_mean: f64,
|
||
) -> Result<u32> {
|
||
#[derive(serde::Serialize)]
|
||
struct CalibRow {
|
||
num_instances: u32,
|
||
router: String,
|
||
ttft_mean: f64,
|
||
ttft_p50: f64,
|
||
ttft_p95: f64,
|
||
ttft_p99: f64,
|
||
num_requests: u64,
|
||
hit_rate_l0: f64,
|
||
passed: bool,
|
||
}
|
||
|
||
let out_root = std::path::Path::new(&base.sim.output_dir).join("auto_instances");
|
||
std::fs::create_dir_all(&out_root)?;
|
||
let mut log: Vec<CalibRow> = Vec::new();
|
||
let mut chosen: Option<u32> = None;
|
||
|
||
for &n in candidates {
|
||
let mut cfg = base.clone();
|
||
cfg.cluster.num_instances = n;
|
||
cfg.cluster.router.mode = probe;
|
||
// Isolate calibration output so ablation runs don't overwrite it.
|
||
cfg.sim.output_dir = out_root
|
||
.join(format!("n{n}__{}", probe.as_str()))
|
||
.to_string_lossy()
|
||
.into_owned();
|
||
eprintln!(
|
||
"[auto-instances] probing num_instances={} router={} ...",
|
||
n,
|
||
probe.as_str()
|
||
);
|
||
let run = driver::run(&cfg, None)?;
|
||
let passed = run.summary.ttft_mean <= target_ttft_mean;
|
||
eprintln!(
|
||
"[auto-instances] ttft_mean={:.3}s p95={:.3}s hit_l0={:.4} -> {}",
|
||
run.summary.ttft_mean,
|
||
run.summary.ttft_p95,
|
||
run.summary.hit_rate_l0,
|
||
if passed { "PASS" } else { "fail" },
|
||
);
|
||
log.push(CalibRow {
|
||
num_instances: n,
|
||
router: probe.as_str().to_string(),
|
||
ttft_mean: run.summary.ttft_mean,
|
||
ttft_p50: run.summary.ttft_p50,
|
||
ttft_p95: run.summary.ttft_p95,
|
||
ttft_p99: run.summary.ttft_p99,
|
||
num_requests: run.summary.num_requests,
|
||
hit_rate_l0: run.summary.hit_rate_l0,
|
||
passed,
|
||
});
|
||
if passed && chosen.is_none() {
|
||
chosen = Some(n);
|
||
// Keep sweeping if you want a curve; here we stop at the smallest
|
||
// passing N to satisfy the "not too small, but as small as it can
|
||
// be while meeting the SLA" requirement.
|
||
break;
|
||
}
|
||
}
|
||
|
||
// Persist the calibration log either way so failures are debuggable.
|
||
let log_path = out_root.join("calibration.json");
|
||
std::fs::write(
|
||
&log_path,
|
||
serde_json::to_string_pretty(&serde_json::json!({
|
||
"target_ttft_mean": target_ttft_mean,
|
||
"probe_router": probe.as_str(),
|
||
"candidates": candidates,
|
||
"chosen": chosen,
|
||
"runs": log,
|
||
}))?,
|
||
)?;
|
||
eprintln!("[auto-instances] wrote {}", log_path.display());
|
||
|
||
chosen.ok_or_else(|| {
|
||
let best = log
|
||
.iter()
|
||
.min_by(|a, b| a.ttft_mean.partial_cmp(&b.ttft_mean).unwrap())
|
||
.map(|r| (r.num_instances, r.ttft_mean))
|
||
.unwrap_or((0, f64::INFINITY));
|
||
anyhow::anyhow!(
|
||
"no candidate met target ttft_mean ≤ {:.3}s; best was n={} at {:.3}s — \
|
||
widen --auto-candidates or raise --auto-target-ttft-mean",
|
||
target_ttft_mean,
|
||
best.0,
|
||
best.1,
|
||
)
|
||
})
|
||
}
|
||
|
||
fn cmd_validate(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> {
|
||
use kvcache_simulator::instance::compute::ComputeModel;
|
||
let cfg = load(path, overrides)?;
|
||
eprintln!("config OK: {}", cfg.model.name);
|
||
eprintln!(
|
||
"mode = {}",
|
||
if cfg.model.is_arch_mode() {
|
||
"architecture-derived"
|
||
} else {
|
||
"legacy manual"
|
||
}
|
||
);
|
||
let cm = ComputeModel::new(&cfg.model, &cfg.hardware, &cfg.calibration);
|
||
eprintln!("compute: {}", cm.describe());
|
||
eprintln!(
|
||
"kv_block_bytes = {} ({:.2} MB{})",
|
||
cfg.model.kv_block_bytes(),
|
||
cfg.model.kv_block_bytes() as f64 / 1e6,
|
||
if cfg.model.mla.is_some() {
|
||
", MLA compressed"
|
||
} else {
|
||
""
|
||
},
|
||
);
|
||
let block_bytes = cfg.model.kv_block_bytes() as f64;
|
||
let hbm_blocks = (cfg.hardware.hbm_bytes / block_bytes) as u64;
|
||
let dram_blocks = (cfg.hardware.dram_bytes / block_bytes) as u64;
|
||
eprintln!("per-instance HBM blocks = {hbm_blocks}, DRAM blocks = {dram_blocks}");
|
||
eprintln!("num_instances = {}", cfg.cluster.num_instances);
|
||
// Sample prefill times at a few prompt lengths.
|
||
eprintln!("prefill_time samples:");
|
||
for &n in &[256, 1024, 4096, 16384, 65536, 131072] {
|
||
let t = cm.prefill_time(n);
|
||
eprintln!(" N={n:>7} -> {t:.4} s");
|
||
}
|
||
let reader = TraceReader::open(&cfg.sim.trace_path, Some(5))?;
|
||
for rec in reader {
|
||
let rec = rec?;
|
||
eprintln!(
|
||
" req {} chat={} t={:.3}s in={} out={} blocks={}",
|
||
rec.req_id,
|
||
rec.chat_id,
|
||
rec.arrival,
|
||
rec.input_len,
|
||
rec.output_len,
|
||
rec.hash_ids.len()
|
||
);
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
fn cmd_oracle(
|
||
path: &PathBuf,
|
||
overrides: &ConfigOverrides,
|
||
capacity_blocks: Option<u64>,
|
||
per_instance: bool,
|
||
out_path: Option<&std::path::Path>,
|
||
) -> Result<()> {
|
||
let cfg = load(path, overrides)?;
|
||
let block_bytes = cfg.model.kv_block_bytes() as f64;
|
||
let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64;
|
||
let aggregate_blocks = per_instance_blocks * cfg.cluster.num_instances as u64;
|
||
let capacity = match (capacity_blocks, per_instance) {
|
||
(Some(_), true) => {
|
||
return Err(anyhow::anyhow!(
|
||
"--capacity-blocks and --per-instance are mutually exclusive"
|
||
))
|
||
}
|
||
(Some(c), false) => c,
|
||
(None, true) => per_instance_blocks,
|
||
(None, false) => aggregate_blocks,
|
||
};
|
||
|
||
eprintln!(
|
||
"[oracle] loading trace {} (max_requests={:?})",
|
||
cfg.sim.trace_path, cfg.sim.max_requests
|
||
);
|
||
let reader = TraceReader::open(&cfg.sim.trace_path, cfg.sim.max_requests)?;
|
||
let mut records: Vec<_> = reader.collect::<Result<Vec<_>, _>>()?;
|
||
let raw_count = records.len();
|
||
driver::apply_input_length_filter(&mut records, &cfg.sim);
|
||
if records.len() != raw_count {
|
||
eprintln!(
|
||
"[oracle] input_length filter [{}, {}] kept {}/{} requests",
|
||
cfg.sim.input_length_min.unwrap_or(0),
|
||
cfg.sim
|
||
.input_length_max
|
||
.map_or("∞".to_string(), |v| v.to_string()),
|
||
records.len(),
|
||
raw_count,
|
||
);
|
||
}
|
||
eprintln!(
|
||
"[oracle] loaded {} requests; analyzing with capacity = {} blocks \
|
||
({} per-instance × {} instances{})",
|
||
records.len(),
|
||
capacity,
|
||
per_instance_blocks,
|
||
cfg.cluster.num_instances,
|
||
if per_instance {
|
||
", per-instance mode"
|
||
} else {
|
||
""
|
||
}
|
||
);
|
||
|
||
let result = oracle::analyze(&records, capacity);
|
||
let json = serde_json::to_string_pretty(&result)?;
|
||
println!("{}", json);
|
||
|
||
let target = match out_path {
|
||
Some(p) => p.to_path_buf(),
|
||
None => std::path::Path::new(&cfg.sim.output_dir).join("oracle.json"),
|
||
};
|
||
if let Some(parent) = target.parent() {
|
||
std::fs::create_dir_all(parent)?;
|
||
}
|
||
std::fs::write(&target, &json)?;
|
||
eprintln!("[oracle] wrote {}", target.display());
|
||
Ok(())
|
||
}
|