fix: guard legacy runtime paths for bucketed configs

This commit is contained in:
2026-04-17 14:35:09 +08:00
parent d8a0796506
commit 7de38fa998
5 changed files with 319 additions and 5 deletions

View File

@@ -419,6 +419,15 @@ pub enum GlobalRouterMode {
BucketScore,
}
impl GlobalRouterMode {
pub fn as_str(&self) -> &'static str {
match self {
Self::StrictInputLength => "strict_input_length",
Self::BucketScore => "bucket_score",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetaStoreConfig {
pub ttl_seconds: f64,
@@ -468,6 +477,15 @@ fn default_prefix_k() -> usize {
}
impl ClusterConfig {
pub fn require_legacy_single_pool(&self, context: &str) -> Result<u32> {
if !self.buckets.is_empty() {
anyhow::bail!("{context} does not support cluster.buckets until Task 2 lands");
}
self.legacy_num_instances()
.ok_or_else(|| anyhow::anyhow!("{context} requires cluster.num_instances"))
}
pub fn legacy_num_instances(&self) -> Option<u32> {
if self.buckets.is_empty() {
self.num_instances
@@ -481,6 +499,24 @@ impl ClusterConfig {
.unwrap_or_else(|| self.buckets.iter().map(|bucket| bucket.num_instances).sum())
}
pub fn bucket_index_for_input_len(&self, input_len: u32) -> Result<usize> {
if self.buckets.is_empty() {
return Ok(0);
}
self.buckets
.iter()
.position(|bucket| {
bucket.input_length_min <= input_len && input_len <= bucket.input_length_max
})
.ok_or_else(|| {
anyhow::anyhow!(
"cluster.global_router.mode={} has no bucket for input_length={input_len}",
self.global_router.mode.as_str()
)
})
}
pub fn effective_buckets(&self) -> Vec<BucketConfig> {
if !self.buckets.is_empty() {
return self.buckets.clone();
@@ -1270,4 +1306,93 @@ sim:
assert!(err.to_string().contains("num_instances"));
assert!(err.to_string().contains("buckets"));
}
#[test]
fn bucketed_config_reports_unmatched_input_length_gaps_clearly() {
let gapful = write_temp_config(
r#"
model:
name: test
num_layers: 4
num_kv_heads: 2
head_dim: 64
dtype_bytes: 2
block_size_tokens: 16
flops_per_token_prefill: 1.0e9
attn_quadratic_coeff: 64.0
hardware:
gpu_flops: 1.0e14
gpu_mem_bw: 1.0e12
hbm_bytes: 1.0e9
cluster:
meta_store:
ttl_seconds: 10.0
router:
mode: cache_affinity
global_router:
mode: strict_input_length
buckets:
- name: short
input_length_min: 0
input_length_max: 32
num_instances: 2
- name: long
input_length_min: 64
input_length_max: 96
num_instances: 1
sim:
trace_path: trace.jsonl
output_dir: runs/test
"#,
);
let cfg = Config::from_yaml_path(&gapful).unwrap();
let err = cfg.cluster.bucket_index_for_input_len(40).unwrap_err();
assert!(err.to_string().contains("40"));
assert!(err.to_string().contains("no bucket"));
}
#[test]
fn bucketed_config_requires_legacy_single_pool_for_legacy_runtime_paths() {
let bucketed = write_temp_config(
r#"
model:
name: test
num_layers: 4
num_kv_heads: 2
head_dim: 64
dtype_bytes: 2
block_size_tokens: 16
flops_per_token_prefill: 1.0e9
attn_quadratic_coeff: 64.0
hardware:
gpu_flops: 1.0e14
gpu_mem_bw: 1.0e12
hbm_bytes: 1.0e9
cluster:
meta_store:
ttl_seconds: 10.0
router:
mode: cache_affinity
global_router:
mode: strict_input_length
buckets:
- name: short
input_length_min: 0
input_length_max: 32
num_instances: 2
sim:
trace_path: trace.jsonl
output_dir: runs/test
"#,
);
let cfg = Config::from_yaml_path(&bucketed).unwrap();
let err = cfg
.cluster
.require_legacy_single_pool("driver run")
.unwrap_err();
assert!(err.to_string().contains("driver run"));
assert!(err.to_string().contains("cluster.buckets"));
}
}

View File

@@ -49,6 +49,9 @@ struct InflightInfo {
}
pub fn run(config: &Config, output_subdir: Option<&str>) -> Result<RunOutputs> {
config
.cluster
.require_legacy_single_pool("driver run")?;
let mut cluster = Cluster::new(config, &config.model);
let mut q = EventQueue::new();
@@ -217,6 +220,8 @@ pub fn ablate_fixed_placement_with_parallelism(
evict_policies: &[ReplayEvictPolicy],
jobs: usize,
) -> Result<Vec<AblationRow>> {
base.cluster
.require_legacy_single_pool("fixed-placement ablation")?;
let mut out = Vec::new();
for &policy in evict_policies {
if policy != ReplayEvictPolicy::Lru {

View File

@@ -50,10 +50,14 @@ struct ConfigOverrides {
}
impl ConfigOverrides {
fn apply(&self, cfg: &mut Config) {
fn apply(&self, cfg: &mut Config) -> Result<()> {
if let Some(n) = self.num_instances {
if !cfg.cluster.buckets.is_empty() {
anyhow::bail!(
"--num-instances does not support cluster.buckets until Task 2 lands"
);
}
cfg.cluster.num_instances = Some(n);
cfg.cluster.buckets.clear();
}
if let Some(m) = self.max_requests {
cfg.sim.max_requests = Some(m);
@@ -79,6 +83,7 @@ impl ConfigOverrides {
if let Some(hi) = self.input_length_max {
cfg.sim.input_length_max = Some(hi);
}
Ok(())
}
}
@@ -205,7 +210,7 @@ fn main() -> Result<()> {
fn load(config: &PathBuf, overrides: &ConfigOverrides) -> Result<Config> {
let mut cfg = Config::from_yaml_path(config)?;
overrides.apply(&mut cfg);
overrides.apply(&mut cfg)?;
cfg.cluster.validate()?;
Ok(cfg)
}
@@ -313,6 +318,9 @@ fn auto_select_instances(
probe: RouterMode,
target_ttft_mean: f64,
) -> Result<u32> {
base.cluster
.require_legacy_single_pool("auto-instances calibration")?;
#[derive(serde::Serialize)]
struct CalibRow {
num_instances: u32,
@@ -334,7 +342,6 @@ fn auto_select_instances(
for &n in candidates {
let mut cfg = base.clone();
cfg.cluster.num_instances = Some(n);
cfg.cluster.buckets.clear();
cfg.cluster.router.mode = probe;
// Isolate calibration output so ablation runs don't overwrite it.
cfg.sim.output_dir = out_root
@@ -464,6 +471,8 @@ fn cmd_oracle(
out_path: Option<&std::path::Path>,
) -> Result<()> {
let cfg = load(path, overrides)?;
cfg.cluster
.require_legacy_single_pool("oracle analysis")?;
let block_bytes = cfg.model.kv_block_bytes() as f64;
let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64;
let aggregate_blocks = per_instance_blocks * cfg.cluster.total_instances() as u64;
@@ -545,3 +554,123 @@ fn cmd_oracle(
eprintln!("[oracle] wrote {}", target.display());
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use kvcache_simulator::config::{
BucketConfig, CalibrationConfig, ClusterConfig, GlobalRouterConfig, HardwareConfig,
MetaStoreConfig, ModelConfig, RouterConfig, SimConfig,
};
fn bucketed_config(out_dir: &str) -> Config {
Config {
model: ModelConfig {
name: "test".into(),
num_layers: 4,
num_kv_heads: 2,
head_dim: 64,
dtype_bytes: 2,
block_size_tokens: 16,
flops_per_token_prefill: Some(1.0e9),
attn_quadratic_coeff: Some(64.0),
..Default::default()
},
hardware: HardwareConfig {
gpu_flops: 1.0e14,
gpu_fp8_flops: 0.0,
gpu_fp4_flops: 0.0,
gpu_mem_bw: 1.0e12,
hbm_bytes: 1.0e9,
dram_bytes: 4.0e9,
host_dram_bw: 5.0e11,
pcie_bw: 32.0e9,
pcie_latency_us: 1.0,
rdma_bw: 12.0e9,
rdma_latency_us: 5.0,
intra_node_tp_bw: 9.0e11,
intra_node_tp_latency_us: 2.0,
tp_degree: 1,
max_batch_slots: 32,
prefill_chunk_tokens: 1024,
},
calibration: CalibrationConfig::default(),
cluster: ClusterConfig {
num_instances: None,
buckets: vec![
BucketConfig {
name: "short".into(),
input_length_min: 0,
input_length_max: 64,
num_instances: 2,
},
BucketConfig {
name: "long".into(),
input_length_min: 65,
input_length_max: 128,
num_instances: 1,
},
],
global_router: GlobalRouterConfig::default(),
meta_store: MetaStoreConfig {
ttl_seconds: 1000.0,
},
router: RouterConfig {
mode: RouterMode::Random,
precise_probe_latency_us: 10.0,
precise_probe_topk: 4,
load_alpha: 0.1,
score_alpha: 1.0,
score_beta: 0.1,
prefix_k: 8,
affinity_fan_out: 0,
},
},
sim: SimConfig {
trace_path: "unused.jsonl".into(),
max_requests: None,
output_dir: out_dir.into(),
sample_interval_s: 0.0,
seed: 7,
input_length_min: None,
input_length_max: None,
},
}
}
#[test]
fn num_instances_override_rejects_bucketed_configs() {
let mut cfg = bucketed_config(std::env::temp_dir().to_str().unwrap());
let overrides = ConfigOverrides {
num_instances: Some(8),
..ConfigOverrides::default()
};
let err = overrides.apply(&mut cfg).unwrap_err();
assert!(err.to_string().contains("--num-instances"));
assert!(err.to_string().contains("cluster.buckets"));
}
#[test]
fn auto_instances_rejects_bucketed_configs() {
let cfg = bucketed_config(std::env::temp_dir().to_str().unwrap());
let err = auto_select_instances(&cfg, &[4, 8], RouterMode::Random, 1.0).unwrap_err();
assert!(err.to_string().contains("auto-instances"));
assert!(err.to_string().contains("cluster.buckets"));
}
#[test]
fn oracle_rejects_bucketed_configs() {
let tmp = std::env::temp_dir().join("kvcache_sim_main_tests");
let _ = std::fs::remove_dir_all(&tmp);
std::fs::create_dir_all(&tmp).unwrap();
let path = tmp.join("bucketed.yaml");
let cfg = bucketed_config(tmp.to_str().unwrap());
std::fs::write(&path, serde_yaml::to_string(&cfg).unwrap()).unwrap();
let err = cmd_oracle(&path, &ConfigOverrides::default(), None, false, None).unwrap_err();
assert!(err.to_string().contains("oracle analysis"));
assert!(err.to_string().contains("cluster.buckets"));
}
}

View File

@@ -496,6 +496,8 @@ pub fn replay_fixed_placement(
placements: &[PlacementEntry],
policy: ReplayEvictPolicy,
) -> Result<ReplaySummary> {
cfg.cluster
.require_legacy_single_pool("fixed-placement replay")?;
if records.len() != placements.len() {
return Err(anyhow!(
"records/placements length mismatch: {} vs {}",

View File

@@ -6,7 +6,7 @@ use std::io::Write;
use kvcache_simulator::config::*;
use kvcache_simulator::driver;
use kvcache_simulator::replay::ReplayEvictPolicy;
use kvcache_simulator::replay::{self, PlacementEntry, ReplayEvictPolicy};
fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config {
Config {
@@ -70,6 +70,26 @@ fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config {
}
}
fn bucketed_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config {
let mut cfg = base_config(trace_path, out_dir, mode);
cfg.cluster.num_instances = None;
cfg.cluster.buckets = vec![
BucketConfig {
name: "short".into(),
input_length_min: 0,
input_length_max: 64,
num_instances: 2,
},
BucketConfig {
name: "long".into(),
input_length_min: 65,
input_length_max: 128,
num_instances: 1,
},
];
cfg
}
fn write_synthetic_trace(path: &std::path::Path) {
// 5 distinct conversations, each with 8 turns. Within a conversation,
// turn k+1 reuses the prefix of turn k (shared first ~10 blocks) and
@@ -274,3 +294,36 @@ fn ablation_parallel_matches_serial() {
assert!((lhs.miss_rate - rhs.miss_rate).abs() < 1e-12);
}
}
#[test]
fn bucketed_configs_are_rejected_by_legacy_runtime_paths() {
let tmp = std::env::temp_dir().join("kvcache_sim_bucketed_reject");
let _ = std::fs::remove_dir_all(&tmp);
std::fs::create_dir_all(&tmp).unwrap();
let trace_path = tmp.join("trace.jsonl");
write_synthetic_trace(&trace_path);
let cfg = bucketed_config(
trace_path.to_str().unwrap(),
tmp.to_str().unwrap(),
RouterMode::Random,
);
let result = driver::run(&cfg, Some("bucketed_guard"));
assert!(result.is_err(), "bucketed run should fail");
let err = result.err().unwrap();
assert!(err.to_string().contains("cluster.buckets"));
let err = driver::ablate_fixed_placement(&cfg, &[RouterMode::Random], &[ReplayEvictPolicy::Lru])
.expect_err("bucketed ablation should fail");
assert!(err.to_string().contains("cluster.buckets"));
let err = replay::replay_fixed_placement(
&cfg,
&[],
&Vec::<PlacementEntry>::new(),
ReplayEvictPolicy::Lru,
)
.expect_err("bucketed replay should fail");
assert!(err.to_string().contains("cluster.buckets"));
}