fix: guard legacy runtime paths for bucketed configs
This commit is contained in:
125
src/config.rs
125
src/config.rs
@@ -419,6 +419,15 @@ pub enum GlobalRouterMode {
|
||||
BucketScore,
|
||||
}
|
||||
|
||||
impl GlobalRouterMode {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::StrictInputLength => "strict_input_length",
|
||||
Self::BucketScore => "bucket_score",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MetaStoreConfig {
|
||||
pub ttl_seconds: f64,
|
||||
@@ -468,6 +477,15 @@ fn default_prefix_k() -> usize {
|
||||
}
|
||||
|
||||
impl ClusterConfig {
|
||||
pub fn require_legacy_single_pool(&self, context: &str) -> Result<u32> {
|
||||
if !self.buckets.is_empty() {
|
||||
anyhow::bail!("{context} does not support cluster.buckets until Task 2 lands");
|
||||
}
|
||||
|
||||
self.legacy_num_instances()
|
||||
.ok_or_else(|| anyhow::anyhow!("{context} requires cluster.num_instances"))
|
||||
}
|
||||
|
||||
pub fn legacy_num_instances(&self) -> Option<u32> {
|
||||
if self.buckets.is_empty() {
|
||||
self.num_instances
|
||||
@@ -481,6 +499,24 @@ impl ClusterConfig {
|
||||
.unwrap_or_else(|| self.buckets.iter().map(|bucket| bucket.num_instances).sum())
|
||||
}
|
||||
|
||||
pub fn bucket_index_for_input_len(&self, input_len: u32) -> Result<usize> {
|
||||
if self.buckets.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
self.buckets
|
||||
.iter()
|
||||
.position(|bucket| {
|
||||
bucket.input_length_min <= input_len && input_len <= bucket.input_length_max
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"cluster.global_router.mode={} has no bucket for input_length={input_len}",
|
||||
self.global_router.mode.as_str()
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn effective_buckets(&self) -> Vec<BucketConfig> {
|
||||
if !self.buckets.is_empty() {
|
||||
return self.buckets.clone();
|
||||
@@ -1270,4 +1306,93 @@ sim:
|
||||
assert!(err.to_string().contains("num_instances"));
|
||||
assert!(err.to_string().contains("buckets"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bucketed_config_reports_unmatched_input_length_gaps_clearly() {
|
||||
let gapful = write_temp_config(
|
||||
r#"
|
||||
model:
|
||||
name: test
|
||||
num_layers: 4
|
||||
num_kv_heads: 2
|
||||
head_dim: 64
|
||||
dtype_bytes: 2
|
||||
block_size_tokens: 16
|
||||
flops_per_token_prefill: 1.0e9
|
||||
attn_quadratic_coeff: 64.0
|
||||
hardware:
|
||||
gpu_flops: 1.0e14
|
||||
gpu_mem_bw: 1.0e12
|
||||
hbm_bytes: 1.0e9
|
||||
cluster:
|
||||
meta_store:
|
||||
ttl_seconds: 10.0
|
||||
router:
|
||||
mode: cache_affinity
|
||||
global_router:
|
||||
mode: strict_input_length
|
||||
buckets:
|
||||
- name: short
|
||||
input_length_min: 0
|
||||
input_length_max: 32
|
||||
num_instances: 2
|
||||
- name: long
|
||||
input_length_min: 64
|
||||
input_length_max: 96
|
||||
num_instances: 1
|
||||
sim:
|
||||
trace_path: trace.jsonl
|
||||
output_dir: runs/test
|
||||
"#,
|
||||
);
|
||||
let cfg = Config::from_yaml_path(&gapful).unwrap();
|
||||
|
||||
let err = cfg.cluster.bucket_index_for_input_len(40).unwrap_err();
|
||||
assert!(err.to_string().contains("40"));
|
||||
assert!(err.to_string().contains("no bucket"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bucketed_config_requires_legacy_single_pool_for_legacy_runtime_paths() {
|
||||
let bucketed = write_temp_config(
|
||||
r#"
|
||||
model:
|
||||
name: test
|
||||
num_layers: 4
|
||||
num_kv_heads: 2
|
||||
head_dim: 64
|
||||
dtype_bytes: 2
|
||||
block_size_tokens: 16
|
||||
flops_per_token_prefill: 1.0e9
|
||||
attn_quadratic_coeff: 64.0
|
||||
hardware:
|
||||
gpu_flops: 1.0e14
|
||||
gpu_mem_bw: 1.0e12
|
||||
hbm_bytes: 1.0e9
|
||||
cluster:
|
||||
meta_store:
|
||||
ttl_seconds: 10.0
|
||||
router:
|
||||
mode: cache_affinity
|
||||
global_router:
|
||||
mode: strict_input_length
|
||||
buckets:
|
||||
- name: short
|
||||
input_length_min: 0
|
||||
input_length_max: 32
|
||||
num_instances: 2
|
||||
sim:
|
||||
trace_path: trace.jsonl
|
||||
output_dir: runs/test
|
||||
"#,
|
||||
);
|
||||
let cfg = Config::from_yaml_path(&bucketed).unwrap();
|
||||
|
||||
let err = cfg
|
||||
.cluster
|
||||
.require_legacy_single_pool("driver run")
|
||||
.unwrap_err();
|
||||
assert!(err.to_string().contains("driver run"));
|
||||
assert!(err.to_string().contains("cluster.buckets"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +49,9 @@ struct InflightInfo {
|
||||
}
|
||||
|
||||
pub fn run(config: &Config, output_subdir: Option<&str>) -> Result<RunOutputs> {
|
||||
config
|
||||
.cluster
|
||||
.require_legacy_single_pool("driver run")?;
|
||||
let mut cluster = Cluster::new(config, &config.model);
|
||||
let mut q = EventQueue::new();
|
||||
|
||||
@@ -217,6 +220,8 @@ pub fn ablate_fixed_placement_with_parallelism(
|
||||
evict_policies: &[ReplayEvictPolicy],
|
||||
jobs: usize,
|
||||
) -> Result<Vec<AblationRow>> {
|
||||
base.cluster
|
||||
.require_legacy_single_pool("fixed-placement ablation")?;
|
||||
let mut out = Vec::new();
|
||||
for &policy in evict_policies {
|
||||
if policy != ReplayEvictPolicy::Lru {
|
||||
|
||||
137
src/main.rs
137
src/main.rs
@@ -50,10 +50,14 @@ struct ConfigOverrides {
|
||||
}
|
||||
|
||||
impl ConfigOverrides {
|
||||
fn apply(&self, cfg: &mut Config) {
|
||||
fn apply(&self, cfg: &mut Config) -> Result<()> {
|
||||
if let Some(n) = self.num_instances {
|
||||
if !cfg.cluster.buckets.is_empty() {
|
||||
anyhow::bail!(
|
||||
"--num-instances does not support cluster.buckets until Task 2 lands"
|
||||
);
|
||||
}
|
||||
cfg.cluster.num_instances = Some(n);
|
||||
cfg.cluster.buckets.clear();
|
||||
}
|
||||
if let Some(m) = self.max_requests {
|
||||
cfg.sim.max_requests = Some(m);
|
||||
@@ -79,6 +83,7 @@ impl ConfigOverrides {
|
||||
if let Some(hi) = self.input_length_max {
|
||||
cfg.sim.input_length_max = Some(hi);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,7 +210,7 @@ fn main() -> Result<()> {
|
||||
|
||||
fn load(config: &PathBuf, overrides: &ConfigOverrides) -> Result<Config> {
|
||||
let mut cfg = Config::from_yaml_path(config)?;
|
||||
overrides.apply(&mut cfg);
|
||||
overrides.apply(&mut cfg)?;
|
||||
cfg.cluster.validate()?;
|
||||
Ok(cfg)
|
||||
}
|
||||
@@ -313,6 +318,9 @@ fn auto_select_instances(
|
||||
probe: RouterMode,
|
||||
target_ttft_mean: f64,
|
||||
) -> Result<u32> {
|
||||
base.cluster
|
||||
.require_legacy_single_pool("auto-instances calibration")?;
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct CalibRow {
|
||||
num_instances: u32,
|
||||
@@ -334,7 +342,6 @@ fn auto_select_instances(
|
||||
for &n in candidates {
|
||||
let mut cfg = base.clone();
|
||||
cfg.cluster.num_instances = Some(n);
|
||||
cfg.cluster.buckets.clear();
|
||||
cfg.cluster.router.mode = probe;
|
||||
// Isolate calibration output so ablation runs don't overwrite it.
|
||||
cfg.sim.output_dir = out_root
|
||||
@@ -464,6 +471,8 @@ fn cmd_oracle(
|
||||
out_path: Option<&std::path::Path>,
|
||||
) -> Result<()> {
|
||||
let cfg = load(path, overrides)?;
|
||||
cfg.cluster
|
||||
.require_legacy_single_pool("oracle analysis")?;
|
||||
let block_bytes = cfg.model.kv_block_bytes() as f64;
|
||||
let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64;
|
||||
let aggregate_blocks = per_instance_blocks * cfg.cluster.total_instances() as u64;
|
||||
@@ -545,3 +554,123 @@ fn cmd_oracle(
|
||||
eprintln!("[oracle] wrote {}", target.display());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use kvcache_simulator::config::{
|
||||
BucketConfig, CalibrationConfig, ClusterConfig, GlobalRouterConfig, HardwareConfig,
|
||||
MetaStoreConfig, ModelConfig, RouterConfig, SimConfig,
|
||||
};
|
||||
|
||||
fn bucketed_config(out_dir: &str) -> Config {
|
||||
Config {
|
||||
model: ModelConfig {
|
||||
name: "test".into(),
|
||||
num_layers: 4,
|
||||
num_kv_heads: 2,
|
||||
head_dim: 64,
|
||||
dtype_bytes: 2,
|
||||
block_size_tokens: 16,
|
||||
flops_per_token_prefill: Some(1.0e9),
|
||||
attn_quadratic_coeff: Some(64.0),
|
||||
..Default::default()
|
||||
},
|
||||
hardware: HardwareConfig {
|
||||
gpu_flops: 1.0e14,
|
||||
gpu_fp8_flops: 0.0,
|
||||
gpu_fp4_flops: 0.0,
|
||||
gpu_mem_bw: 1.0e12,
|
||||
hbm_bytes: 1.0e9,
|
||||
dram_bytes: 4.0e9,
|
||||
host_dram_bw: 5.0e11,
|
||||
pcie_bw: 32.0e9,
|
||||
pcie_latency_us: 1.0,
|
||||
rdma_bw: 12.0e9,
|
||||
rdma_latency_us: 5.0,
|
||||
intra_node_tp_bw: 9.0e11,
|
||||
intra_node_tp_latency_us: 2.0,
|
||||
tp_degree: 1,
|
||||
max_batch_slots: 32,
|
||||
prefill_chunk_tokens: 1024,
|
||||
},
|
||||
calibration: CalibrationConfig::default(),
|
||||
cluster: ClusterConfig {
|
||||
num_instances: None,
|
||||
buckets: vec![
|
||||
BucketConfig {
|
||||
name: "short".into(),
|
||||
input_length_min: 0,
|
||||
input_length_max: 64,
|
||||
num_instances: 2,
|
||||
},
|
||||
BucketConfig {
|
||||
name: "long".into(),
|
||||
input_length_min: 65,
|
||||
input_length_max: 128,
|
||||
num_instances: 1,
|
||||
},
|
||||
],
|
||||
global_router: GlobalRouterConfig::default(),
|
||||
meta_store: MetaStoreConfig {
|
||||
ttl_seconds: 1000.0,
|
||||
},
|
||||
router: RouterConfig {
|
||||
mode: RouterMode::Random,
|
||||
precise_probe_latency_us: 10.0,
|
||||
precise_probe_topk: 4,
|
||||
load_alpha: 0.1,
|
||||
score_alpha: 1.0,
|
||||
score_beta: 0.1,
|
||||
prefix_k: 8,
|
||||
affinity_fan_out: 0,
|
||||
},
|
||||
},
|
||||
sim: SimConfig {
|
||||
trace_path: "unused.jsonl".into(),
|
||||
max_requests: None,
|
||||
output_dir: out_dir.into(),
|
||||
sample_interval_s: 0.0,
|
||||
seed: 7,
|
||||
input_length_min: None,
|
||||
input_length_max: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn num_instances_override_rejects_bucketed_configs() {
|
||||
let mut cfg = bucketed_config(std::env::temp_dir().to_str().unwrap());
|
||||
let overrides = ConfigOverrides {
|
||||
num_instances: Some(8),
|
||||
..ConfigOverrides::default()
|
||||
};
|
||||
|
||||
let err = overrides.apply(&mut cfg).unwrap_err();
|
||||
assert!(err.to_string().contains("--num-instances"));
|
||||
assert!(err.to_string().contains("cluster.buckets"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_instances_rejects_bucketed_configs() {
|
||||
let cfg = bucketed_config(std::env::temp_dir().to_str().unwrap());
|
||||
|
||||
let err = auto_select_instances(&cfg, &[4, 8], RouterMode::Random, 1.0).unwrap_err();
|
||||
assert!(err.to_string().contains("auto-instances"));
|
||||
assert!(err.to_string().contains("cluster.buckets"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oracle_rejects_bucketed_configs() {
|
||||
let tmp = std::env::temp_dir().join("kvcache_sim_main_tests");
|
||||
let _ = std::fs::remove_dir_all(&tmp);
|
||||
std::fs::create_dir_all(&tmp).unwrap();
|
||||
let path = tmp.join("bucketed.yaml");
|
||||
let cfg = bucketed_config(tmp.to_str().unwrap());
|
||||
std::fs::write(&path, serde_yaml::to_string(&cfg).unwrap()).unwrap();
|
||||
|
||||
let err = cmd_oracle(&path, &ConfigOverrides::default(), None, false, None).unwrap_err();
|
||||
assert!(err.to_string().contains("oracle analysis"));
|
||||
assert!(err.to_string().contains("cluster.buckets"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -496,6 +496,8 @@ pub fn replay_fixed_placement(
|
||||
placements: &[PlacementEntry],
|
||||
policy: ReplayEvictPolicy,
|
||||
) -> Result<ReplaySummary> {
|
||||
cfg.cluster
|
||||
.require_legacy_single_pool("fixed-placement replay")?;
|
||||
if records.len() != placements.len() {
|
||||
return Err(anyhow!(
|
||||
"records/placements length mismatch: {} vs {}",
|
||||
|
||||
Reference in New Issue
Block a user