feat(ablate): input-length bucketing + auto-instance sizing

- Add sim.input_length_{min,max} (+ CLI overrides) that drop requests
  outside the bucket after trace load, enabling per-bucket ablation
  (e.g. 0-40k) without rewriting the trace file. Applied uniformly in
  both `run`/`ablate` driver path and `oracle` analysis.

- Add cache_score_strong router (alpha=1, beta=1) to isolate how much
  of cache_affinity's win is reproducible by just retuning beta in the
  existing cache_score framework (no rendezvous, no meta-store bonus).

- Add --auto-instances to ablate: sweeps --auto-candidates ascending
  with --auto-probe-router and picks the smallest cluster size whose
  TTFT mean <= --auto-target-ttft-mean. Per-candidate calibration
  results are persisted under runs/<output_dir>/auto_instances/ so the
  pick is auditable; the chosen N is then used for the whole ablation.
This commit is contained in:
2026-04-15 19:42:28 +08:00
parent a3f386c858
commit c86d931d8f
6 changed files with 255 additions and 6 deletions

View File

@@ -243,6 +243,8 @@ mod tests {
output_dir: String::new(),
sample_interval_s: 0.0,
seed: 7,
input_length_min: None,
input_length_max: None,
},
}
}

View File

@@ -429,6 +429,7 @@ pub enum RouterMode {
CacheAffinityWeakRend,
CacheAffinityStrongOnly,
CacheScore,
CacheScoreStrong,
CacheScoreTtl,
EstimatedTtft,
PrefixAffinity,
@@ -449,6 +450,7 @@ impl RouterMode {
"cache_affinity_weak_rend" | "caff_weak" => Ok(Self::CacheAffinityWeakRend),
"cache_affinity_strong_only" | "caff_strong" => Ok(Self::CacheAffinityStrongOnly),
"cache_score" | "cs" => Ok(Self::CacheScore),
"cache_score_strong" | "cs_strong" | "css" => Ok(Self::CacheScoreStrong),
"cache_score_ttl" | "csttl" | "cs_ttl" => Ok(Self::CacheScoreTtl),
"estimated_ttft" | "ettft" | "optimal" => Ok(Self::EstimatedTtft),
"prefix_affinity" | "affinity" | "pa" => Ok(Self::PrefixAffinity),
@@ -470,6 +472,7 @@ impl RouterMode {
Self::CacheAffinityWeakRend => "cache_affinity_weak_rend",
Self::CacheAffinityStrongOnly => "cache_affinity_strong_only",
Self::CacheScore => "cache_score",
Self::CacheScoreStrong => "cache_score_strong",
Self::CacheScoreTtl => "cache_score_ttl",
Self::EstimatedTtft => "estimated_ttft",
Self::PrefixAffinity => "prefix_affinity",
@@ -491,6 +494,15 @@ pub struct SimConfig {
pub sample_interval_s: f64,
#[serde(default)]
pub seed: u64,
/// Optional lower bound on `input_length` (tokens, inclusive). Requests
/// outside `[input_length_min, input_length_max]` are dropped from the
/// replay after trace load — use this to focus an ablation on a specific
/// input-length bucket (e.g. 040k) without touching the trace file.
#[serde(default)]
pub input_length_min: Option<u32>,
/// Optional upper bound on `input_length` (tokens, inclusive).
#[serde(default)]
pub input_length_max: Option<u32>,
}
fn default_sample_interval() -> f64 {

View File

@@ -16,6 +16,21 @@ use crate::replay::ReplayEvictPolicy;
use crate::sim::{Event, EventQueue};
use crate::trace::{RequestRecord, TraceReader};
/// Drop records whose `input_len` falls outside `sim.input_length_{min,max}`.
/// Used to carve an ablation onto a specific input-length bucket (e.g. 040k)
/// without rewriting the trace file. No-op if both bounds are unset.
pub fn apply_input_length_filter(
records: &mut Vec<RequestRecord>,
cfg: &crate::config::SimConfig,
) {
let lo = cfg.input_length_min.unwrap_or(0);
let hi = cfg.input_length_max.unwrap_or(u32::MAX);
if lo == 0 && hi == u32::MAX {
return;
}
records.retain(|r| r.input_len >= lo && r.input_len <= hi);
}
pub struct RunOutputs {
pub summary: Summary,
pub rows: Vec<PerRequestRow>,
@@ -55,7 +70,18 @@ pub fn run(config: &Config, output_subdir: Option<&str>) -> Result<RunOutputs> {
// Load all records (cheap for moderate traces) so we can index by req_id.
// For very large traces a streaming approach with a peekable iterator
// would be better; this keeps the driver simple.
let records: Vec<RequestRecord> = (&mut trace).collect::<Result<Vec<_>, _>>()?;
let mut records: Vec<RequestRecord> = (&mut trace).collect::<Result<Vec<_>, _>>()?;
let raw_count = records.len();
apply_input_length_filter(&mut records, &config.sim);
if records.len() != raw_count {
eprintln!(
"[driver] input_length filter [{}, {}] kept {}/{} requests",
config.sim.input_length_min.unwrap_or(0),
config.sim.input_length_max.map_or("".to_string(), |v| v.to_string()),
records.len(),
raw_count,
);
}
let mut by_id: HashMap<u64, RequestRecord> = HashMap::with_capacity(records.len());
for r in &records {
q.schedule(r.arrival, Event::Arrival { req_id: r.req_id });

View File

@@ -38,6 +38,15 @@ struct ConfigOverrides {
/// Override `cluster.meta_store.ttl_seconds`.
#[arg(long)]
ttl_seconds: Option<f64>,
/// Override `sim.input_length_min` — requests with `input_length` below
/// this value are dropped from the replay.
#[arg(long)]
input_length_min: Option<u32>,
/// Override `sim.input_length_max` — requests with `input_length` above
/// this value are dropped from the replay. Combine with `--input-length-min`
/// to carve out a specific input-length bucket for ablation.
#[arg(long)]
input_length_max: Option<u32>,
}
impl ConfigOverrides {
@@ -63,6 +72,12 @@ impl ConfigOverrides {
if let Some(ttl) = self.ttl_seconds {
cfg.cluster.meta_store.ttl_seconds = ttl;
}
if let Some(lo) = self.input_length_min {
cfg.sim.input_length_min = Some(lo);
}
if let Some(hi) = self.input_length_max {
cfg.sim.input_length_max = Some(hi);
}
}
}
@@ -91,6 +106,23 @@ enum Cmd {
/// Currently only `lru` is supported.
#[arg(long, default_value = "lru")]
evict_policies: String,
/// Sweep `num_instances` from `--auto-candidates` with the
/// `--auto-probe-router` and pick the smallest cluster size whose
/// TTFT mean ≤ `--auto-target-ttft-mean`. Overrides the YAML
/// `num_instances` for the ablation run.
#[arg(long, default_value_t = false)]
auto_instances: bool,
/// Target TTFT mean (seconds) for auto-instances calibration.
#[arg(long, default_value_t = 4.0)]
auto_target_ttft_mean: f64,
/// Comma-separated candidate cluster sizes (ascending).
#[arg(long, default_value = "4,8,16,24,32,48,64,96,128")]
auto_candidates: String,
/// Router used as the calibration baseline. The smallest candidate
/// where this router's TTFT mean ≤ target is picked — all ablation
/// routers are then run at that cluster size.
#[arg(long, default_value = "cache_score")]
auto_probe_router: String,
#[command(flatten)]
overrides: ConfigOverrides,
},
@@ -132,8 +164,21 @@ fn main() -> Result<()> {
config,
routers,
evict_policies,
auto_instances,
auto_target_ttft_mean,
auto_candidates,
auto_probe_router,
overrides,
} => cmd_ablate(&config, &routers, &evict_policies, &overrides),
} => cmd_ablate(
&config,
&routers,
&evict_policies,
auto_instances,
auto_target_ttft_mean,
&auto_candidates,
&auto_probe_router,
&overrides,
),
Cmd::Validate { config, overrides } => cmd_validate(&config, &overrides),
Cmd::Oracle {
config,
@@ -164,13 +209,18 @@ fn cmd_run(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> {
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn cmd_ablate(
path: &PathBuf,
routers: &str,
evict_policies: &str,
auto_instances: bool,
auto_target_ttft_mean: f64,
auto_candidates: &str,
auto_probe_router: &str,
overrides: &ConfigOverrides,
) -> Result<()> {
let base = load(path, overrides)?;
let mut base = load(path, overrides)?;
let modes: Vec<RouterMode> = routers
.split(',')
.map(|s| s.trim())
@@ -185,8 +235,42 @@ fn cmd_ablate(
.map(ReplayEvictPolicy::parse)
.collect::<Result<Vec<_>>>()
.with_context(|| format!("parsing --evict-policies='{evict_policies}'"))?;
if auto_instances {
let candidates: Vec<u32> = auto_candidates
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| {
s.parse::<u32>()
.with_context(|| format!("parsing --auto-candidates entry '{s}'"))
})
.collect::<Result<Vec<_>>>()?;
if candidates.is_empty() {
return Err(anyhow::anyhow!("--auto-candidates is empty"));
}
// Ascending so the first hit is the smallest cluster meeting the target.
let mut sorted = candidates.clone();
sorted.sort_unstable();
let probe_mode = RouterMode::parse(auto_probe_router)
.with_context(|| format!("parsing --auto-probe-router='{auto_probe_router}'"))?;
let chosen = auto_select_instances(
&base,
&sorted,
probe_mode,
auto_target_ttft_mean,
)?;
eprintln!(
"[ablate] auto-instances chose num_instances={} (target ttft_mean ≤ {:.3}s, probe_router={})",
chosen,
auto_target_ttft_mean,
probe_mode.as_str()
);
base.cluster.num_instances = chosen;
}
eprintln!(
"[ablate] routers={} evict_policies={}",
"[ablate] routers={} evict_policies={} num_instances={}",
modes
.iter()
.map(RouterMode::as_str)
@@ -196,7 +280,8 @@ fn cmd_ablate(
.iter()
.map(ReplayEvictPolicy::as_str)
.collect::<Vec<_>>()
.join(",")
.join(","),
base.cluster.num_instances,
);
let all = driver::ablate_fixed_placement(&base, &modes, &policies)?;
let agg_path = std::path::Path::new(&base.sim.output_dir).join("ablation.json");
@@ -207,6 +292,108 @@ fn cmd_ablate(
Ok(())
}
/// Sweep candidate cluster sizes ascending and return the smallest one whose
/// TTFT mean under `probe` is ≤ `target_ttft_mean`. Per-candidate calibration
/// summaries are written under `<output_dir>/auto_instances/` so the picked
/// N is auditable. If no candidate meets the target, returns an error naming
/// the best achievable TTFT.
fn auto_select_instances(
base: &Config,
candidates: &[u32],
probe: RouterMode,
target_ttft_mean: f64,
) -> Result<u32> {
#[derive(serde::Serialize)]
struct CalibRow {
num_instances: u32,
router: String,
ttft_mean: f64,
ttft_p50: f64,
ttft_p95: f64,
ttft_p99: f64,
num_requests: u64,
hit_rate_l0: f64,
passed: bool,
}
let out_root = std::path::Path::new(&base.sim.output_dir).join("auto_instances");
std::fs::create_dir_all(&out_root)?;
let mut log: Vec<CalibRow> = Vec::new();
let mut chosen: Option<u32> = None;
for &n in candidates {
let mut cfg = base.clone();
cfg.cluster.num_instances = n;
cfg.cluster.router.mode = probe;
// Isolate calibration output so ablation runs don't overwrite it.
cfg.sim.output_dir = out_root
.join(format!("n{n}__{}", probe.as_str()))
.to_string_lossy()
.into_owned();
eprintln!(
"[auto-instances] probing num_instances={} router={} ...",
n,
probe.as_str()
);
let run = driver::run(&cfg, None)?;
let passed = run.summary.ttft_mean <= target_ttft_mean;
eprintln!(
"[auto-instances] ttft_mean={:.3}s p95={:.3}s hit_l0={:.4} -> {}",
run.summary.ttft_mean,
run.summary.ttft_p95,
run.summary.hit_rate_l0,
if passed { "PASS" } else { "fail" },
);
log.push(CalibRow {
num_instances: n,
router: probe.as_str().to_string(),
ttft_mean: run.summary.ttft_mean,
ttft_p50: run.summary.ttft_p50,
ttft_p95: run.summary.ttft_p95,
ttft_p99: run.summary.ttft_p99,
num_requests: run.summary.num_requests,
hit_rate_l0: run.summary.hit_rate_l0,
passed,
});
if passed && chosen.is_none() {
chosen = Some(n);
// Keep sweeping if you want a curve; here we stop at the smallest
// passing N to satisfy the "not too small, but as small as it can
// be while meeting the SLA" requirement.
break;
}
}
// Persist the calibration log either way so failures are debuggable.
let log_path = out_root.join("calibration.json");
std::fs::write(
&log_path,
serde_json::to_string_pretty(&serde_json::json!({
"target_ttft_mean": target_ttft_mean,
"probe_router": probe.as_str(),
"candidates": candidates,
"chosen": chosen,
"runs": log,
}))?,
)?;
eprintln!("[auto-instances] wrote {}", log_path.display());
chosen.ok_or_else(|| {
let best = log
.iter()
.min_by(|a, b| a.ttft_mean.partial_cmp(&b.ttft_mean).unwrap())
.map(|r| (r.num_instances, r.ttft_mean))
.unwrap_or((0, f64::INFINITY));
anyhow::anyhow!(
"no candidate met target ttft_mean ≤ {:.3}s; best was n={} at {:.3}s — \
widen --auto-candidates or raise --auto-target-ttft-mean",
target_ttft_mean,
best.0,
best.1,
)
})
}
fn cmd_validate(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> {
use kvcache_simulator::instance::compute::ComputeModel;
let cfg = load(path, overrides)?;
@@ -285,7 +472,18 @@ fn cmd_oracle(
cfg.sim.trace_path, cfg.sim.max_requests
);
let reader = TraceReader::open(&cfg.sim.trace_path, cfg.sim.max_requests)?;
let records: Vec<_> = reader.collect::<Result<Vec<_>, _>>()?;
let mut records: Vec<_> = reader.collect::<Result<Vec<_>, _>>()?;
let raw_count = records.len();
driver::apply_input_length_filter(&mut records, &cfg.sim);
if records.len() != raw_count {
eprintln!(
"[oracle] input_length filter [{}, {}] kept {}/{} requests",
cfg.sim.input_length_min.unwrap_or(0),
cfg.sim.input_length_max.map_or("".to_string(), |v| v.to_string()),
records.len(),
raw_count,
);
}
eprintln!(
"[oracle] loaded {} requests; analyzing with capacity = {} blocks \
({} per-instance × {} instances{})",

View File

@@ -91,6 +91,13 @@ pub fn build(full: &Config, seed: u64) -> Box<dyn Router> {
cfg.score_alpha,
cfg.score_beta,
)) as Box<dyn Router>,
// Parity probe for the cache_affinity reweight claim: same scoring
// framework as cache_score, but β=1.0 so a single L0-hit block fully
// offsets one queue position. Demonstrates how much of cache_affinity's
// gain is reproducible by just retuning β (no rendezvous, no meta-store
// bonus).
CacheScoreStrong => Box::new(cache_score::CacheScoreRouter::new(1.0, 1.0))
as Box<dyn Router>,
CacheScoreTtl => Box::new(cache_score_ttl::CacheScoreTtlRouter::new(
cfg.score_alpha,
cfg.score_beta,
@@ -184,6 +191,8 @@ mod tests {
output_dir: String::new(),
sample_interval_s: 0.0,
seed: 7,
input_length_min: None,
input_length_max: None,
},
}
}