From a723d7a811c1b0080f6c8d136d19ffd443e5e44e Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 17 Apr 2026 14:16:56 +0800 Subject: [PATCH 1/9] feat: model explicit bucketed cluster config --- src/config.rs | 375 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 372 insertions(+), 3 deletions(-) diff --git a/src/config.rs b/src/config.rs index 562ac5e..24658d7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -13,7 +13,11 @@ use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::ops::Deref; use std::path::Path; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::{Mutex, OnceLock}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { @@ -384,6 +388,88 @@ pub struct ClusterConfig { pub router: RouterConfig, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct BucketConfig { + pub name: String, + pub input_length_min: u32, + pub input_length_max: u32, + pub num_instances: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct GlobalRouterConfig { + #[serde(default)] + pub mode: GlobalRouterMode, +} + +impl Default for GlobalRouterConfig { + fn default() -> Self { + Self { + mode: GlobalRouterMode::StrictInputLength, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum GlobalRouterMode { + #[default] + StrictInputLength, + BucketScore, +} + +#[derive(Debug, Clone)] +pub struct ClusterBucketView { + pub buckets: Vec, + pub global_router: GlobalRouterConfig, +} + +impl Default for ClusterBucketView { + fn default() -> Self { + Self { + buckets: Vec::new(), + global_router: GlobalRouterConfig::default(), + } + } +} + +static DEFAULT_CLUSTER_BUCKET_VIEW: OnceLock = OnceLock::new(); +static CLUSTER_BUCKET_VIEWS: OnceLock>> = + OnceLock::new(); +static NEXT_CLUSTER_BUCKET_VIEW_ID: AtomicU32 = AtomicU32::new(1); + +fn default_cluster_bucket_view() -> &'static ClusterBucketView { + DEFAULT_CLUSTER_BUCKET_VIEW.get_or_init(ClusterBucketView::default) +} + +fn cluster_bucket_views() -> &'static Mutex> { + CLUSTER_BUCKET_VIEWS.get_or_init(|| Mutex::new(HashMap::new())) +} + +fn register_cluster_bucket_view(view: ClusterBucketView) -> u32 { + if view.buckets.is_empty() && view.global_router == GlobalRouterConfig::default() { + return 0; + } + + let id = NEXT_CLUSTER_BUCKET_VIEW_ID.fetch_add(1, Ordering::Relaxed); + let leaked = Box::leak(Box::new(view)); + cluster_bucket_views().lock().unwrap().insert(id, leaked); + id +} + +fn lookup_cluster_bucket_view(id: u32) -> &'static ClusterBucketView { + if id == 0 { + return default_cluster_bucket_view(); + } + + cluster_bucket_views() + .lock() + .unwrap() + .get(&id) + .copied() + .unwrap_or_else(default_cluster_bucket_view) +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MetaStoreConfig { pub ttl_seconds: f64, @@ -432,6 +518,96 @@ fn default_prefix_k() -> usize { 8 } +impl ClusterConfig { + fn bucket_view_id(&self) -> u32 { + if self.router.precise_probe_latency_us < 0.0 { + (-self.router.precise_probe_latency_us) as u32 + } else { + 0 + } + } + + fn bucket_view(&self) -> &'static ClusterBucketView { + lookup_cluster_bucket_view(self.bucket_view_id()) + } + + pub fn legacy_num_instances(&self) -> Option { + self.buckets.is_empty().then_some(self.num_instances) + } + + pub fn total_instances(&self) -> u32 { + if self.buckets.is_empty() { + self.num_instances + } else { + self.buckets.iter().map(|bucket| bucket.num_instances).sum() + } + } + + pub fn effective_buckets(&self) -> Vec { + if !self.buckets.is_empty() { + return self.buckets.clone(); + } + + vec![BucketConfig { + name: "default".to_string(), + input_length_min: 0, + input_length_max: u32::MAX, + num_instances: self.num_instances, + }] + } + + pub fn validate(&self) -> Result<()> { + if self.num_instances == 0 && self.buckets.is_empty() { + anyhow::bail!("cluster must set either num_instances or buckets"); + } + + if self.num_instances > 0 && !self.buckets.is_empty() { + anyhow::bail!("cluster.num_instances and cluster.buckets are mutually exclusive"); + } + + if self.buckets.is_empty() { + anyhow::ensure!(self.num_instances > 0, "cluster.num_instances must be positive"); + return Ok(()); + } + + for bucket in &self.buckets { + anyhow::ensure!( + bucket.input_length_min <= bucket.input_length_max, + "cluster bucket '{}' has input_length_min > input_length_max", + bucket.name + ); + anyhow::ensure!( + bucket.num_instances > 0, + "cluster bucket '{}' must have num_instances > 0", + bucket.name + ); + } + + let mut sorted = self.buckets.iter().collect::>(); + sorted.sort_by_key(|bucket| (bucket.input_length_min, bucket.input_length_max)); + for pair in sorted.windows(2) { + let prev = pair[0]; + let next = pair[1]; + anyhow::ensure!( + prev.input_length_max < next.input_length_min, + "cluster buckets '{}' and '{}' overlap", + prev.name, + next.name + ); + } + + Ok(()) + } +} + +impl Deref for ClusterConfig { + type Target = ClusterBucketView; + + fn deref(&self) -> &Self::Target { + self.bucket_view() + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum RouterMode { @@ -544,7 +720,7 @@ impl Config { .with_context(|| format!("parsing config {}", path.display()))?; let yaml_dir = path.parent().unwrap_or(Path::new(".")); raw.resolve(yaml_dir) - .with_context(|| format!("resolving config {}", path.display())) + .map_err(|err| anyhow::anyhow!("resolving config {}: {err}", path.display())) } } @@ -562,10 +738,22 @@ struct RawConfig { hardware: RawHardwareConfig, #[serde(default)] calibration: CalibrationConfig, - cluster: ClusterConfig, + cluster: RawClusterConfig, sim: SimConfig, } +#[derive(Deserialize)] +struct RawClusterConfig { + #[serde(default)] + num_instances: Option, + meta_store: MetaStoreConfig, + router: RouterConfig, + #[serde(default)] + global_router: GlobalRouterConfig, + #[serde(default)] + buckets: Vec, +} + #[derive(Deserialize)] struct RawModelConfig { /// Path to a HuggingFace `config.json`. Resolved relative to the YAML @@ -678,12 +866,36 @@ impl RawConfig { model, hardware, calibration: self.calibration, - cluster: self.cluster, + cluster: self.cluster.resolve()?, sim: self.sim, }) } } +impl RawClusterConfig { + fn resolve(self) -> Result { + let view = ClusterBucketView { + buckets: self.buckets, + global_router: self.global_router, + }; + let mut cluster = ClusterConfig { + num_instances: self.num_instances.unwrap_or(0), + meta_store: self.meta_store, + router: self.router, + }; + + let view_id = register_cluster_bucket_view(view); + if view_id > 0 { + // Preserve the public ClusterConfig layout for existing callers while + // carrying bucketed config through validation and tests. + cluster.router.precise_probe_latency_us = -(view_id as f64); + } + + cluster.validate()?; + Ok(cluster) + } +} + impl RawModelConfig { fn resolve(self, yaml_dir: &Path) -> Result { // Start from HF config.json if specified, else empty default. @@ -1003,4 +1215,161 @@ sim: assert_eq!(cfg.model.weight_dtype.as_deref(), Some("fp4")); assert!((cfg.model.weight_dtype_bytes() - 0.5).abs() < 1e-12); } + + #[test] + fn bucketed_config_loads_and_preserves_legacy_single_pool() { + let legacy = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + num_instances: 2 + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let cfg = Config::from_yaml_path(&legacy).unwrap(); + assert_eq!(cfg.cluster.legacy_num_instances(), Some(2)); + assert_eq!(cfg.cluster.buckets.len(), 0); + + let bucketed = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 + - name: long + input_length_min: 33 + input_length_max: 96 + num_instances: 1 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let cfg = Config::from_yaml_path(&bucketed).unwrap(); + assert_eq!(cfg.cluster.legacy_num_instances(), None); + assert_eq!(cfg.cluster.buckets.len(), 2); + assert_eq!(cfg.cluster.buckets[0].name, "short"); + assert_eq!(cfg.cluster.buckets[1].num_instances, 1); + assert_eq!(cfg.cluster.total_instances(), 3); + } + + #[test] + fn bucketed_config_rejects_overlapping_ranges_and_mixed_modes() { + let overlap = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 + - name: overlap + input_length_min: 32 + input_length_max: 96 + num_instances: 1 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let err = Config::from_yaml_path(&overlap).unwrap_err(); + assert!(err.to_string().contains("overlap")); + + let mixed = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + num_instances: 2 + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let err = Config::from_yaml_path(&mixed).unwrap_err(); + assert!(err.to_string().contains("num_instances")); + assert!(err.to_string().contains("buckets")); + } } From d8a079650674ca67578631353309856e373f2fa2 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 17 Apr 2026 14:21:34 +0800 Subject: [PATCH 2/9] fix: close bucketed cluster config model gaps --- src/cluster/cluster.rs | 9 +- src/config.rs | 152 ++++++-------------------------- src/main.rs | 18 ++-- src/replay.rs | 2 +- src/router/adaptive_affinity.rs | 2 +- src/router/lineage_affinity.rs | 2 +- src/router/mod.rs | 4 +- src/router/prefix_affinity.rs | 2 +- tests/smoke.rs | 4 +- 9 files changed, 52 insertions(+), 143 deletions(-) diff --git a/src/cluster/cluster.rs b/src/cluster/cluster.rs index 9927cdc..31d42fb 100644 --- a/src/cluster/cluster.rs +++ b/src/cluster/cluster.rs @@ -37,8 +37,9 @@ pub struct Cluster { impl Cluster { pub fn new(config: &Config, model: &ModelConfig) -> Self { - let mut instances = Vec::with_capacity(config.cluster.num_instances as usize); - for id in 0..config.cluster.num_instances { + let total_instances = config.cluster.total_instances(); + let mut instances = Vec::with_capacity(total_instances as usize); + for id in 0..total_instances { instances.push(Instance::new( id as InstanceId, model, @@ -226,7 +227,9 @@ mod tests { ..CalibrationConfig::default() }, cluster: ClusterConfig { - num_instances: 1, + num_instances: Some(1), + buckets: Vec::new(), + global_router: Default::default(), meta_store: MetaStoreConfig { ttl_seconds: 1000.0, }, diff --git a/src/config.rs b/src/config.rs index 24658d7..f646dcb 100644 --- a/src/config.rs +++ b/src/config.rs @@ -13,11 +13,7 @@ use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::ops::Deref; use std::path::Path; -use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::{Mutex, OnceLock}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { @@ -383,7 +379,12 @@ fn default_first_token_ready_us() -> f64 { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ClusterConfig { - pub num_instances: u32, + #[serde(default)] + pub num_instances: Option, + #[serde(default)] + pub buckets: Vec, + #[serde(default)] + pub global_router: GlobalRouterConfig, pub meta_store: MetaStoreConfig, pub router: RouterConfig, } @@ -418,58 +419,6 @@ pub enum GlobalRouterMode { BucketScore, } -#[derive(Debug, Clone)] -pub struct ClusterBucketView { - pub buckets: Vec, - pub global_router: GlobalRouterConfig, -} - -impl Default for ClusterBucketView { - fn default() -> Self { - Self { - buckets: Vec::new(), - global_router: GlobalRouterConfig::default(), - } - } -} - -static DEFAULT_CLUSTER_BUCKET_VIEW: OnceLock = OnceLock::new(); -static CLUSTER_BUCKET_VIEWS: OnceLock>> = - OnceLock::new(); -static NEXT_CLUSTER_BUCKET_VIEW_ID: AtomicU32 = AtomicU32::new(1); - -fn default_cluster_bucket_view() -> &'static ClusterBucketView { - DEFAULT_CLUSTER_BUCKET_VIEW.get_or_init(ClusterBucketView::default) -} - -fn cluster_bucket_views() -> &'static Mutex> { - CLUSTER_BUCKET_VIEWS.get_or_init(|| Mutex::new(HashMap::new())) -} - -fn register_cluster_bucket_view(view: ClusterBucketView) -> u32 { - if view.buckets.is_empty() && view.global_router == GlobalRouterConfig::default() { - return 0; - } - - let id = NEXT_CLUSTER_BUCKET_VIEW_ID.fetch_add(1, Ordering::Relaxed); - let leaked = Box::leak(Box::new(view)); - cluster_bucket_views().lock().unwrap().insert(id, leaked); - id -} - -fn lookup_cluster_bucket_view(id: u32) -> &'static ClusterBucketView { - if id == 0 { - return default_cluster_bucket_view(); - } - - cluster_bucket_views() - .lock() - .unwrap() - .get(&id) - .copied() - .unwrap_or_else(default_cluster_bucket_view) -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MetaStoreConfig { pub ttl_seconds: f64, @@ -519,30 +468,19 @@ fn default_prefix_k() -> usize { } impl ClusterConfig { - fn bucket_view_id(&self) -> u32 { - if self.router.precise_probe_latency_us < 0.0 { - (-self.router.precise_probe_latency_us) as u32 - } else { - 0 - } - } - - fn bucket_view(&self) -> &'static ClusterBucketView { - lookup_cluster_bucket_view(self.bucket_view_id()) - } - pub fn legacy_num_instances(&self) -> Option { - self.buckets.is_empty().then_some(self.num_instances) - } - - pub fn total_instances(&self) -> u32 { if self.buckets.is_empty() { self.num_instances } else { - self.buckets.iter().map(|bucket| bucket.num_instances).sum() + None } } + pub fn total_instances(&self) -> u32 { + self.legacy_num_instances() + .unwrap_or_else(|| self.buckets.iter().map(|bucket| bucket.num_instances).sum()) + } + pub fn effective_buckets(&self) -> Vec { if !self.buckets.is_empty() { return self.buckets.clone(); @@ -552,21 +490,22 @@ impl ClusterConfig { name: "default".to_string(), input_length_min: 0, input_length_max: u32::MAX, - num_instances: self.num_instances, + num_instances: self + .num_instances + .expect("legacy single-pool cluster must have num_instances"), }] } pub fn validate(&self) -> Result<()> { - if self.num_instances == 0 && self.buckets.is_empty() { - anyhow::bail!("cluster must set either num_instances or buckets"); - } - - if self.num_instances > 0 && !self.buckets.is_empty() { + if self.num_instances.is_some() && !self.buckets.is_empty() { anyhow::bail!("cluster.num_instances and cluster.buckets are mutually exclusive"); } if self.buckets.is_empty() { - anyhow::ensure!(self.num_instances > 0, "cluster.num_instances must be positive"); + let num_instances = self.num_instances.ok_or_else(|| { + anyhow::anyhow!("cluster must set either num_instances or buckets") + })?; + anyhow::ensure!(num_instances > 0, "cluster.num_instances must be positive"); return Ok(()); } @@ -600,14 +539,6 @@ impl ClusterConfig { } } -impl Deref for ClusterConfig { - type Target = ClusterBucketView; - - fn deref(&self) -> &Self::Target { - self.bucket_view() - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum RouterMode { @@ -738,22 +669,10 @@ struct RawConfig { hardware: RawHardwareConfig, #[serde(default)] calibration: CalibrationConfig, - cluster: RawClusterConfig, + cluster: ClusterConfig, sim: SimConfig, } -#[derive(Deserialize)] -struct RawClusterConfig { - #[serde(default)] - num_instances: Option, - meta_store: MetaStoreConfig, - router: RouterConfig, - #[serde(default)] - global_router: GlobalRouterConfig, - #[serde(default)] - buckets: Vec, -} - #[derive(Deserialize)] struct RawModelConfig { /// Path to a HuggingFace `config.json`. Resolved relative to the YAML @@ -866,36 +785,15 @@ impl RawConfig { model, hardware, calibration: self.calibration, - cluster: self.cluster.resolve()?, + cluster: { + self.cluster.validate()?; + self.cluster + }, sim: self.sim, }) } } -impl RawClusterConfig { - fn resolve(self) -> Result { - let view = ClusterBucketView { - buckets: self.buckets, - global_router: self.global_router, - }; - let mut cluster = ClusterConfig { - num_instances: self.num_instances.unwrap_or(0), - meta_store: self.meta_store, - router: self.router, - }; - - let view_id = register_cluster_bucket_view(view); - if view_id > 0 { - // Preserve the public ClusterConfig layout for existing callers while - // carrying bucketed config through validation and tests. - cluster.router.precise_probe_latency_us = -(view_id as f64); - } - - cluster.validate()?; - Ok(cluster) - } -} - impl RawModelConfig { fn resolve(self, yaml_dir: &Path) -> Result { // Start from HF config.json if specified, else empty default. diff --git a/src/main.rs b/src/main.rs index c0b6e6f..37d6677 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,7 +52,8 @@ struct ConfigOverrides { impl ConfigOverrides { fn apply(&self, cfg: &mut Config) { if let Some(n) = self.num_instances { - cfg.cluster.num_instances = n; + cfg.cluster.num_instances = Some(n); + cfg.cluster.buckets.clear(); } if let Some(m) = self.max_requests { cfg.sim.max_requests = Some(m); @@ -205,6 +206,7 @@ fn main() -> Result<()> { fn load(config: &PathBuf, overrides: &ConfigOverrides) -> Result { let mut cfg = Config::from_yaml_path(config)?; overrides.apply(&mut cfg); + cfg.cluster.validate()?; Ok(cfg) } @@ -268,7 +270,8 @@ fn cmd_ablate( auto_target_ttft_mean, probe_mode.as_str() ); - base.cluster.num_instances = chosen; + base.cluster.num_instances = Some(chosen); + base.cluster.buckets.clear(); } eprintln!( @@ -283,7 +286,7 @@ fn cmd_ablate( .map(ReplayEvictPolicy::as_str) .collect::>() .join(","), - base.cluster.num_instances, + base.cluster.total_instances(), if jobs == 0 { "auto".to_string() } else { @@ -330,7 +333,8 @@ fn auto_select_instances( for &n in candidates { let mut cfg = base.clone(); - cfg.cluster.num_instances = n; + cfg.cluster.num_instances = Some(n); + cfg.cluster.buckets.clear(); cfg.cluster.router.mode = probe; // Isolate calibration output so ablation runs don't overwrite it. cfg.sim.output_dir = out_root @@ -429,7 +433,7 @@ fn cmd_validate(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> { let hbm_blocks = (cfg.hardware.hbm_bytes / block_bytes) as u64; let dram_blocks = (cfg.hardware.dram_bytes / block_bytes) as u64; eprintln!("per-instance HBM blocks = {hbm_blocks}, DRAM blocks = {dram_blocks}"); - eprintln!("num_instances = {}", cfg.cluster.num_instances); + eprintln!("num_instances = {}", cfg.cluster.total_instances()); // Sample prefill times at a few prompt lengths. eprintln!("prefill_time samples:"); for &n in &[256, 1024, 4096, 16384, 65536, 131072] { @@ -462,7 +466,7 @@ fn cmd_oracle( let cfg = load(path, overrides)?; let block_bytes = cfg.model.kv_block_bytes() as f64; let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64; - let aggregate_blocks = per_instance_blocks * cfg.cluster.num_instances as u64; + let aggregate_blocks = per_instance_blocks * cfg.cluster.total_instances() as u64; let capacity = match (capacity_blocks, per_instance) { (Some(_), true) => { return Err(anyhow::anyhow!( @@ -518,7 +522,7 @@ fn cmd_oracle( records.len(), capacity, per_instance_blocks, - cfg.cluster.num_instances, + cfg.cluster.total_instances(), if per_instance { ", per-instance mode" } else { diff --git a/src/replay.rs b/src/replay.rs index 13a7cb5..2199dc8 100644 --- a/src/replay.rs +++ b/src/replay.rs @@ -519,7 +519,7 @@ pub fn replay_fixed_placement( let block_bytes = cfg.model.kv_block_bytes() as f64; let l0_cap = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as usize; let l1_cap = (cfg.hardware.dram_bytes / block_bytes).max(1.0) as usize; - let num_instances = cfg.cluster.num_instances as usize; + let num_instances = cfg.cluster.total_instances() as usize; let mut caches: Vec = (0..num_instances) .map(|_| ReplayInstanceCache::new(policy, l0_cap, l1_cap)) .collect(); diff --git a/src/router/adaptive_affinity.rs b/src/router/adaptive_affinity.rs index 7b94239..c1ff154 100644 --- a/src/router/adaptive_affinity.rs +++ b/src/router/adaptive_affinity.rs @@ -57,7 +57,7 @@ pub struct AdaptiveAffinityRouter { impl AdaptiveAffinityRouter { pub fn new(config: &Config) -> Self { - let n = config.cluster.num_instances as usize; + let n = config.cluster.total_instances() as usize; let configured_fan_out = config.cluster.router.affinity_fan_out; let max_fan_out = if configured_fan_out > 0 { configured_fan_out.max(2).min(n) diff --git a/src/router/lineage_affinity.rs b/src/router/lineage_affinity.rs index f8d51e6..d23a580 100644 --- a/src/router/lineage_affinity.rs +++ b/src/router/lineage_affinity.rs @@ -53,7 +53,7 @@ pub struct LineageAffinityRouter { impl LineageAffinityRouter { pub fn new(config: &Config) -> Self { - let n = config.cluster.num_instances as usize; + let n = config.cluster.total_instances() as usize; let configured_fan_out = config.cluster.router.affinity_fan_out; let max_fan_out = if configured_fan_out > 0 { configured_fan_out.max(2).min(n) diff --git a/src/router/mod.rs b/src/router/mod.rs index 7773e3d..2a23e3c 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -181,7 +181,9 @@ mod tests { hardware: test_hardware(), calibration: CalibrationConfig::default(), cluster: ClusterConfig { - num_instances: 2, + num_instances: Some(2), + buckets: Vec::new(), + global_router: Default::default(), meta_store: MetaStoreConfig { ttl_seconds: 1000.0, }, diff --git a/src/router/prefix_affinity.rs b/src/router/prefix_affinity.rs index 8f99b7a..ca6bcb3 100644 --- a/src/router/prefix_affinity.rs +++ b/src/router/prefix_affinity.rs @@ -51,7 +51,7 @@ pub struct PrefixAffinityRouter { impl PrefixAffinityRouter { pub fn new(config: &Config) -> Self { - let n = config.cluster.num_instances as usize; + let n = config.cluster.total_instances() as usize; let cfg_fan = config.cluster.router.affinity_fan_out; // fan_out: if configured, use it; otherwise auto = max(2, n/8). let fan_out = if cfg_fan > 0 { diff --git a/tests/smoke.rs b/tests/smoke.rs index 700408d..33c81db 100644 --- a/tests/smoke.rs +++ b/tests/smoke.rs @@ -41,7 +41,9 @@ fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config { }, calibration: CalibrationConfig::default(), cluster: ClusterConfig { - num_instances: 4, + num_instances: Some(4), + buckets: Vec::new(), + global_router: Default::default(), meta_store: MetaStoreConfig { ttl_seconds: 1000.0, }, From 7de38fa9987e635f5a71b66ed4a3a788a24f8e5c Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 17 Apr 2026 14:35:09 +0800 Subject: [PATCH 3/9] fix: guard legacy runtime paths for bucketed configs --- src/config.rs | 125 ++++++++++++++++++++++++++++++++++++++++++++ src/driver.rs | 5 ++ src/main.rs | 137 +++++++++++++++++++++++++++++++++++++++++++++++-- src/replay.rs | 2 + tests/smoke.rs | 55 +++++++++++++++++++- 5 files changed, 319 insertions(+), 5 deletions(-) diff --git a/src/config.rs b/src/config.rs index f646dcb..5be0ba6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -419,6 +419,15 @@ pub enum GlobalRouterMode { BucketScore, } +impl GlobalRouterMode { + pub fn as_str(&self) -> &'static str { + match self { + Self::StrictInputLength => "strict_input_length", + Self::BucketScore => "bucket_score", + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MetaStoreConfig { pub ttl_seconds: f64, @@ -468,6 +477,15 @@ fn default_prefix_k() -> usize { } impl ClusterConfig { + pub fn require_legacy_single_pool(&self, context: &str) -> Result { + if !self.buckets.is_empty() { + anyhow::bail!("{context} does not support cluster.buckets until Task 2 lands"); + } + + self.legacy_num_instances() + .ok_or_else(|| anyhow::anyhow!("{context} requires cluster.num_instances")) + } + pub fn legacy_num_instances(&self) -> Option { if self.buckets.is_empty() { self.num_instances @@ -481,6 +499,24 @@ impl ClusterConfig { .unwrap_or_else(|| self.buckets.iter().map(|bucket| bucket.num_instances).sum()) } + pub fn bucket_index_for_input_len(&self, input_len: u32) -> Result { + if self.buckets.is_empty() { + return Ok(0); + } + + self.buckets + .iter() + .position(|bucket| { + bucket.input_length_min <= input_len && input_len <= bucket.input_length_max + }) + .ok_or_else(|| { + anyhow::anyhow!( + "cluster.global_router.mode={} has no bucket for input_length={input_len}", + self.global_router.mode.as_str() + ) + }) + } + pub fn effective_buckets(&self) -> Vec { if !self.buckets.is_empty() { return self.buckets.clone(); @@ -1270,4 +1306,93 @@ sim: assert!(err.to_string().contains("num_instances")); assert!(err.to_string().contains("buckets")); } + + #[test] + fn bucketed_config_reports_unmatched_input_length_gaps_clearly() { + let gapful = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 + - name: long + input_length_min: 64 + input_length_max: 96 + num_instances: 1 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let cfg = Config::from_yaml_path(&gapful).unwrap(); + + let err = cfg.cluster.bucket_index_for_input_len(40).unwrap_err(); + assert!(err.to_string().contains("40")); + assert!(err.to_string().contains("no bucket")); + } + + #[test] + fn bucketed_config_requires_legacy_single_pool_for_legacy_runtime_paths() { + let bucketed = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let cfg = Config::from_yaml_path(&bucketed).unwrap(); + + let err = cfg + .cluster + .require_legacy_single_pool("driver run") + .unwrap_err(); + assert!(err.to_string().contains("driver run")); + assert!(err.to_string().contains("cluster.buckets")); + } } diff --git a/src/driver.rs b/src/driver.rs index 6c1c064..05e9531 100644 --- a/src/driver.rs +++ b/src/driver.rs @@ -49,6 +49,9 @@ struct InflightInfo { } pub fn run(config: &Config, output_subdir: Option<&str>) -> Result { + config + .cluster + .require_legacy_single_pool("driver run")?; let mut cluster = Cluster::new(config, &config.model); let mut q = EventQueue::new(); @@ -217,6 +220,8 @@ pub fn ablate_fixed_placement_with_parallelism( evict_policies: &[ReplayEvictPolicy], jobs: usize, ) -> Result> { + base.cluster + .require_legacy_single_pool("fixed-placement ablation")?; let mut out = Vec::new(); for &policy in evict_policies { if policy != ReplayEvictPolicy::Lru { diff --git a/src/main.rs b/src/main.rs index 37d6677..b174148 100644 --- a/src/main.rs +++ b/src/main.rs @@ -50,10 +50,14 @@ struct ConfigOverrides { } impl ConfigOverrides { - fn apply(&self, cfg: &mut Config) { + fn apply(&self, cfg: &mut Config) -> Result<()> { if let Some(n) = self.num_instances { + if !cfg.cluster.buckets.is_empty() { + anyhow::bail!( + "--num-instances does not support cluster.buckets until Task 2 lands" + ); + } cfg.cluster.num_instances = Some(n); - cfg.cluster.buckets.clear(); } if let Some(m) = self.max_requests { cfg.sim.max_requests = Some(m); @@ -79,6 +83,7 @@ impl ConfigOverrides { if let Some(hi) = self.input_length_max { cfg.sim.input_length_max = Some(hi); } + Ok(()) } } @@ -205,7 +210,7 @@ fn main() -> Result<()> { fn load(config: &PathBuf, overrides: &ConfigOverrides) -> Result { let mut cfg = Config::from_yaml_path(config)?; - overrides.apply(&mut cfg); + overrides.apply(&mut cfg)?; cfg.cluster.validate()?; Ok(cfg) } @@ -313,6 +318,9 @@ fn auto_select_instances( probe: RouterMode, target_ttft_mean: f64, ) -> Result { + base.cluster + .require_legacy_single_pool("auto-instances calibration")?; + #[derive(serde::Serialize)] struct CalibRow { num_instances: u32, @@ -334,7 +342,6 @@ fn auto_select_instances( for &n in candidates { let mut cfg = base.clone(); cfg.cluster.num_instances = Some(n); - cfg.cluster.buckets.clear(); cfg.cluster.router.mode = probe; // Isolate calibration output so ablation runs don't overwrite it. cfg.sim.output_dir = out_root @@ -464,6 +471,8 @@ fn cmd_oracle( out_path: Option<&std::path::Path>, ) -> Result<()> { let cfg = load(path, overrides)?; + cfg.cluster + .require_legacy_single_pool("oracle analysis")?; let block_bytes = cfg.model.kv_block_bytes() as f64; let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64; let aggregate_blocks = per_instance_blocks * cfg.cluster.total_instances() as u64; @@ -545,3 +554,123 @@ fn cmd_oracle( eprintln!("[oracle] wrote {}", target.display()); Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use kvcache_simulator::config::{ + BucketConfig, CalibrationConfig, ClusterConfig, GlobalRouterConfig, HardwareConfig, + MetaStoreConfig, ModelConfig, RouterConfig, SimConfig, + }; + + fn bucketed_config(out_dir: &str) -> Config { + Config { + model: ModelConfig { + name: "test".into(), + num_layers: 4, + num_kv_heads: 2, + head_dim: 64, + dtype_bytes: 2, + block_size_tokens: 16, + flops_per_token_prefill: Some(1.0e9), + attn_quadratic_coeff: Some(64.0), + ..Default::default() + }, + hardware: HardwareConfig { + gpu_flops: 1.0e14, + gpu_fp8_flops: 0.0, + gpu_fp4_flops: 0.0, + gpu_mem_bw: 1.0e12, + hbm_bytes: 1.0e9, + dram_bytes: 4.0e9, + host_dram_bw: 5.0e11, + pcie_bw: 32.0e9, + pcie_latency_us: 1.0, + rdma_bw: 12.0e9, + rdma_latency_us: 5.0, + intra_node_tp_bw: 9.0e11, + intra_node_tp_latency_us: 2.0, + tp_degree: 1, + max_batch_slots: 32, + prefill_chunk_tokens: 1024, + }, + calibration: CalibrationConfig::default(), + cluster: ClusterConfig { + num_instances: None, + buckets: vec![ + BucketConfig { + name: "short".into(), + input_length_min: 0, + input_length_max: 64, + num_instances: 2, + }, + BucketConfig { + name: "long".into(), + input_length_min: 65, + input_length_max: 128, + num_instances: 1, + }, + ], + global_router: GlobalRouterConfig::default(), + meta_store: MetaStoreConfig { + ttl_seconds: 1000.0, + }, + router: RouterConfig { + mode: RouterMode::Random, + precise_probe_latency_us: 10.0, + precise_probe_topk: 4, + load_alpha: 0.1, + score_alpha: 1.0, + score_beta: 0.1, + prefix_k: 8, + affinity_fan_out: 0, + }, + }, + sim: SimConfig { + trace_path: "unused.jsonl".into(), + max_requests: None, + output_dir: out_dir.into(), + sample_interval_s: 0.0, + seed: 7, + input_length_min: None, + input_length_max: None, + }, + } + } + + #[test] + fn num_instances_override_rejects_bucketed_configs() { + let mut cfg = bucketed_config(std::env::temp_dir().to_str().unwrap()); + let overrides = ConfigOverrides { + num_instances: Some(8), + ..ConfigOverrides::default() + }; + + let err = overrides.apply(&mut cfg).unwrap_err(); + assert!(err.to_string().contains("--num-instances")); + assert!(err.to_string().contains("cluster.buckets")); + } + + #[test] + fn auto_instances_rejects_bucketed_configs() { + let cfg = bucketed_config(std::env::temp_dir().to_str().unwrap()); + + let err = auto_select_instances(&cfg, &[4, 8], RouterMode::Random, 1.0).unwrap_err(); + assert!(err.to_string().contains("auto-instances")); + assert!(err.to_string().contains("cluster.buckets")); + } + + #[test] + fn oracle_rejects_bucketed_configs() { + let tmp = std::env::temp_dir().join("kvcache_sim_main_tests"); + let _ = std::fs::remove_dir_all(&tmp); + std::fs::create_dir_all(&tmp).unwrap(); + let path = tmp.join("bucketed.yaml"); + let cfg = bucketed_config(tmp.to_str().unwrap()); + std::fs::write(&path, serde_yaml::to_string(&cfg).unwrap()).unwrap(); + + let err = cmd_oracle(&path, &ConfigOverrides::default(), None, false, None).unwrap_err(); + assert!(err.to_string().contains("oracle analysis")); + assert!(err.to_string().contains("cluster.buckets")); + } +} diff --git a/src/replay.rs b/src/replay.rs index 2199dc8..a26b724 100644 --- a/src/replay.rs +++ b/src/replay.rs @@ -496,6 +496,8 @@ pub fn replay_fixed_placement( placements: &[PlacementEntry], policy: ReplayEvictPolicy, ) -> Result { + cfg.cluster + .require_legacy_single_pool("fixed-placement replay")?; if records.len() != placements.len() { return Err(anyhow!( "records/placements length mismatch: {} vs {}", diff --git a/tests/smoke.rs b/tests/smoke.rs index 33c81db..76374f4 100644 --- a/tests/smoke.rs +++ b/tests/smoke.rs @@ -6,7 +6,7 @@ use std::io::Write; use kvcache_simulator::config::*; use kvcache_simulator::driver; -use kvcache_simulator::replay::ReplayEvictPolicy; +use kvcache_simulator::replay::{self, PlacementEntry, ReplayEvictPolicy}; fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config { Config { @@ -70,6 +70,26 @@ fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config { } } +fn bucketed_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config { + let mut cfg = base_config(trace_path, out_dir, mode); + cfg.cluster.num_instances = None; + cfg.cluster.buckets = vec![ + BucketConfig { + name: "short".into(), + input_length_min: 0, + input_length_max: 64, + num_instances: 2, + }, + BucketConfig { + name: "long".into(), + input_length_min: 65, + input_length_max: 128, + num_instances: 1, + }, + ]; + cfg +} + fn write_synthetic_trace(path: &std::path::Path) { // 5 distinct conversations, each with 8 turns. Within a conversation, // turn k+1 reuses the prefix of turn k (shared first ~10 blocks) and @@ -274,3 +294,36 @@ fn ablation_parallel_matches_serial() { assert!((lhs.miss_rate - rhs.miss_rate).abs() < 1e-12); } } + +#[test] +fn bucketed_configs_are_rejected_by_legacy_runtime_paths() { + let tmp = std::env::temp_dir().join("kvcache_sim_bucketed_reject"); + let _ = std::fs::remove_dir_all(&tmp); + std::fs::create_dir_all(&tmp).unwrap(); + let trace_path = tmp.join("trace.jsonl"); + write_synthetic_trace(&trace_path); + + let cfg = bucketed_config( + trace_path.to_str().unwrap(), + tmp.to_str().unwrap(), + RouterMode::Random, + ); + + let result = driver::run(&cfg, Some("bucketed_guard")); + assert!(result.is_err(), "bucketed run should fail"); + let err = result.err().unwrap(); + assert!(err.to_string().contains("cluster.buckets")); + + let err = driver::ablate_fixed_placement(&cfg, &[RouterMode::Random], &[ReplayEvictPolicy::Lru]) + .expect_err("bucketed ablation should fail"); + assert!(err.to_string().contains("cluster.buckets")); + + let err = replay::replay_fixed_placement( + &cfg, + &[], + &Vec::::new(), + ReplayEvictPolicy::Lru, + ) + .expect_err("bucketed replay should fail"); + assert!(err.to_string().contains("cluster.buckets")); +} From 008fe2fe5dab379cffd8e64c4aa9a538f2ee3a14 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 17 Apr 2026 14:44:21 +0800 Subject: [PATCH 4/9] fix: reject bucketed configs in cluster constructor --- src/cluster/cluster.rs | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/src/cluster/cluster.rs b/src/cluster/cluster.rs index 31d42fb..122945e 100644 --- a/src/cluster/cluster.rs +++ b/src/cluster/cluster.rs @@ -37,7 +37,10 @@ pub struct Cluster { impl Cluster { pub fn new(config: &Config, model: &ModelConfig) -> Self { - let total_instances = config.cluster.total_instances(); + let total_instances = config + .cluster + .require_legacy_single_pool("Cluster::new") + .unwrap_or_else(|err| panic!("{err}")); let mut instances = Vec::with_capacity(total_instances as usize); for id in 0..total_instances { instances.push(Instance::new( @@ -185,8 +188,8 @@ impl Cluster { mod tests { use super::*; use crate::config::{ - CalibrationConfig, ClusterConfig, Config, HardwareConfig, MetaStoreConfig, ModelConfig, - RouterConfig, RouterMode, SimConfig, + BucketConfig, CalibrationConfig, ClusterConfig, Config, HardwareConfig, MetaStoreConfig, + ModelConfig, RouterConfig, RouterMode, SimConfig, }; use crate::trace::RequestRecord; @@ -285,4 +288,29 @@ mod tests { assert!(stats.ready_at > pure_pcie); } + + #[test] + fn cluster_new_rejects_bucketed_configs() { + let mut cfg = test_config(RouterMode::EstimatedTtft); + cfg.cluster.num_instances = None; + cfg.cluster.buckets = vec![BucketConfig { + name: "short".into(), + input_length_min: 0, + input_length_max: 64, + num_instances: 2, + }]; + + let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + Cluster::new(&cfg, &cfg.model); + })) + .expect_err("bucketed Cluster::new should panic"); + + let msg = panic + .downcast_ref::() + .cloned() + .or_else(|| panic.downcast_ref::<&str>().map(|s| (*s).to_string())) + .expect("panic payload should be a string"); + assert!(msg.contains("Cluster::new")); + assert!(msg.contains("cluster.buckets")); + } } From 96019082cca9cd68d061fd5ec9bf05602370bc52 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 17 Apr 2026 14:50:47 +0800 Subject: [PATCH 5/9] fix: complete global router config and recoverable cluster init --- src/cluster/cluster.rs | 32 ++++++++------------ src/config.rs | 68 +++++++++++++++++++++++++++++++++++++++++- src/driver.rs | 2 +- 3 files changed, 80 insertions(+), 22 deletions(-) diff --git a/src/cluster/cluster.rs b/src/cluster/cluster.rs index 122945e..035a72e 100644 --- a/src/cluster/cluster.rs +++ b/src/cluster/cluster.rs @@ -1,6 +1,8 @@ //! Cluster: routes arrivals, performs the L0 / L1 / remote-RDMA fetch chain //! described in the design diagram, and bookkeeps the global meta store. +use anyhow::Result; + use crate::cluster::meta_store::MetaStore; use crate::config::{Config, ModelConfig}; use crate::instance::instance::AdmittedRequest; @@ -36,11 +38,8 @@ pub struct Cluster { } impl Cluster { - pub fn new(config: &Config, model: &ModelConfig) -> Self { - let total_instances = config - .cluster - .require_legacy_single_pool("Cluster::new") - .unwrap_or_else(|err| panic!("{err}")); + pub fn new(config: &Config, model: &ModelConfig) -> Result { + let total_instances = config.cluster.require_legacy_single_pool("Cluster::new")?; let mut instances = Vec::with_capacity(total_instances as usize); for id in 0..total_instances { instances.push(Instance::new( @@ -52,7 +51,7 @@ impl Cluster { } let meta_store = MetaStore::new(config.cluster.meta_store.ttl_seconds); let router = router::build(config, config.sim.seed); - Self { + Ok(Self { instances, meta_store, router, @@ -63,7 +62,7 @@ impl Cluster { &config.calibration, model.kv_block_bytes(), ), - } + }) } /// Route + admit a request. Returns the chosen instance plus rich @@ -262,7 +261,7 @@ mod tests { #[test] fn l1_ready_at_includes_dram_and_transform_overhead() { let cfg = test_config(RouterMode::EstimatedTtft); - let mut cluster = Cluster::new(&cfg, &cfg.model); + let mut cluster = Cluster::new(&cfg, &cfg.model).unwrap(); let req = RequestRecord { req_id: 1, chat_id: 0, @@ -300,17 +299,10 @@ mod tests { num_instances: 2, }]; - let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - Cluster::new(&cfg, &cfg.model); - })) - .expect_err("bucketed Cluster::new should panic"); - - let msg = panic - .downcast_ref::() - .cloned() - .or_else(|| panic.downcast_ref::<&str>().map(|s| (*s).to_string())) - .expect("panic payload should be a string"); - assert!(msg.contains("Cluster::new")); - assert!(msg.contains("cluster.buckets")); + let result = Cluster::new(&cfg, &cfg.model); + assert!(result.is_err(), "bucketed Cluster::new should fail"); + let err = result.err().unwrap(); + assert!(err.to_string().contains("Cluster::new")); + assert!(err.to_string().contains("cluster.buckets")); } } diff --git a/src/config.rs b/src/config.rs index 5be0ba6..36e4947 100644 --- a/src/config.rs +++ b/src/config.rs @@ -397,16 +397,25 @@ pub struct BucketConfig { pub num_instances: u32, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct GlobalRouterConfig { #[serde(default)] pub mode: GlobalRouterMode, + #[serde(default = "default_global_router_length_penalty_weight")] + pub length_penalty_weight: f64, + #[serde(default = "default_global_router_load_weight")] + pub load_weight: f64, + #[serde(default = "default_global_router_cache_weight")] + pub cache_weight: f64, } impl Default for GlobalRouterConfig { fn default() -> Self { Self { mode: GlobalRouterMode::StrictInputLength, + length_penalty_weight: default_global_router_length_penalty_weight(), + load_weight: default_global_router_load_weight(), + cache_weight: default_global_router_cache_weight(), } } } @@ -428,6 +437,18 @@ impl GlobalRouterMode { } } +fn default_global_router_length_penalty_weight() -> f64 { + 1.0 +} + +fn default_global_router_load_weight() -> f64 { + 1.0 +} + +fn default_global_router_cache_weight() -> f64 { + 1.0 +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MetaStoreConfig { pub ttl_seconds: f64, @@ -1228,6 +1249,51 @@ sim: assert_eq!(cfg.cluster.total_instances(), 3); } + #[test] + fn bucketed_config_deserializes_global_router_weights() { + let path = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: bucket_score + length_penalty_weight: 1.5 + load_weight: 0.75 + cache_weight: 2.25 + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + + let cfg = Config::from_yaml_path(&path).unwrap(); + assert_eq!(cfg.cluster.global_router.mode, GlobalRouterMode::BucketScore); + assert!((cfg.cluster.global_router.length_penalty_weight - 1.5).abs() < 1e-12); + assert!((cfg.cluster.global_router.load_weight - 0.75).abs() < 1e-12); + assert!((cfg.cluster.global_router.cache_weight - 2.25).abs() < 1e-12); + } + #[test] fn bucketed_config_rejects_overlapping_ranges_and_mixed_modes() { let overlap = write_temp_config( diff --git a/src/driver.rs b/src/driver.rs index 05e9531..dbc6b5e 100644 --- a/src/driver.rs +++ b/src/driver.rs @@ -52,7 +52,7 @@ pub fn run(config: &Config, output_subdir: Option<&str>) -> Result { config .cluster .require_legacy_single_pool("driver run")?; - let mut cluster = Cluster::new(config, &config.model); + let mut cluster = Cluster::new(config, &config.model)?; let mut q = EventQueue::new(); // Output directory From fa381b5db3b976afb555666e7a44d2eeab0b2f6e Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 17 Apr 2026 15:03:10 +0800 Subject: [PATCH 6/9] feat: add bucketed service and strict global routing --- src/cluster/bucketed_service.rs | 186 ++++++++++++++++++++++++++++++++ src/cluster/cluster.rs | 96 ++++++++++++----- src/cluster/mod.rs | 2 + src/router/adaptive_affinity.rs | 12 +-- src/router/cache_affinity.rs | 14 +-- src/router/cache_load.rs | 14 +-- src/router/cache_score.rs | 14 +-- src/router/cache_score_ttl.rs | 14 +-- src/router/estimated_ttft.rs | 14 +-- src/router/global_bucket.rs | 133 +++++++++++++++++++++++ src/router/least_loaded.rs | 14 +-- src/router/least_tokens.rs | 14 +-- src/router/lineage_affinity.rs | 12 +-- src/router/min_pd.rs | 14 +-- src/router/mod.rs | 40 +++++++ src/router/precise_aware.rs | 14 +-- src/router/prefix_affinity.rs | 12 +-- src/router/random.rs | 28 ++--- src/router/ttl_aware.rs | 14 +-- 19 files changed, 533 insertions(+), 128 deletions(-) create mode 100644 src/cluster/bucketed_service.rs create mode 100644 src/router/global_bucket.rs diff --git a/src/cluster/bucketed_service.rs b/src/cluster/bucketed_service.rs new file mode 100644 index 0000000..104f77a --- /dev/null +++ b/src/cluster/bucketed_service.rs @@ -0,0 +1,186 @@ +use super::cluster::{AdmissionStats, Cluster}; +use crate::config::{BucketConfig, Config, ModelConfig}; +use crate::instance::Instance; +use crate::router::{self, BucketId, GlobalRouter}; +use crate::trace::RequestRecord; + +pub struct ServiceBucket { + pub id: BucketId, + pub cfg: BucketConfig, + pub cluster: Cluster, +} + +impl ServiceBucket { + pub fn instances(&self) -> &[Instance] { + &self.cluster.instances + } +} + +pub struct BucketedService { + pub buckets: Vec, + pub global_router: Box, +} + +impl BucketedService { + pub fn new(config: &Config, model: &ModelConfig) -> Self { + let buckets = config + .cluster + .effective_buckets() + .into_iter() + .enumerate() + .map(|(idx, cfg)| ServiceBucket { + id: idx as BucketId, + cluster: Cluster::new_for_bucket(config, model, idx as BucketId, cfg.num_instances) + .expect("bucket-local cluster construction should succeed"), + cfg, + }) + .collect(); + + Self { + buckets, + global_router: router::build_global(config), + } + } + + pub fn bucket(&self, bucket_id: BucketId) -> &ServiceBucket { + &self.buckets[bucket_id as usize] + } + + pub fn route_and_admit(&mut self, req: &RequestRecord, now: f64) -> AdmissionStats { + let bucket_views = self + .buckets + .iter() + .map(|bucket| bucket.cluster.bucket_view(bucket.id, &bucket.cfg)) + .collect::>(); + let global = self.global_router.route(req, &bucket_views, now); + let bucket = &mut self.buckets[global.chosen_bucket as usize]; + bucket + .cluster + .route_and_admit_with_global(req, now, &global) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{ + BucketConfig, CalibrationConfig, ClusterConfig, Config, GlobalRouterConfig, + GlobalRouterMode, HardwareConfig, MetaStoreConfig, ModelConfig, RouterConfig, RouterMode, + SimConfig, + }; + use crate::trace::RequestRecord; + + fn test_config() -> Config { + Config { + model: ModelConfig { + name: "test".into(), + num_layers: 4, + num_kv_heads: 2, + head_dim: 64, + dtype_bytes: 2, + block_size_tokens: 16, + flops_per_token_prefill: Some(1.0e9), + attn_quadratic_coeff: Some(64.0), + ..Default::default() + }, + hardware: HardwareConfig { + gpu_flops: 1.0e14, + gpu_fp8_flops: 0.0, + gpu_fp4_flops: 0.0, + gpu_mem_bw: 1.0e12, + hbm_bytes: 1.0e9, + dram_bytes: 4.0e9, + host_dram_bw: 5.0e11, + pcie_bw: 32.0e9, + pcie_latency_us: 1.0, + rdma_bw: 12.0e9, + rdma_latency_us: 5.0, + intra_node_tp_bw: 9.0e11, + intra_node_tp_latency_us: 2.0, + tp_degree: 1, + max_batch_slots: 32, + prefill_chunk_tokens: 1024, + }, + calibration: CalibrationConfig::default(), + cluster: ClusterConfig { + num_instances: None, + buckets: vec![ + BucketConfig { + name: "short".into(), + input_length_min: 0, + input_length_max: 32, + num_instances: 2, + }, + BucketConfig { + name: "long".into(), + input_length_min: 33, + input_length_max: 96, + num_instances: 1, + }, + ], + global_router: GlobalRouterConfig { + mode: GlobalRouterMode::StrictInputLength, + length_penalty_weight: 1.0, + load_weight: 1.0, + cache_weight: 1.0, + }, + meta_store: MetaStoreConfig { + ttl_seconds: 1000.0, + }, + router: RouterConfig { + mode: RouterMode::LeastLoaded, + precise_probe_latency_us: 10.0, + precise_probe_topk: 2, + load_alpha: 0.0, + score_alpha: 1.0, + score_beta: 0.1, + prefix_k: 8, + affinity_fan_out: 2, + }, + }, + sim: SimConfig { + trace_path: String::new(), + max_requests: None, + output_dir: String::new(), + sample_interval_s: 0.0, + seed: 7, + input_length_min: None, + input_length_max: None, + }, + } + } + + fn req(req_id: u64, input_len: u32, hashes: &[u64]) -> RequestRecord { + RequestRecord { + req_id, + chat_id: req_id as i64, + parent_chat_id: -1, + turn: 0, + arrival: 0.0, + input_len, + output_len: 16, + hash_ids: hashes.to_vec(), + } + } + + #[test] + fn strict_input_length_routes_into_matching_bucket() { + let cfg = test_config(); + let mut service = BucketedService::new(&cfg, &cfg.model); + let stats = service.route_and_admit(&req(1, 24, &[10, 11]), 0.0); + assert_eq!(stats.bucket, 0); + assert_eq!(stats.decision.chosen_bucket, 0); + assert_eq!(service.bucket(0).instances().len(), 2); + } + + #[test] + fn bucket_meta_store_is_isolated() { + let cfg = test_config(); + let mut service = BucketedService::new(&cfg, &cfg.model); + let _ = service.route_and_admit(&req(1, 24, &[10, 11]), 0.0); + let long_stats = service.route_and_admit(&req(2, 64, &[10, 11, 12, 13]), 1.0); + assert_eq!(long_stats.bucket, 1); + assert_eq!(long_stats.remote_hit_blocks, 0); + assert_eq!(long_stats.l1_hit_blocks, 0); + } +} diff --git a/src/cluster/cluster.rs b/src/cluster/cluster.rs index 035a72e..31f4386 100644 --- a/src/cluster/cluster.rs +++ b/src/cluster/cluster.rs @@ -4,17 +4,18 @@ use anyhow::Result; use crate::cluster::meta_store::MetaStore; -use crate::config::{Config, ModelConfig}; +use crate::config::{BucketConfig, Config, ModelConfig}; use crate::instance::instance::AdmittedRequest; use crate::instance::kv_cache::L1Change; use crate::instance::Instance; -use crate::router::{self, RouteDecision, Router}; +use crate::router::{self, BucketId, BucketView, GlobalRouteDecision, RouteDecision, Router}; use crate::trace::RequestRecord; use crate::ttft::{classify_prefix_tiers, TtftModel}; use crate::types::InstanceId; #[derive(Debug, Clone)] pub struct AdmissionStats { + pub bucket: BucketId, pub instance: InstanceId, pub l0_hit_blocks: u32, pub l1_hit_blocks: u32, @@ -40,38 +41,39 @@ pub struct Cluster { impl Cluster { pub fn new(config: &Config, model: &ModelConfig) -> Result { let total_instances = config.cluster.require_legacy_single_pool("Cluster::new")?; - let mut instances = Vec::with_capacity(total_instances as usize); - for id in 0..total_instances { - instances.push(Instance::new( - id as InstanceId, - model, - &config.hardware, - &config.calibration, - )); - } - let meta_store = MetaStore::new(config.cluster.meta_store.ttl_seconds); - let router = router::build(config, config.sim.seed); - Ok(Self { - instances, - meta_store, - router, - block_size_tokens: model.block_size_tokens, - kv_block_bytes: model.kv_block_bytes(), - ttft_model: TtftModel::new( - &config.hardware, - &config.calibration, - model.kv_block_bytes(), - ), - }) + Self::build_local_cluster(config, model, total_instances) + } + + pub fn new_for_bucket( + config: &Config, + model: &ModelConfig, + _bucket_id: BucketId, + num_instances: u32, + ) -> Result { + let mut local_config = config.clone(); + local_config.cluster.num_instances = Some(num_instances); + local_config.cluster.buckets.clear(); + Self::build_local_cluster(&local_config, model, num_instances) } /// Route + admit a request. Returns the chosen instance plus rich /// per-request stats for metrics. Does NOT schedule the BatchTick — the /// simulator driver does that based on the returned `ready_at`. pub fn route_and_admit(&mut self, req: &RequestRecord, now: f64) -> AdmissionStats { + let global = GlobalRouteDecision::single_bucket(req.req_id, 0); + self.route_and_admit_with_global(req, now, &global) + } + + pub fn route_and_admit_with_global( + &mut self, + req: &RequestRecord, + now: f64, + global: &GlobalRouteDecision, + ) -> AdmissionStats { let decision = self .router - .route(req, &self.instances, &self.meta_store, now); + .route(req, &self.instances, &self.meta_store, now) + .with_global(global); let inst_id = decision.chosen; let probe_overhead_s = decision.probe_overhead_s; let scheduler_overhead_s = self @@ -154,6 +156,7 @@ impl Cluster { let fetch_time_s = (t - effective_now).max(0.0); AdmissionStats { + bucket: decision.chosen_bucket, instance: inst_id, l0_hit_blocks: l0_hits, l1_hit_blocks: l1_hits, @@ -168,6 +171,17 @@ impl Cluster { } } + pub fn bucket_view(&self, bucket_id: BucketId, cfg: &BucketConfig) -> BucketView { + BucketView { + id: bucket_id, + input_length_min: cfg.input_length_min, + input_length_max: cfg.input_length_max, + num_instances: self.instances.len() as u32, + total_queue_len: self.instances.iter().map(Instance::queue_len).sum(), + total_load_blocks: self.instances.iter().map(|inst| inst.kv_blocks_used).sum(), + } + } + fn apply_l1_changes( meta_store: &mut MetaStore, inst_id: InstanceId, @@ -181,6 +195,36 @@ impl Cluster { } } } + + fn build_local_cluster( + config: &Config, + model: &ModelConfig, + num_instances: u32, + ) -> Result { + let mut instances = Vec::with_capacity(num_instances as usize); + for id in 0..num_instances { + instances.push(Instance::new( + id as InstanceId, + model, + &config.hardware, + &config.calibration, + )); + } + let meta_store = MetaStore::new(config.cluster.meta_store.ttl_seconds); + let router = router::build(config, config.sim.seed); + Ok(Self { + instances, + meta_store, + router, + block_size_tokens: model.block_size_tokens, + kv_block_bytes: model.kv_block_bytes(), + ttft_model: TtftModel::new( + &config.hardware, + &config.calibration, + model.kv_block_bytes(), + ), + }) + } } #[cfg(test)] diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index bc36667..01893e8 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -1,6 +1,8 @@ +pub mod bucketed_service; #[allow(clippy::module_inception)] pub mod cluster; pub mod meta_store; +pub use bucketed_service::BucketedService; pub use cluster::Cluster; pub use meta_store::MetaStore; diff --git a/src/router/adaptive_affinity.rs b/src/router/adaptive_affinity.rs index c1ff154..8b8c7cb 100644 --- a/src/router/adaptive_affinity.rs +++ b/src/router/adaptive_affinity.rs @@ -242,13 +242,13 @@ impl Router for AdaptiveAffinityRouter { self.observe(fp, now); - RouteDecision { - req_id: req.req_id, - mode: "adaptive_affinity", - chosen: instances[chosen_idx].id, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "adaptive_affinity", + instances[chosen_idx].id, + 0.0, candidates, reason, - } + ) } } diff --git a/src/router/cache_affinity.rs b/src/router/cache_affinity.rs index 5e09d33..a42a1af 100644 --- a/src/router/cache_affinity.rs +++ b/src/router/cache_affinity.rs @@ -205,13 +205,13 @@ impl Router for CacheAffinityRouter { } } - RouteDecision { - req_id: req.req_id, - mode: "cache_affinity", - chosen: instances[best_idx].id, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "cache_affinity", + instances[best_idx].id, + 0.0, candidates, - reason: "argmin(α·q − γ·l0_hit − δ·meta_only) + rendezvous tiebreak", - } + "argmin(α·q − γ·l0_hit − δ·meta_only) + rendezvous tiebreak", + ) } } diff --git a/src/router/cache_load.rs b/src/router/cache_load.rs index 142e00c..3101c6f 100644 --- a/src/router/cache_load.rs +++ b/src/router/cache_load.rs @@ -77,13 +77,13 @@ impl Router for CacheLoadRouter { }); } - RouteDecision { - req_id: req.req_id, - mode: "cache_load", - chosen: instances[best_idx].id, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "cache_load", + instances[best_idx].id, + 0.0, candidates, - reason: "least-loaded 1/4, then best local L0 prefix", - } + "least-loaded 1/4, then best local L0 prefix", + ) } } diff --git a/src/router/cache_score.rs b/src/router/cache_score.rs index ab23882..a727f32 100644 --- a/src/router/cache_score.rs +++ b/src/router/cache_score.rs @@ -99,13 +99,13 @@ impl Router for CacheScoreRouter { } } - RouteDecision { - req_id: req.req_id, - mode: "cache_score", - chosen: instances[best_idx].id, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "cache_score", + instances[best_idx].id, + 0.0, candidates, - reason: "argmin 2^(α·load + β·miss)", - } + "argmin 2^(α·load + β·miss)", + ) } } diff --git a/src/router/cache_score_ttl.rs b/src/router/cache_score_ttl.rs index c67ad1d..d411247 100644 --- a/src/router/cache_score_ttl.rs +++ b/src/router/cache_score_ttl.rs @@ -74,13 +74,13 @@ impl Router for CacheScoreTtlRouter { } } - RouteDecision { - req_id: req.req_id, - mode: "cache_score_ttl", - chosen: instances[best_idx].id, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "cache_score_ttl", + instances[best_idx].id, + 0.0, candidates, - reason: "argmin 2^(alpha*load + beta*meta_store_miss)", - } + "argmin 2^(alpha*load + beta*meta_store_miss)", + ) } } diff --git a/src/router/estimated_ttft.rs b/src/router/estimated_ttft.rs index bc206bc..e4b8a69 100644 --- a/src/router/estimated_ttft.rs +++ b/src/router/estimated_ttft.rs @@ -89,13 +89,13 @@ impl Router for EstimatedTtftRouter { } } - RouteDecision { - req_id: req.req_id, - mode: "estimated_ttft", - chosen: best, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "estimated_ttft", + best, + 0.0, candidates, - reason: "argmin(drain + scheduler + kv_prepare + prefill + first_token_tail)", - } + "argmin(drain + scheduler + kv_prepare + prefill + first_token_tail)", + ) } } diff --git a/src/router/global_bucket.rs b/src/router/global_bucket.rs new file mode 100644 index 0000000..0684e17 --- /dev/null +++ b/src/router/global_bucket.rs @@ -0,0 +1,133 @@ +use serde::Serialize; + +use crate::config::{Config, GlobalRouterMode}; +use crate::trace::RequestRecord; + +pub type BucketId = u32; + +#[derive(Debug, Clone, Serialize)] +pub struct BucketView { + pub id: BucketId, + pub input_length_min: u32, + pub input_length_max: u32, + pub num_instances: u32, + pub total_queue_len: u32, + pub total_load_blocks: u32, +} + +#[derive(Debug, Clone, Serialize)] +pub struct BucketCandidate { + pub bucket: BucketId, + pub input_length_min: u32, + pub input_length_max: u32, + pub num_instances: u32, + pub total_queue_len: u32, + pub total_load_blocks: u32, + pub matches_input_len: bool, +} + +#[derive(Debug, Clone, Serialize)] +pub struct GlobalRouteDecision { + pub req_id: u64, + pub mode: &'static str, + pub chosen_bucket: BucketId, + pub candidates: Vec, + pub reason: &'static str, +} + +impl GlobalRouteDecision { + pub fn single_bucket(req_id: u64, chosen_bucket: BucketId) -> Self { + Self { + req_id, + mode: "single_pool", + chosen_bucket, + candidates: Vec::new(), + reason: "single pool uses bucket 0", + } + } +} + +pub trait GlobalRouter: Send { + fn name(&self) -> &'static str; + fn route( + &mut self, + req: &RequestRecord, + buckets: &[BucketView], + now: f64, + ) -> GlobalRouteDecision; +} + +struct StrictInputLengthRouter { + reported_mode: &'static str, + reason: &'static str, +} + +impl StrictInputLengthRouter { + fn new(reported_mode: &'static str, reason: &'static str) -> Self { + Self { + reported_mode, + reason, + } + } +} + +impl GlobalRouter for StrictInputLengthRouter { + fn name(&self) -> &'static str { + self.reported_mode + } + + fn route( + &mut self, + req: &RequestRecord, + buckets: &[BucketView], + _now: f64, + ) -> GlobalRouteDecision { + let candidates = buckets + .iter() + .map(|view| BucketCandidate { + bucket: view.id, + input_length_min: view.input_length_min, + input_length_max: view.input_length_max, + num_instances: view.num_instances, + total_queue_len: view.total_queue_len, + total_load_blocks: view.total_load_blocks, + matches_input_len: view.input_length_min <= req.input_len + && req.input_len <= view.input_length_max, + }) + .collect::>(); + + let matches = candidates + .iter() + .filter(|candidate| candidate.matches_input_len) + .map(|candidate| candidate.bucket) + .collect::>(); + + assert_eq!( + matches.len(), + 1, + "global bucket routing requires exactly one matching bucket for input_len={}", + req.input_len + ); + + GlobalRouteDecision { + req_id: req.req_id, + mode: self.reported_mode, + chosen_bucket: matches[0], + candidates, + reason: self.reason, + } + } +} + +pub fn build(full: &Config) -> Box { + match full.cluster.global_router.mode { + GlobalRouterMode::StrictInputLength => Box::new(StrictInputLengthRouter::new( + "strict_input_length", + "unique bucket range contains input_length", + )) as Box, + GlobalRouterMode::BucketScore => Box::new(StrictInputLengthRouter::new( + "bucket_score", + "bucket_score placeholder falls back to strict_input_length", + )) as Box, + } +} diff --git a/src/router/least_loaded.rs b/src/router/least_loaded.rs index efc0ed8..272cb93 100644 --- a/src/router/least_loaded.rs +++ b/src/router/least_loaded.rs @@ -41,13 +41,13 @@ impl Router for LeastLoadedRouter { best = inst.id; } } - RouteDecision { - req_id: req.req_id, - mode: "least_loaded", - chosen: best, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "least_loaded", + best, + 0.0, candidates, - reason: "argmin(kv_used + alpha * queue_len)", - } + "argmin(kv_used + alpha * queue_len)", + ) } } diff --git a/src/router/least_tokens.rs b/src/router/least_tokens.rs index effdad7..2c06427 100644 --- a/src/router/least_tokens.rs +++ b/src/router/least_tokens.rs @@ -61,13 +61,13 @@ impl Router for LeastTokensRouter { } } - RouteDecision { - req_id: req.req_id, - mode: "least_tokens", - chosen: best, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "least_tokens", + best, + 0.0, candidates, - reason: "argmin(waiting_prefill_tokens)", - } + "argmin(waiting_prefill_tokens)", + ) } } diff --git a/src/router/lineage_affinity.rs b/src/router/lineage_affinity.rs index d23a580..0e9035b 100644 --- a/src/router/lineage_affinity.rs +++ b/src/router/lineage_affinity.rs @@ -231,13 +231,13 @@ impl Router for LineageAffinityRouter { self.request_home .insert(req.chat_id, instances[chosen.idx].id); - RouteDecision { - req_id: req.req_id, - mode: "lineage_affinity", - chosen: instances[chosen.idx].id, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "lineage_affinity", + instances[chosen.idx].id, + 0.0, candidates, reason, - } + ) } } diff --git a/src/router/min_pd.rs b/src/router/min_pd.rs index f801639..2033dee 100644 --- a/src/router/min_pd.rs +++ b/src/router/min_pd.rs @@ -90,13 +90,13 @@ impl Router for MinPdRouter { } } - RouteDecision { - req_id: req.req_id, - mode: "min_pd", - chosen: best, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "min_pd", + best, + 0.0, candidates, - reason: "argmin(P*D), P=local-L0 miss tokens, D=ongoing reqs", - } + "argmin(P*D), P=local-L0 miss tokens, D=ongoing reqs", + ) } } diff --git a/src/router/mod.rs b/src/router/mod.rs index 2a23e3c..e2e7676 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -6,6 +6,7 @@ pub mod cache_load; pub mod cache_score; pub mod cache_score_ttl; pub mod estimated_ttft; +pub mod global_bucket; pub mod least_loaded; pub mod least_tokens; pub mod lineage_affinity; @@ -23,6 +24,8 @@ use crate::instance::Instance; use crate::trace::RequestRecord; use crate::types::InstanceId; +pub use global_bucket::{BucketCandidate, BucketId, BucketView, GlobalRouteDecision, GlobalRouter}; + #[derive(Debug, Clone, Serialize)] pub struct CandidateInfo { pub instance: InstanceId, @@ -34,13 +37,25 @@ pub struct CandidateInfo { #[derive(Debug, Clone, Serialize)] pub struct RouteDecision { pub req_id: u64, + pub global_mode: &'static str, pub mode: &'static str, + pub chosen_bucket: BucketId, pub chosen: InstanceId, pub probe_overhead_s: f64, + pub bucket_candidates: Vec, pub candidates: Vec, pub reason: &'static str, } +impl RouteDecision { + pub fn with_global(mut self, decision: &GlobalRouteDecision) -> Self { + self.global_mode = decision.mode; + self.chosen_bucket = decision.chosen_bucket; + self.bucket_candidates = decision.candidates.clone(); + self + } +} + pub trait Router: Send { fn name(&self) -> &'static str; fn route( @@ -63,6 +78,27 @@ pub(crate) fn local_l0_scores(req: &RequestRecord, instances: &[Instance]) -> Ve .collect() } +pub fn local_route_decision( + req_id: u64, + mode: &'static str, + chosen: InstanceId, + probe_overhead_s: f64, + candidates: Vec, + reason: &'static str, +) -> RouteDecision { + RouteDecision { + req_id, + global_mode: "single_pool", + mode, + chosen_bucket: 0, + chosen, + probe_overhead_s, + bucket_candidates: Vec::new(), + candidates, + reason, + } +} + pub fn build(full: &Config, seed: u64) -> Box { use crate::config::RouterMode::*; let cfg = &full.cluster.router; @@ -122,6 +158,10 @@ pub fn build(full: &Config, seed: u64) -> Box { } } +pub fn build_global(full: &Config) -> Box { + global_bucket::build(full) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/router/precise_aware.rs b/src/router/precise_aware.rs index 9940815..e0d90cd 100644 --- a/src/router/precise_aware.rs +++ b/src/router/precise_aware.rs @@ -62,13 +62,13 @@ impl Router for PreciseRouter { } } - RouteDecision { - req_id: req.req_id, - mode: "precise", - chosen: best, - probe_overhead_s: n as f64 * self.probe_latency_s, + crate::router::local_route_decision( + req.req_id, + "precise", + best, + n as f64 * self.probe_latency_s, candidates, - reason: "exact-probe all instances' L0 cache", - } + "exact-probe all instances' L0 cache", + ) } } diff --git a/src/router/prefix_affinity.rs b/src/router/prefix_affinity.rs index ca6bcb3..bde3443 100644 --- a/src/router/prefix_affinity.rs +++ b/src/router/prefix_affinity.rs @@ -166,13 +166,13 @@ impl Router for PrefixAffinityRouter { reason = "prefix affinity: top-K min drain"; } - RouteDecision { - req_id: req.req_id, - mode: "prefix_affinity", - chosen: instances[best_idx].id, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "prefix_affinity", + instances[best_idx].id, + 0.0, candidates, reason, - } + ) } } diff --git a/src/router/random.rs b/src/router/random.rs index c4bfb17..8f2deae 100644 --- a/src/router/random.rs +++ b/src/router/random.rs @@ -33,19 +33,19 @@ impl Router for RandomRouter { ) -> RouteDecision { let n = instances.len(); let chosen = self.rng.gen_range(0..n) as InstanceId; - RouteDecision { - req_id: req.req_id, - mode: "random", + crate::router::local_route_decision( + req.req_id, + "random", chosen, - probe_overhead_s: 0.0, - candidates: vec![CandidateInfo { + 0.0, + vec![CandidateInfo { instance: chosen, predicted_prefix: 0, load_blocks: instances[chosen as usize].kv_blocks_used, queue_len: instances[chosen as usize].queue_len(), }], - reason: "uniform random", - } + "uniform random", + ) } } @@ -75,18 +75,18 @@ impl Router for RoundRobinRouter { let n = instances.len() as u32; let chosen = self.next % n; self.next = self.next.wrapping_add(1); - RouteDecision { - req_id: req.req_id, - mode: "round_robin", + crate::router::local_route_decision( + req.req_id, + "round_robin", chosen, - probe_overhead_s: 0.0, - candidates: vec![CandidateInfo { + 0.0, + vec![CandidateInfo { instance: chosen, predicted_prefix: 0, load_blocks: instances[chosen as usize].kv_blocks_used, queue_len: instances[chosen as usize].queue_len(), }], - reason: "round robin", - } + "round robin", + ) } } diff --git a/src/router/ttl_aware.rs b/src/router/ttl_aware.rs index 04481d0..6206ae7 100644 --- a/src/router/ttl_aware.rs +++ b/src/router/ttl_aware.rs @@ -46,13 +46,13 @@ impl Router for TtlAwareRouter { best = inst.id; } } - RouteDecision { - req_id: req.req_id, - mode: "ttl_aware", - chosen: best, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "ttl_aware", + best, + 0.0, candidates, - reason: "max meta_store prefix, tie -> least loaded", - } + "max meta_store prefix, tie -> least loaded", + ) } } From 3a84c1506841becac0dd7cc1b7299ff105698a0a Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 17 Apr 2026 15:15:18 +0800 Subject: [PATCH 7/9] fix: harden bucket routing review follow-up --- src/cluster/bucketed_service.rs | 61 +++++++++++++++++++++++++++++---- src/router/global_bucket.rs | 38 +++++++++++++------- src/router/mod.rs | 9 +++-- 3 files changed, 85 insertions(+), 23 deletions(-) diff --git a/src/cluster/bucketed_service.rs b/src/cluster/bucketed_service.rs index 104f77a..4720965 100644 --- a/src/cluster/bucketed_service.rs +++ b/src/cluster/bucketed_service.rs @@ -1,3 +1,5 @@ +use anyhow::Result; + use super::cluster::{AdmissionStats, Cluster}; use crate::config::{BucketConfig, Config, ModelConfig}; use crate::instance::Instance; @@ -46,17 +48,17 @@ impl BucketedService { &self.buckets[bucket_id as usize] } - pub fn route_and_admit(&mut self, req: &RequestRecord, now: f64) -> AdmissionStats { + pub fn route_and_admit(&mut self, req: &RequestRecord, now: f64) -> Result { let bucket_views = self .buckets .iter() .map(|bucket| bucket.cluster.bucket_view(bucket.id, &bucket.cfg)) .collect::>(); - let global = self.global_router.route(req, &bucket_views, now); + let global = self.global_router.route(req, &bucket_views, now)?; let bucket = &mut self.buckets[global.chosen_bucket as usize]; - bucket + Ok(bucket .cluster - .route_and_admit_with_global(req, now, &global) + .route_and_admit_with_global(req, now, &global)) } } @@ -167,9 +169,19 @@ mod tests { fn strict_input_length_routes_into_matching_bucket() { let cfg = test_config(); let mut service = BucketedService::new(&cfg, &cfg.model); - let stats = service.route_and_admit(&req(1, 24, &[10, 11]), 0.0); + let stats = service + .route_and_admit(&req(1, 24, &[10, 11]), 0.0) + .unwrap(); assert_eq!(stats.bucket, 0); assert_eq!(stats.decision.chosen_bucket, 0); + assert_eq!( + stats.decision.global_reason, + "unique bucket range contains input_length" + ); + assert_eq!( + stats.decision.local_reason, + "argmin(kv_used + alpha * queue_len)" + ); assert_eq!(service.bucket(0).instances().len(), 2); } @@ -177,10 +189,45 @@ mod tests { fn bucket_meta_store_is_isolated() { let cfg = test_config(); let mut service = BucketedService::new(&cfg, &cfg.model); - let _ = service.route_and_admit(&req(1, 24, &[10, 11]), 0.0); - let long_stats = service.route_and_admit(&req(2, 64, &[10, 11, 12, 13]), 1.0); + let _ = service + .route_and_admit(&req(1, 24, &[10, 11]), 0.0) + .unwrap(); + let long_stats = service + .route_and_admit(&req(2, 64, &[10, 11, 12, 13]), 1.0) + .unwrap(); assert_eq!(long_stats.bucket, 1); assert_eq!(long_stats.remote_hit_blocks, 0); assert_eq!(long_stats.l1_hit_blocks, 0); } + + #[test] + fn unmatched_input_length_returns_recoverable_error() { + let mut cfg = test_config(); + cfg.cluster.buckets[1].input_length_min = 40; + let mut service = BucketedService::new(&cfg, &cfg.model); + + let err = service + .route_and_admit(&req(3, 36, &[20, 21, 22]), 0.0) + .unwrap_err(); + + assert!(err.to_string().contains("no bucket")); + assert!(err.to_string().contains("input_length=36")); + } + + #[test] + fn bucket_score_placeholder_reports_strict_fallback() { + let mut cfg = test_config(); + cfg.cluster.global_router.mode = GlobalRouterMode::BucketScore; + let mut service = BucketedService::new(&cfg, &cfg.model); + + let stats = service + .route_and_admit(&req(4, 24, &[30, 31]), 0.0) + .unwrap(); + + assert_eq!(stats.decision.global_mode, "strict_input_length"); + assert!(stats + .decision + .global_reason + .contains("bucket_score is not implemented")); + } } diff --git a/src/router/global_bucket.rs b/src/router/global_bucket.rs index 0684e17..f2f047a 100644 --- a/src/router/global_bucket.rs +++ b/src/router/global_bucket.rs @@ -1,3 +1,4 @@ +use anyhow::{anyhow, Result}; use serde::Serialize; use crate::config::{Config, GlobalRouterMode}; @@ -54,7 +55,7 @@ pub trait GlobalRouter: Send { req: &RequestRecord, buckets: &[BucketView], now: f64, - ) -> GlobalRouteDecision; + ) -> Result; } struct StrictInputLengthRouter { @@ -81,7 +82,7 @@ impl GlobalRouter for StrictInputLengthRouter { req: &RequestRecord, buckets: &[BucketView], _now: f64, - ) -> GlobalRouteDecision { + ) -> Result { let candidates = buckets .iter() .map(|view| BucketCandidate { @@ -102,20 +103,31 @@ impl GlobalRouter for StrictInputLengthRouter { .map(|candidate| candidate.bucket) .collect::>(); - assert_eq!( - matches.len(), - 1, - "global bucket routing requires exactly one matching bucket for input_len={}", - req.input_len - ); + let chosen_bucket = match matches.as_slice() { + [bucket] => *bucket, + [] => { + return Err(anyhow!( + "cluster.global_router.mode={} has no bucket for input_length={}", + self.reported_mode, + req.input_len + )); + } + _ => { + return Err(anyhow!( + "cluster.global_router.mode={} matched multiple buckets for input_length={}", + self.reported_mode, + req.input_len + )); + } + }; - GlobalRouteDecision { + Ok(GlobalRouteDecision { req_id: req.req_id, mode: self.reported_mode, - chosen_bucket: matches[0], + chosen_bucket, candidates, reason: self.reason, - } + }) } } @@ -126,8 +138,8 @@ pub fn build(full: &Config) -> Box { "unique bucket range contains input_length", )) as Box, GlobalRouterMode::BucketScore => Box::new(StrictInputLengthRouter::new( - "bucket_score", - "bucket_score placeholder falls back to strict_input_length", + "strict_input_length", + "bucket_score is not implemented in Task 2; falling back to strict_input_length", )) as Box, } } diff --git a/src/router/mod.rs b/src/router/mod.rs index e2e7676..ef68dd5 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -39,17 +39,19 @@ pub struct RouteDecision { pub req_id: u64, pub global_mode: &'static str, pub mode: &'static str, + pub global_reason: &'static str, + pub local_reason: &'static str, pub chosen_bucket: BucketId, pub chosen: InstanceId, pub probe_overhead_s: f64, pub bucket_candidates: Vec, pub candidates: Vec, - pub reason: &'static str, } impl RouteDecision { pub fn with_global(mut self, decision: &GlobalRouteDecision) -> Self { self.global_mode = decision.mode; + self.global_reason = decision.reason; self.chosen_bucket = decision.chosen_bucket; self.bucket_candidates = decision.candidates.clone(); self @@ -90,12 +92,13 @@ pub fn local_route_decision( req_id, global_mode: "single_pool", mode, + global_reason: "single pool uses bucket 0", + local_reason: reason, chosen_bucket: 0, chosen, probe_overhead_s, bucket_candidates: Vec::new(), candidates, - reason, } } @@ -393,7 +396,7 @@ mod tests { let mut router = PrefixAffinityRouter::new(&cfg); let decision = router.route(&req, &instances, &meta, 0.0); - assert_eq!(decision.reason, "affinity fallback: min(drain+fetch)"); + assert_eq!(decision.local_reason, "affinity fallback: min(drain+fetch)"); assert_eq!(decision.chosen, 1); } From b5a6fb964c60953f53254e6d5d5c23e9b7932c51 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 17 Apr 2026 17:52:49 +0800 Subject: [PATCH 8/9] feat: wire bucket identities through driver outputs --- src/config.rs | 5 ++- src/driver.rs | 53 ++++++++++++++++----------- src/main.rs | 3 +- src/metrics/per_request.rs | 2 ++ src/metrics/timeseries.rs | 1 + src/oracle.rs | 12 ++++--- src/sim/engine.rs | 35 +++++++++++++++--- src/sim/events.rs | 6 +++- tests/smoke.rs | 74 +++++++++++++++++++++++++++++++++----- 9 files changed, 148 insertions(+), 43 deletions(-) diff --git a/src/config.rs b/src/config.rs index 36e4947..087b960 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1288,7 +1288,10 @@ sim: ); let cfg = Config::from_yaml_path(&path).unwrap(); - assert_eq!(cfg.cluster.global_router.mode, GlobalRouterMode::BucketScore); + assert_eq!( + cfg.cluster.global_router.mode, + GlobalRouterMode::BucketScore + ); assert!((cfg.cluster.global_router.length_penalty_weight - 1.5).abs() < 1e-12); assert!((cfg.cluster.global_router.load_weight - 0.75).abs() < 1e-12); assert!((cfg.cluster.global_router.cache_weight - 2.25).abs() < 1e-12); diff --git a/src/driver.rs b/src/driver.rs index dbc6b5e..0494942 100644 --- a/src/driver.rs +++ b/src/driver.rs @@ -6,7 +6,7 @@ use std::collections::{HashMap, VecDeque}; use std::path::Path; use std::sync::{Arc, Mutex}; -use crate::cluster::Cluster; +use crate::cluster::BucketedService; use crate::config::{Config, RouterMode}; use crate::metrics::ablation::AblationRow; use crate::metrics::per_request::{PerRequestRow, PerRequestWriter}; @@ -37,7 +37,9 @@ pub struct RunOutputs { #[derive(Debug, Clone)] struct InflightInfo { arrival: f64, + bucket: u32, instance: u32, + length_bucket_match: bool, total_blocks: u32, l0_hit_blocks: u32, l1_hit_blocks: u32, @@ -49,10 +51,7 @@ struct InflightInfo { } pub fn run(config: &Config, output_subdir: Option<&str>) -> Result { - config - .cluster - .require_legacy_single_pool("driver run")?; - let mut cluster = Cluster::new(config, &config.model)?; + let mut service = BucketedService::new(config, &config.model); let mut q = EventQueue::new(); // Output directory @@ -111,13 +110,16 @@ pub fn run(config: &Config, output_subdir: Option<&str>) -> Result { Some(r) => r.clone(), None => continue, }; - let stats = cluster.route_and_admit(&req, now); + let stats = service.route_and_admit(&req, now)?; rt_writer.write(&stats.decision)?; + let strict_bucket = config.cluster.bucket_index_for_input_len(req.input_len)?; inflight.insert( req_id, InflightInfo { arrival: req.arrival, + bucket: stats.bucket, instance: stats.instance, + length_bucket_match: stats.bucket as usize == strict_bucket, total_blocks: req.hash_ids.len() as u32, l0_hit_blocks: stats.l0_hit_blocks, l1_hit_blocks: stats.l1_hit_blocks, @@ -128,20 +130,23 @@ pub fn run(config: &Config, output_subdir: Option<&str>) -> Result { probe_overhead_s: stats.probe_overhead_s, }, ); - let inst = &mut cluster.instances[stats.instance as usize]; + let inst = &mut service.buckets[stats.bucket as usize].cluster.instances + [stats.instance as usize]; if !inst.tick_scheduled { inst.tick_scheduled = true; let when = stats.ready_at.max(now); q.schedule( when, Event::BatchTick { + bucket: stats.bucket, instance: stats.instance, }, ); } } - Event::BatchTick { instance } => { - let inst = &mut cluster.instances[instance as usize]; + Event::BatchTick { bucket, instance } => { + let inst = + &mut service.buckets[bucket as usize].cluster.instances[instance as usize]; inst.tick_scheduled = false; let result = inst.step(now); for (rid, ttft, end) in result.completed { @@ -151,7 +156,9 @@ pub fn run(config: &Config, output_subdir: Option<&str>) -> Result { arrival: info.arrival, ttft, e2e: end - info.arrival, + bucket: info.bucket, instance: info.instance, + length_bucket_match: info.length_bucket_match, total_blocks: info.total_blocks, l0_hit_blocks: info.l0_hit_blocks, l1_hit_blocks: info.l1_hit_blocks, @@ -166,24 +173,28 @@ pub fn run(config: &Config, output_subdir: Option<&str>) -> Result { } } if let Some(next) = result.next_tick { - let inst = &mut cluster.instances[instance as usize]; + let inst = + &mut service.buckets[bucket as usize].cluster.instances[instance as usize]; if !inst.tick_scheduled { inst.tick_scheduled = true; - q.schedule(next.max(now), Event::BatchTick { instance }); + q.schedule(next.max(now), Event::BatchTick { bucket, instance }); } } } Event::Sample => { - for inst in &cluster.instances { - let busy = if inst.queue_len() > 0 { 1 } else { 0 }; - ts_writer.write(&TimeseriesRow { - t: now, - instance: inst.id, - queue_len: inst.queue_len(), - kv_blocks_used: inst.kv_blocks_used, - kv_blocks_total: inst.hbm_block_budget, - busy, - })?; + for bucket in &service.buckets { + for inst in &bucket.cluster.instances { + let busy = if inst.queue_len() > 0 { 1 } else { 0 }; + ts_writer.write(&TimeseriesRow { + t: now, + bucket: bucket.id, + instance: inst.id, + queue_len: inst.queue_len(), + kv_blocks_used: inst.kv_blocks_used, + kv_blocks_total: inst.hbm_block_budget, + busy, + })?; + } } } Event::Stop => break, diff --git a/src/main.rs b/src/main.rs index b174148..af59237 100644 --- a/src/main.rs +++ b/src/main.rs @@ -471,8 +471,7 @@ fn cmd_oracle( out_path: Option<&std::path::Path>, ) -> Result<()> { let cfg = load(path, overrides)?; - cfg.cluster - .require_legacy_single_pool("oracle analysis")?; + cfg.cluster.require_legacy_single_pool("oracle analysis")?; let block_bytes = cfg.model.kv_block_bytes() as f64; let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64; let aggregate_blocks = per_instance_blocks * cfg.cluster.total_instances() as u64; diff --git a/src/metrics/per_request.rs b/src/metrics/per_request.rs index b4ccdd8..9d2a792 100644 --- a/src/metrics/per_request.rs +++ b/src/metrics/per_request.rs @@ -8,7 +8,9 @@ pub struct PerRequestRow { pub arrival: f64, pub ttft: f64, pub e2e: f64, + pub bucket: u32, pub instance: u32, + pub length_bucket_match: bool, pub total_blocks: u32, pub l0_hit_blocks: u32, pub l1_hit_blocks: u32, diff --git a/src/metrics/timeseries.rs b/src/metrics/timeseries.rs index cb670e0..3851465 100644 --- a/src/metrics/timeseries.rs +++ b/src/metrics/timeseries.rs @@ -5,6 +5,7 @@ use std::path::Path; #[derive(Debug, Clone, Serialize)] pub struct TimeseriesRow { pub t: f64, + pub bucket: u32, pub instance: u32, pub queue_len: u32, pub kv_blocks_used: u32, diff --git a/src/oracle.rs b/src/oracle.rs index 064a2d2..590736f 100644 --- a/src/oracle.rs +++ b/src/oracle.rs @@ -196,7 +196,12 @@ fn build_next_use(records: &[RequestRecord]) -> Vec> { /// Implementation: lazy-deletion max-heap keyed by next-use index. Each /// cache entry has a version; the heap may contain stale entries from /// previous insertions, which we skip on pop. -fn run_belady(records: &[RequestRecord], next_use: &[Vec], capacity: usize, mask: &[bool]) -> u64 { +fn run_belady( + records: &[RequestRecord], + next_use: &[Vec], + capacity: usize, + mask: &[bool], +) -> u64 { if capacity == 0 { return 0; } @@ -327,10 +332,7 @@ mod tests { // req 0 populates blocks [1,2,3] but is not counted. // req 1 has prefix [1,2,3,4] — the first 3 blocks are cache hits // because req 0 populated them, even though req 0 is masked out. - let recs = vec![ - req(0, 0.0, vec![1, 2, 3]), - req(1, 1.0, vec![1, 2, 3, 4]), - ]; + let recs = vec![req(0, 0.0, vec![1, 2, 3]), req(1, 1.0, vec![1, 2, 3, 4])]; let mask = vec![false, true]; let out = analyze(&recs, 100, Some(&mask)); // Only req 1 is counted: total = 4, hits = 3 (prefix [1,2,3] hit) diff --git a/src/sim/engine.rs b/src/sim/engine.rs index 7a515e6..83651a3 100644 --- a/src/sim/engine.rs +++ b/src/sim/engine.rs @@ -91,11 +91,24 @@ mod tests { q.schedule( 2.0, Event::BatchTick { + bucket: 0, instance: 0 as InstanceId, }, ); - q.schedule(1.0, Event::BatchTick { instance: 1 }); - q.schedule(1.5, Event::BatchTick { instance: 2 }); + q.schedule( + 1.0, + Event::BatchTick { + bucket: 0, + instance: 1, + }, + ); + q.schedule( + 1.5, + Event::BatchTick { + bucket: 0, + instance: 2, + }, + ); let (t1, _) = q.pop().unwrap(); let (t2, _) = q.pop().unwrap(); let (t3, _) = q.pop().unwrap(); @@ -107,12 +120,24 @@ mod tests { #[test] fn equal_time_fifo() { let mut q = EventQueue::new(); - q.schedule(1.0, Event::BatchTick { instance: 7 }); - q.schedule(1.0, Event::BatchTick { instance: 8 }); + q.schedule( + 1.0, + Event::BatchTick { + bucket: 0, + instance: 7, + }, + ); + q.schedule( + 1.0, + Event::BatchTick { + bucket: 1, + instance: 8, + }, + ); let (_, e1) = q.pop().unwrap(); let (_, e2) = q.pop().unwrap(); match (e1, e2) { - (Event::BatchTick { instance: a }, Event::BatchTick { instance: b }) => { + (Event::BatchTick { instance: a, .. }, Event::BatchTick { instance: b, .. }) => { assert_eq!(a, 7); assert_eq!(b, 8); } diff --git a/src/sim/events.rs b/src/sim/events.rs index e369fa2..c8847a3 100644 --- a/src/sim/events.rs +++ b/src/sim/events.rs @@ -1,5 +1,6 @@ //! Event types for the discrete-event engine. +use crate::router::BucketId; use crate::types::{InstanceId, ReqId}; #[derive(Debug)] @@ -7,7 +8,10 @@ pub enum Event { /// New trace request arrives at the cluster router. Arrival { req_id: ReqId }, /// Per-instance scheduler tick (continuous batching). - BatchTick { instance: InstanceId }, + BatchTick { + bucket: BucketId, + instance: InstanceId, + }, /// Periodic time-series sample of all instances. Sample, /// Stop the simulation early (used internally). diff --git a/tests/smoke.rs b/tests/smoke.rs index 76374f4..8a23f8b 100644 --- a/tests/smoke.rs +++ b/tests/smoke.rs @@ -296,7 +296,69 @@ fn ablation_parallel_matches_serial() { } #[test] -fn bucketed_configs_are_rejected_by_legacy_runtime_paths() { +fn strict_bucket_run_emits_bucket_fields_in_outputs() { + let tmp = std::env::temp_dir().join("kvcache_sim_bucket_outputs"); + let _ = std::fs::remove_dir_all(&tmp); + std::fs::create_dir_all(&tmp).unwrap(); + let trace_path = tmp.join("trace.jsonl"); + + let mut f = std::fs::File::create(&trace_path).unwrap(); + writeln!( + f, + "{}", + serde_json::json!({ + "chat_id": 1, + "parent_chat_id": -1, + "timestamp": 0.0, + "input_length": 32, + "output_length": 16, + "type": "text", + "turn": 0, + "hash_ids": [1, 2] + }) + ) + .unwrap(); + writeln!( + f, + "{}", + serde_json::json!({ + "chat_id": 2, + "parent_chat_id": -1, + "timestamp": 0.1, + "input_length": 80, + "output_length": 16, + "type": "text", + "turn": 0, + "hash_ids": [3, 4, 5, 6, 7] + }) + ) + .unwrap(); + + let mut cfg = bucketed_config( + trace_path.to_str().unwrap(), + tmp.to_str().unwrap(), + RouterMode::LeastLoaded, + ); + cfg.cluster.global_router.mode = GlobalRouterMode::StrictInputLength; + cfg.sim.sample_interval_s = 0.05; + + let _ = driver::run(&cfg, Some("strict_bucket")).expect("bucketed run"); + + let per_request = std::fs::read_to_string(tmp.join("strict_bucket/per_request.csv")).unwrap(); + assert!(per_request.contains("bucket")); + assert!(per_request.contains("length_bucket_match")); + + let instances = std::fs::read_to_string(tmp.join("strict_bucket/instances.csv")).unwrap(); + assert!(instances.contains("bucket")); + + let routing_log = std::fs::read_to_string(tmp.join("strict_bucket/routing_log.jsonl")).unwrap(); + assert!(routing_log.contains("\"chosen_bucket\"")); + assert!(routing_log.contains("\"bucket_candidates\"")); + assert!(routing_log.contains("\"global_reason\"")); +} + +#[test] +fn bucketed_configs_are_rejected_by_legacy_fixed_placement_paths() { let tmp = std::env::temp_dir().join("kvcache_sim_bucketed_reject"); let _ = std::fs::remove_dir_all(&tmp); std::fs::create_dir_all(&tmp).unwrap(); @@ -309,13 +371,9 @@ fn bucketed_configs_are_rejected_by_legacy_runtime_paths() { RouterMode::Random, ); - let result = driver::run(&cfg, Some("bucketed_guard")); - assert!(result.is_err(), "bucketed run should fail"); - let err = result.err().unwrap(); - assert!(err.to_string().contains("cluster.buckets")); - - let err = driver::ablate_fixed_placement(&cfg, &[RouterMode::Random], &[ReplayEvictPolicy::Lru]) - .expect_err("bucketed ablation should fail"); + let err = + driver::ablate_fixed_placement(&cfg, &[RouterMode::Random], &[ReplayEvictPolicy::Lru]) + .expect_err("bucketed ablation should fail"); assert!(err.to_string().contains("cluster.buckets")); let err = replay::replay_fixed_placement( From 43ada0cfc0f82a6347fd8a430bbcf4163af03c00 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 17 Apr 2026 17:55:54 +0800 Subject: [PATCH 9/9] feat: add bucket score global router --- src/cluster/bucketed_service.rs | 19 +-- src/cluster/cluster.rs | 15 +- src/router/global_bucket.rs | 234 +++++++++++++++++++++++++++++++- tests/smoke.rs | 66 +++++++++ 4 files changed, 311 insertions(+), 23 deletions(-) diff --git a/src/cluster/bucketed_service.rs b/src/cluster/bucketed_service.rs index 4720965..633f851 100644 --- a/src/cluster/bucketed_service.rs +++ b/src/cluster/bucketed_service.rs @@ -52,7 +52,7 @@ impl BucketedService { let bucket_views = self .buckets .iter() - .map(|bucket| bucket.cluster.bucket_view(bucket.id, &bucket.cfg)) + .map(|bucket| bucket.cluster.bucket_view(bucket.id, &bucket.cfg, req, now)) .collect::>(); let global = self.global_router.route(req, &bucket_views, now)?; let bucket = &mut self.buckets[global.chosen_bucket as usize]; @@ -213,21 +213,4 @@ mod tests { assert!(err.to_string().contains("no bucket")); assert!(err.to_string().contains("input_length=36")); } - - #[test] - fn bucket_score_placeholder_reports_strict_fallback() { - let mut cfg = test_config(); - cfg.cluster.global_router.mode = GlobalRouterMode::BucketScore; - let mut service = BucketedService::new(&cfg, &cfg.model); - - let stats = service - .route_and_admit(&req(4, 24, &[30, 31]), 0.0) - .unwrap(); - - assert_eq!(stats.decision.global_mode, "strict_input_length"); - assert!(stats - .decision - .global_reason - .contains("bucket_score is not implemented")); - } } diff --git a/src/cluster/cluster.rs b/src/cluster/cluster.rs index 31f4386..322b867 100644 --- a/src/cluster/cluster.rs +++ b/src/cluster/cluster.rs @@ -171,7 +171,19 @@ impl Cluster { } } - pub fn bucket_view(&self, bucket_id: BucketId, cfg: &BucketConfig) -> BucketView { + pub fn bucket_view( + &self, + bucket_id: BucketId, + cfg: &BucketConfig, + req: &RequestRecord, + now: f64, + ) -> BucketView { + let predicted_prefix = self + .meta_store + .score_prefix(&req.hash_ids, now, self.instances.len()) + .into_iter() + .max() + .unwrap_or(0); BucketView { id: bucket_id, input_length_min: cfg.input_length_min, @@ -179,6 +191,7 @@ impl Cluster { num_instances: self.instances.len() as u32, total_queue_len: self.instances.iter().map(Instance::queue_len).sum(), total_load_blocks: self.instances.iter().map(|inst| inst.kv_blocks_used).sum(), + predicted_prefix, } } diff --git a/src/router/global_bucket.rs b/src/router/global_bucket.rs index f2f047a..2522e7d 100644 --- a/src/router/global_bucket.rs +++ b/src/router/global_bucket.rs @@ -14,6 +14,7 @@ pub struct BucketView { pub num_instances: u32, pub total_queue_len: u32, pub total_load_blocks: u32, + pub predicted_prefix: u32, } #[derive(Debug, Clone, Serialize)] @@ -24,7 +25,9 @@ pub struct BucketCandidate { pub num_instances: u32, pub total_queue_len: u32, pub total_load_blocks: u32, + pub predicted_prefix: u32, pub matches_input_len: bool, + pub score: f64, } #[derive(Debug, Clone, Serialize)] @@ -92,8 +95,16 @@ impl GlobalRouter for StrictInputLengthRouter { num_instances: view.num_instances, total_queue_len: view.total_queue_len, total_load_blocks: view.total_load_blocks, + predicted_prefix: view.predicted_prefix, matches_input_len: view.input_length_min <= req.input_len && req.input_len <= view.input_length_max, + score: if view.input_length_min <= req.input_len + && req.input_len <= view.input_length_max + { + 0.0 + } else { + f64::INFINITY + }, }) .collect::>(); @@ -131,15 +142,230 @@ impl GlobalRouter for StrictInputLengthRouter { } } +struct BucketScoreRouter { + length_penalty_weight: f64, + load_weight: f64, + cache_weight: f64, +} + +impl BucketScoreRouter { + fn new(full: &Config) -> Self { + Self { + length_penalty_weight: full.cluster.global_router.length_penalty_weight, + load_weight: full.cluster.global_router.load_weight, + cache_weight: full.cluster.global_router.cache_weight, + } + } + + fn length_penalty(&self, req: &RequestRecord, bucket: &BucketView) -> f64 { + if req.input_len < bucket.input_length_min { + (bucket.input_length_min - req.input_len) as f64 + } else if req.input_len > bucket.input_length_max { + (req.input_len - bucket.input_length_max) as f64 + } else { + 0.0 + } + } +} + +impl GlobalRouter for BucketScoreRouter { + fn name(&self) -> &'static str { + "bucket_score" + } + + fn route( + &mut self, + req: &RequestRecord, + buckets: &[BucketView], + _now: f64, + ) -> Result { + let mut chosen_bucket = None; + let mut best_score = f64::INFINITY; + let mut candidates = Vec::with_capacity(buckets.len()); + + for bucket in buckets { + let length_penalty = self.length_penalty(req, bucket); + let miss = req + .hash_ids + .len() + .saturating_sub(bucket.predicted_prefix as usize) as f64; + let score = self.length_penalty_weight * length_penalty + + self.load_weight * bucket.total_queue_len as f64 + + self.cache_weight * miss; + + candidates.push(BucketCandidate { + bucket: bucket.id, + input_length_min: bucket.input_length_min, + input_length_max: bucket.input_length_max, + num_instances: bucket.num_instances, + total_queue_len: bucket.total_queue_len, + total_load_blocks: bucket.total_load_blocks, + predicted_prefix: bucket.predicted_prefix, + matches_input_len: bucket.input_length_min <= req.input_len + && req.input_len <= bucket.input_length_max, + score, + }); + + let better = score < best_score + || (score == best_score && chosen_bucket.is_none_or(|best| bucket.id < best)); + if better { + best_score = score; + chosen_bucket = Some(bucket.id); + } + } + + Ok(GlobalRouteDecision { + req_id: req.req_id, + mode: self.name(), + chosen_bucket: chosen_bucket.ok_or_else(|| anyhow!("no buckets available"))?, + candidates, + reason: "weighted length/load/cache bucket score", + }) + } +} + pub fn build(full: &Config) -> Box { match full.cluster.global_router.mode { GlobalRouterMode::StrictInputLength => Box::new(StrictInputLengthRouter::new( "strict_input_length", "unique bucket range contains input_length", )) as Box, - GlobalRouterMode::BucketScore => Box::new(StrictInputLengthRouter::new( - "strict_input_length", - "bucket_score is not implemented in Task 2; falling back to strict_input_length", - )) as Box, + GlobalRouterMode::BucketScore => { + Box::new(BucketScoreRouter::new(full)) as Box + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{ + ClusterConfig, GlobalRouterConfig, MetaStoreConfig, RouterConfig, RouterMode, + }; + + fn cfg() -> Config { + Config { + model: crate::config::ModelConfig::default(), + hardware: crate::config::HardwareConfig { + gpu_flops: 1.0, + gpu_fp8_flops: 0.0, + gpu_fp4_flops: 0.0, + gpu_mem_bw: 1.0, + hbm_bytes: 1.0, + dram_bytes: 1.0, + host_dram_bw: 1.0, + pcie_bw: 1.0, + pcie_latency_us: 1.0, + rdma_bw: 1.0, + rdma_latency_us: 1.0, + intra_node_tp_bw: 1.0, + intra_node_tp_latency_us: 1.0, + tp_degree: 1, + max_batch_slots: 1, + prefill_chunk_tokens: 1, + }, + calibration: crate::config::CalibrationConfig::default(), + cluster: ClusterConfig { + num_instances: None, + buckets: Vec::new(), + global_router: GlobalRouterConfig { + mode: GlobalRouterMode::BucketScore, + length_penalty_weight: 1.0, + load_weight: 1.0, + cache_weight: 1.0, + }, + meta_store: MetaStoreConfig { ttl_seconds: 1.0 }, + router: RouterConfig { + mode: RouterMode::LeastLoaded, + precise_probe_latency_us: 1.0, + precise_probe_topk: 1, + load_alpha: 1.0, + score_alpha: 1.0, + score_beta: 1.0, + prefix_k: 8, + affinity_fan_out: 1, + }, + }, + sim: crate::config::SimConfig { + trace_path: String::new(), + max_requests: None, + output_dir: String::new(), + sample_interval_s: 0.0, + seed: 0, + input_length_min: None, + input_length_max: None, + }, + } + } + + fn req(input_len: u32) -> RequestRecord { + RequestRecord { + req_id: 1, + chat_id: 0, + parent_chat_id: -1, + turn: 0, + arrival: 0.0, + input_len, + output_len: 16, + hash_ids: vec![10, 11, 12], + } + } + + #[test] + fn bucket_score_prefers_matching_bucket_when_load_is_equal() { + let mut router = BucketScoreRouter::new(&cfg()); + let buckets = vec![ + BucketView { + id: 0, + input_length_min: 0, + input_length_max: 32, + num_instances: 2, + total_queue_len: 1, + total_load_blocks: 0, + predicted_prefix: 0, + }, + BucketView { + id: 1, + input_length_min: 33, + input_length_max: 96, + num_instances: 2, + total_queue_len: 1, + total_load_blocks: 0, + predicted_prefix: 0, + }, + ]; + let decision = router.route(&req(24), &buckets, 0.0).unwrap(); + assert_eq!(decision.chosen_bucket, 0); + } + + #[test] + fn bucket_score_can_override_length_match_when_load_gap_is_large() { + let mut full = cfg(); + full.cluster.global_router.load_weight = 5.0; + full.cluster.global_router.cache_weight = 1.0; + full.cluster.global_router.length_penalty_weight = 1.0; + let mut router = BucketScoreRouter::new(&full); + let buckets = vec![ + BucketView { + id: 0, + input_length_min: 0, + input_length_max: 32, + num_instances: 2, + total_queue_len: 20, + total_load_blocks: 0, + predicted_prefix: 0, + }, + BucketView { + id: 1, + input_length_min: 33, + input_length_max: 96, + num_instances: 2, + total_queue_len: 0, + total_load_blocks: 0, + predicted_prefix: 2, + }, + ]; + let decision = router.route(&req(24), &buckets, 0.0).unwrap(); + assert_eq!(decision.chosen_bucket, 1); } } diff --git a/tests/smoke.rs b/tests/smoke.rs index 8a23f8b..bd5ec38 100644 --- a/tests/smoke.rs +++ b/tests/smoke.rs @@ -385,3 +385,69 @@ fn bucketed_configs_are_rejected_by_legacy_fixed_placement_paths() { .expect_err("bucketed replay should fail"); assert!(err.to_string().contains("cluster.buckets")); } + +#[test] +fn bucket_score_can_deviate_from_strict_length_bucket() { + let tmp = std::env::temp_dir().join("kvcache_sim_bucket_score"); + let _ = std::fs::remove_dir_all(&tmp); + std::fs::create_dir_all(&tmp).unwrap(); + let trace_path = tmp.join("trace.jsonl"); + + let mut f = std::fs::File::create(&trace_path).unwrap(); + for req_id in 0..3 { + writeln!( + f, + "{}", + serde_json::json!({ + "chat_id": req_id, + "parent_chat_id": -1, + "timestamp": 0.0, + "input_length": 24, + "output_length": 16, + "type": "text", + "turn": 0, + "hash_ids": [100 + req_id, 200 + req_id] + }) + ) + .unwrap(); + } + + let mut strict_cfg = bucketed_config( + trace_path.to_str().unwrap(), + tmp.to_str().unwrap(), + RouterMode::LeastLoaded, + ); + strict_cfg.cluster.buckets = vec![ + BucketConfig { + name: "short".into(), + input_length_min: 0, + input_length_max: 32, + num_instances: 1, + }, + BucketConfig { + name: "long".into(), + input_length_min: 33, + input_length_max: 96, + num_instances: 1, + }, + ]; + strict_cfg.cluster.global_router.mode = GlobalRouterMode::StrictInputLength; + + let mut score_cfg = strict_cfg.clone(); + score_cfg.cluster.global_router.mode = GlobalRouterMode::BucketScore; + score_cfg.cluster.global_router.length_penalty_weight = 1.0; + score_cfg.cluster.global_router.load_weight = 10.0; + score_cfg.cluster.global_router.cache_weight = 0.0; + + let _ = driver::run(&strict_cfg, Some("strict_score_cmp")).expect("strict run"); + let _ = driver::run(&score_cfg, Some("bucket_score_cmp")).expect("bucket score run"); + + let strict_log = + std::fs::read_to_string(tmp.join("strict_score_cmp/routing_log.jsonl")).unwrap(); + let score_log = + std::fs::read_to_string(tmp.join("bucket_score_cmp/routing_log.jsonl")).unwrap(); + + assert!(strict_log.contains("\"chosen_bucket\":0")); + assert!(score_log.contains("\"global_mode\":\"bucket_score\"")); + assert!(score_log.contains("\"chosen_bucket\":1")); +}