feat(ablate): input-length bucketing + auto-instance sizing
- Add sim.input_length_{min,max} (+ CLI overrides) that drop requests
outside the bucket after trace load, enabling per-bucket ablation
(e.g. 0-40k) without rewriting the trace file. Applied uniformly in
both `run`/`ablate` driver path and `oracle` analysis.
- Add cache_score_strong router (alpha=1, beta=1) to isolate how much
of cache_affinity's win is reproducible by just retuning beta in the
existing cache_score framework (no rendezvous, no meta-store bonus).
- Add --auto-instances to ablate: sweeps --auto-candidates ascending
with --auto-probe-router and picks the smallest cluster size whose
TTFT mean <= --auto-target-ttft-mean. Per-candidate calibration
results are persisted under runs/<output_dir>/auto_instances/ so the
pick is auditable; the chosen N is then used for the whole ablation.
This commit is contained in:
208
src/main.rs
208
src/main.rs
@@ -38,6 +38,15 @@ struct ConfigOverrides {
|
||||
/// Override `cluster.meta_store.ttl_seconds`.
|
||||
#[arg(long)]
|
||||
ttl_seconds: Option<f64>,
|
||||
/// Override `sim.input_length_min` — requests with `input_length` below
|
||||
/// this value are dropped from the replay.
|
||||
#[arg(long)]
|
||||
input_length_min: Option<u32>,
|
||||
/// Override `sim.input_length_max` — requests with `input_length` above
|
||||
/// this value are dropped from the replay. Combine with `--input-length-min`
|
||||
/// to carve out a specific input-length bucket for ablation.
|
||||
#[arg(long)]
|
||||
input_length_max: Option<u32>,
|
||||
}
|
||||
|
||||
impl ConfigOverrides {
|
||||
@@ -63,6 +72,12 @@ impl ConfigOverrides {
|
||||
if let Some(ttl) = self.ttl_seconds {
|
||||
cfg.cluster.meta_store.ttl_seconds = ttl;
|
||||
}
|
||||
if let Some(lo) = self.input_length_min {
|
||||
cfg.sim.input_length_min = Some(lo);
|
||||
}
|
||||
if let Some(hi) = self.input_length_max {
|
||||
cfg.sim.input_length_max = Some(hi);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,6 +106,23 @@ enum Cmd {
|
||||
/// Currently only `lru` is supported.
|
||||
#[arg(long, default_value = "lru")]
|
||||
evict_policies: String,
|
||||
/// Sweep `num_instances` from `--auto-candidates` with the
|
||||
/// `--auto-probe-router` and pick the smallest cluster size whose
|
||||
/// TTFT mean ≤ `--auto-target-ttft-mean`. Overrides the YAML
|
||||
/// `num_instances` for the ablation run.
|
||||
#[arg(long, default_value_t = false)]
|
||||
auto_instances: bool,
|
||||
/// Target TTFT mean (seconds) for auto-instances calibration.
|
||||
#[arg(long, default_value_t = 4.0)]
|
||||
auto_target_ttft_mean: f64,
|
||||
/// Comma-separated candidate cluster sizes (ascending).
|
||||
#[arg(long, default_value = "4,8,16,24,32,48,64,96,128")]
|
||||
auto_candidates: String,
|
||||
/// Router used as the calibration baseline. The smallest candidate
|
||||
/// where this router's TTFT mean ≤ target is picked — all ablation
|
||||
/// routers are then run at that cluster size.
|
||||
#[arg(long, default_value = "cache_score")]
|
||||
auto_probe_router: String,
|
||||
#[command(flatten)]
|
||||
overrides: ConfigOverrides,
|
||||
},
|
||||
@@ -132,8 +164,21 @@ fn main() -> Result<()> {
|
||||
config,
|
||||
routers,
|
||||
evict_policies,
|
||||
auto_instances,
|
||||
auto_target_ttft_mean,
|
||||
auto_candidates,
|
||||
auto_probe_router,
|
||||
overrides,
|
||||
} => cmd_ablate(&config, &routers, &evict_policies, &overrides),
|
||||
} => cmd_ablate(
|
||||
&config,
|
||||
&routers,
|
||||
&evict_policies,
|
||||
auto_instances,
|
||||
auto_target_ttft_mean,
|
||||
&auto_candidates,
|
||||
&auto_probe_router,
|
||||
&overrides,
|
||||
),
|
||||
Cmd::Validate { config, overrides } => cmd_validate(&config, &overrides),
|
||||
Cmd::Oracle {
|
||||
config,
|
||||
@@ -164,13 +209,18 @@ fn cmd_run(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn cmd_ablate(
|
||||
path: &PathBuf,
|
||||
routers: &str,
|
||||
evict_policies: &str,
|
||||
auto_instances: bool,
|
||||
auto_target_ttft_mean: f64,
|
||||
auto_candidates: &str,
|
||||
auto_probe_router: &str,
|
||||
overrides: &ConfigOverrides,
|
||||
) -> Result<()> {
|
||||
let base = load(path, overrides)?;
|
||||
let mut base = load(path, overrides)?;
|
||||
let modes: Vec<RouterMode> = routers
|
||||
.split(',')
|
||||
.map(|s| s.trim())
|
||||
@@ -185,8 +235,42 @@ fn cmd_ablate(
|
||||
.map(ReplayEvictPolicy::parse)
|
||||
.collect::<Result<Vec<_>>>()
|
||||
.with_context(|| format!("parsing --evict-policies='{evict_policies}'"))?;
|
||||
|
||||
if auto_instances {
|
||||
let candidates: Vec<u32> = auto_candidates
|
||||
.split(',')
|
||||
.map(|s| s.trim())
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| {
|
||||
s.parse::<u32>()
|
||||
.with_context(|| format!("parsing --auto-candidates entry '{s}'"))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
if candidates.is_empty() {
|
||||
return Err(anyhow::anyhow!("--auto-candidates is empty"));
|
||||
}
|
||||
// Ascending so the first hit is the smallest cluster meeting the target.
|
||||
let mut sorted = candidates.clone();
|
||||
sorted.sort_unstable();
|
||||
let probe_mode = RouterMode::parse(auto_probe_router)
|
||||
.with_context(|| format!("parsing --auto-probe-router='{auto_probe_router}'"))?;
|
||||
let chosen = auto_select_instances(
|
||||
&base,
|
||||
&sorted,
|
||||
probe_mode,
|
||||
auto_target_ttft_mean,
|
||||
)?;
|
||||
eprintln!(
|
||||
"[ablate] auto-instances chose num_instances={} (target ttft_mean ≤ {:.3}s, probe_router={})",
|
||||
chosen,
|
||||
auto_target_ttft_mean,
|
||||
probe_mode.as_str()
|
||||
);
|
||||
base.cluster.num_instances = chosen;
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"[ablate] routers={} evict_policies={}",
|
||||
"[ablate] routers={} evict_policies={} num_instances={}",
|
||||
modes
|
||||
.iter()
|
||||
.map(RouterMode::as_str)
|
||||
@@ -196,7 +280,8 @@ fn cmd_ablate(
|
||||
.iter()
|
||||
.map(ReplayEvictPolicy::as_str)
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
.join(","),
|
||||
base.cluster.num_instances,
|
||||
);
|
||||
let all = driver::ablate_fixed_placement(&base, &modes, &policies)?;
|
||||
let agg_path = std::path::Path::new(&base.sim.output_dir).join("ablation.json");
|
||||
@@ -207,6 +292,108 @@ fn cmd_ablate(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sweep candidate cluster sizes ascending and return the smallest one whose
|
||||
/// TTFT mean under `probe` is ≤ `target_ttft_mean`. Per-candidate calibration
|
||||
/// summaries are written under `<output_dir>/auto_instances/` so the picked
|
||||
/// N is auditable. If no candidate meets the target, returns an error naming
|
||||
/// the best achievable TTFT.
|
||||
fn auto_select_instances(
|
||||
base: &Config,
|
||||
candidates: &[u32],
|
||||
probe: RouterMode,
|
||||
target_ttft_mean: f64,
|
||||
) -> Result<u32> {
|
||||
#[derive(serde::Serialize)]
|
||||
struct CalibRow {
|
||||
num_instances: u32,
|
||||
router: String,
|
||||
ttft_mean: f64,
|
||||
ttft_p50: f64,
|
||||
ttft_p95: f64,
|
||||
ttft_p99: f64,
|
||||
num_requests: u64,
|
||||
hit_rate_l0: f64,
|
||||
passed: bool,
|
||||
}
|
||||
|
||||
let out_root = std::path::Path::new(&base.sim.output_dir).join("auto_instances");
|
||||
std::fs::create_dir_all(&out_root)?;
|
||||
let mut log: Vec<CalibRow> = Vec::new();
|
||||
let mut chosen: Option<u32> = None;
|
||||
|
||||
for &n in candidates {
|
||||
let mut cfg = base.clone();
|
||||
cfg.cluster.num_instances = n;
|
||||
cfg.cluster.router.mode = probe;
|
||||
// Isolate calibration output so ablation runs don't overwrite it.
|
||||
cfg.sim.output_dir = out_root
|
||||
.join(format!("n{n}__{}", probe.as_str()))
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
eprintln!(
|
||||
"[auto-instances] probing num_instances={} router={} ...",
|
||||
n,
|
||||
probe.as_str()
|
||||
);
|
||||
let run = driver::run(&cfg, None)?;
|
||||
let passed = run.summary.ttft_mean <= target_ttft_mean;
|
||||
eprintln!(
|
||||
"[auto-instances] ttft_mean={:.3}s p95={:.3}s hit_l0={:.4} -> {}",
|
||||
run.summary.ttft_mean,
|
||||
run.summary.ttft_p95,
|
||||
run.summary.hit_rate_l0,
|
||||
if passed { "PASS" } else { "fail" },
|
||||
);
|
||||
log.push(CalibRow {
|
||||
num_instances: n,
|
||||
router: probe.as_str().to_string(),
|
||||
ttft_mean: run.summary.ttft_mean,
|
||||
ttft_p50: run.summary.ttft_p50,
|
||||
ttft_p95: run.summary.ttft_p95,
|
||||
ttft_p99: run.summary.ttft_p99,
|
||||
num_requests: run.summary.num_requests,
|
||||
hit_rate_l0: run.summary.hit_rate_l0,
|
||||
passed,
|
||||
});
|
||||
if passed && chosen.is_none() {
|
||||
chosen = Some(n);
|
||||
// Keep sweeping if you want a curve; here we stop at the smallest
|
||||
// passing N to satisfy the "not too small, but as small as it can
|
||||
// be while meeting the SLA" requirement.
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Persist the calibration log either way so failures are debuggable.
|
||||
let log_path = out_root.join("calibration.json");
|
||||
std::fs::write(
|
||||
&log_path,
|
||||
serde_json::to_string_pretty(&serde_json::json!({
|
||||
"target_ttft_mean": target_ttft_mean,
|
||||
"probe_router": probe.as_str(),
|
||||
"candidates": candidates,
|
||||
"chosen": chosen,
|
||||
"runs": log,
|
||||
}))?,
|
||||
)?;
|
||||
eprintln!("[auto-instances] wrote {}", log_path.display());
|
||||
|
||||
chosen.ok_or_else(|| {
|
||||
let best = log
|
||||
.iter()
|
||||
.min_by(|a, b| a.ttft_mean.partial_cmp(&b.ttft_mean).unwrap())
|
||||
.map(|r| (r.num_instances, r.ttft_mean))
|
||||
.unwrap_or((0, f64::INFINITY));
|
||||
anyhow::anyhow!(
|
||||
"no candidate met target ttft_mean ≤ {:.3}s; best was n={} at {:.3}s — \
|
||||
widen --auto-candidates or raise --auto-target-ttft-mean",
|
||||
target_ttft_mean,
|
||||
best.0,
|
||||
best.1,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn cmd_validate(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> {
|
||||
use kvcache_simulator::instance::compute::ComputeModel;
|
||||
let cfg = load(path, overrides)?;
|
||||
@@ -285,7 +472,18 @@ fn cmd_oracle(
|
||||
cfg.sim.trace_path, cfg.sim.max_requests
|
||||
);
|
||||
let reader = TraceReader::open(&cfg.sim.trace_path, cfg.sim.max_requests)?;
|
||||
let records: Vec<_> = reader.collect::<Result<Vec<_>, _>>()?;
|
||||
let mut records: Vec<_> = reader.collect::<Result<Vec<_>, _>>()?;
|
||||
let raw_count = records.len();
|
||||
driver::apply_input_length_filter(&mut records, &cfg.sim);
|
||||
if records.len() != raw_count {
|
||||
eprintln!(
|
||||
"[oracle] input_length filter [{}, {}] kept {}/{} requests",
|
||||
cfg.sim.input_length_min.unwrap_or(0),
|
||||
cfg.sim.input_length_max.map_or("∞".to_string(), |v| v.to_string()),
|
||||
records.len(),
|
||||
raw_count,
|
||||
);
|
||||
}
|
||||
eprintln!(
|
||||
"[oracle] loaded {} requests; analyzing with capacity = {} blocks \
|
||||
({} per-instance × {} instances{})",
|
||||
|
||||
Reference in New Issue
Block a user