diff --git a/src/cluster/cluster.rs b/src/cluster/cluster.rs index 9927cdc..31d42fb 100644 --- a/src/cluster/cluster.rs +++ b/src/cluster/cluster.rs @@ -37,8 +37,9 @@ pub struct Cluster { impl Cluster { pub fn new(config: &Config, model: &ModelConfig) -> Self { - let mut instances = Vec::with_capacity(config.cluster.num_instances as usize); - for id in 0..config.cluster.num_instances { + let total_instances = config.cluster.total_instances(); + let mut instances = Vec::with_capacity(total_instances as usize); + for id in 0..total_instances { instances.push(Instance::new( id as InstanceId, model, @@ -226,7 +227,9 @@ mod tests { ..CalibrationConfig::default() }, cluster: ClusterConfig { - num_instances: 1, + num_instances: Some(1), + buckets: Vec::new(), + global_router: Default::default(), meta_store: MetaStoreConfig { ttl_seconds: 1000.0, }, diff --git a/src/config.rs b/src/config.rs index 24658d7..f646dcb 100644 --- a/src/config.rs +++ b/src/config.rs @@ -13,11 +13,7 @@ use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::ops::Deref; use std::path::Path; -use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::{Mutex, OnceLock}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { @@ -383,7 +379,12 @@ fn default_first_token_ready_us() -> f64 { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ClusterConfig { - pub num_instances: u32, + #[serde(default)] + pub num_instances: Option, + #[serde(default)] + pub buckets: Vec, + #[serde(default)] + pub global_router: GlobalRouterConfig, pub meta_store: MetaStoreConfig, pub router: RouterConfig, } @@ -418,58 +419,6 @@ pub enum GlobalRouterMode { BucketScore, } -#[derive(Debug, Clone)] -pub struct ClusterBucketView { - pub buckets: Vec, - pub global_router: GlobalRouterConfig, -} - -impl Default for ClusterBucketView { - fn default() -> Self { - Self { - buckets: Vec::new(), - global_router: GlobalRouterConfig::default(), - } - } -} - -static DEFAULT_CLUSTER_BUCKET_VIEW: OnceLock = OnceLock::new(); -static CLUSTER_BUCKET_VIEWS: OnceLock>> = - OnceLock::new(); -static NEXT_CLUSTER_BUCKET_VIEW_ID: AtomicU32 = AtomicU32::new(1); - -fn default_cluster_bucket_view() -> &'static ClusterBucketView { - DEFAULT_CLUSTER_BUCKET_VIEW.get_or_init(ClusterBucketView::default) -} - -fn cluster_bucket_views() -> &'static Mutex> { - CLUSTER_BUCKET_VIEWS.get_or_init(|| Mutex::new(HashMap::new())) -} - -fn register_cluster_bucket_view(view: ClusterBucketView) -> u32 { - if view.buckets.is_empty() && view.global_router == GlobalRouterConfig::default() { - return 0; - } - - let id = NEXT_CLUSTER_BUCKET_VIEW_ID.fetch_add(1, Ordering::Relaxed); - let leaked = Box::leak(Box::new(view)); - cluster_bucket_views().lock().unwrap().insert(id, leaked); - id -} - -fn lookup_cluster_bucket_view(id: u32) -> &'static ClusterBucketView { - if id == 0 { - return default_cluster_bucket_view(); - } - - cluster_bucket_views() - .lock() - .unwrap() - .get(&id) - .copied() - .unwrap_or_else(default_cluster_bucket_view) -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MetaStoreConfig { pub ttl_seconds: f64, @@ -519,30 +468,19 @@ fn default_prefix_k() -> usize { } impl ClusterConfig { - fn bucket_view_id(&self) -> u32 { - if self.router.precise_probe_latency_us < 0.0 { - (-self.router.precise_probe_latency_us) as u32 - } else { - 0 - } - } - - fn bucket_view(&self) -> &'static ClusterBucketView { - lookup_cluster_bucket_view(self.bucket_view_id()) - } - pub fn legacy_num_instances(&self) -> Option { - self.buckets.is_empty().then_some(self.num_instances) - } - - pub fn total_instances(&self) -> u32 { if self.buckets.is_empty() { self.num_instances } else { - self.buckets.iter().map(|bucket| bucket.num_instances).sum() + None } } + pub fn total_instances(&self) -> u32 { + self.legacy_num_instances() + .unwrap_or_else(|| self.buckets.iter().map(|bucket| bucket.num_instances).sum()) + } + pub fn effective_buckets(&self) -> Vec { if !self.buckets.is_empty() { return self.buckets.clone(); @@ -552,21 +490,22 @@ impl ClusterConfig { name: "default".to_string(), input_length_min: 0, input_length_max: u32::MAX, - num_instances: self.num_instances, + num_instances: self + .num_instances + .expect("legacy single-pool cluster must have num_instances"), }] } pub fn validate(&self) -> Result<()> { - if self.num_instances == 0 && self.buckets.is_empty() { - anyhow::bail!("cluster must set either num_instances or buckets"); - } - - if self.num_instances > 0 && !self.buckets.is_empty() { + if self.num_instances.is_some() && !self.buckets.is_empty() { anyhow::bail!("cluster.num_instances and cluster.buckets are mutually exclusive"); } if self.buckets.is_empty() { - anyhow::ensure!(self.num_instances > 0, "cluster.num_instances must be positive"); + let num_instances = self.num_instances.ok_or_else(|| { + anyhow::anyhow!("cluster must set either num_instances or buckets") + })?; + anyhow::ensure!(num_instances > 0, "cluster.num_instances must be positive"); return Ok(()); } @@ -600,14 +539,6 @@ impl ClusterConfig { } } -impl Deref for ClusterConfig { - type Target = ClusterBucketView; - - fn deref(&self) -> &Self::Target { - self.bucket_view() - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum RouterMode { @@ -738,22 +669,10 @@ struct RawConfig { hardware: RawHardwareConfig, #[serde(default)] calibration: CalibrationConfig, - cluster: RawClusterConfig, + cluster: ClusterConfig, sim: SimConfig, } -#[derive(Deserialize)] -struct RawClusterConfig { - #[serde(default)] - num_instances: Option, - meta_store: MetaStoreConfig, - router: RouterConfig, - #[serde(default)] - global_router: GlobalRouterConfig, - #[serde(default)] - buckets: Vec, -} - #[derive(Deserialize)] struct RawModelConfig { /// Path to a HuggingFace `config.json`. Resolved relative to the YAML @@ -866,36 +785,15 @@ impl RawConfig { model, hardware, calibration: self.calibration, - cluster: self.cluster.resolve()?, + cluster: { + self.cluster.validate()?; + self.cluster + }, sim: self.sim, }) } } -impl RawClusterConfig { - fn resolve(self) -> Result { - let view = ClusterBucketView { - buckets: self.buckets, - global_router: self.global_router, - }; - let mut cluster = ClusterConfig { - num_instances: self.num_instances.unwrap_or(0), - meta_store: self.meta_store, - router: self.router, - }; - - let view_id = register_cluster_bucket_view(view); - if view_id > 0 { - // Preserve the public ClusterConfig layout for existing callers while - // carrying bucketed config through validation and tests. - cluster.router.precise_probe_latency_us = -(view_id as f64); - } - - cluster.validate()?; - Ok(cluster) - } -} - impl RawModelConfig { fn resolve(self, yaml_dir: &Path) -> Result { // Start from HF config.json if specified, else empty default. diff --git a/src/main.rs b/src/main.rs index c0b6e6f..37d6677 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,7 +52,8 @@ struct ConfigOverrides { impl ConfigOverrides { fn apply(&self, cfg: &mut Config) { if let Some(n) = self.num_instances { - cfg.cluster.num_instances = n; + cfg.cluster.num_instances = Some(n); + cfg.cluster.buckets.clear(); } if let Some(m) = self.max_requests { cfg.sim.max_requests = Some(m); @@ -205,6 +206,7 @@ fn main() -> Result<()> { fn load(config: &PathBuf, overrides: &ConfigOverrides) -> Result { let mut cfg = Config::from_yaml_path(config)?; overrides.apply(&mut cfg); + cfg.cluster.validate()?; Ok(cfg) } @@ -268,7 +270,8 @@ fn cmd_ablate( auto_target_ttft_mean, probe_mode.as_str() ); - base.cluster.num_instances = chosen; + base.cluster.num_instances = Some(chosen); + base.cluster.buckets.clear(); } eprintln!( @@ -283,7 +286,7 @@ fn cmd_ablate( .map(ReplayEvictPolicy::as_str) .collect::>() .join(","), - base.cluster.num_instances, + base.cluster.total_instances(), if jobs == 0 { "auto".to_string() } else { @@ -330,7 +333,8 @@ fn auto_select_instances( for &n in candidates { let mut cfg = base.clone(); - cfg.cluster.num_instances = n; + cfg.cluster.num_instances = Some(n); + cfg.cluster.buckets.clear(); cfg.cluster.router.mode = probe; // Isolate calibration output so ablation runs don't overwrite it. cfg.sim.output_dir = out_root @@ -429,7 +433,7 @@ fn cmd_validate(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> { let hbm_blocks = (cfg.hardware.hbm_bytes / block_bytes) as u64; let dram_blocks = (cfg.hardware.dram_bytes / block_bytes) as u64; eprintln!("per-instance HBM blocks = {hbm_blocks}, DRAM blocks = {dram_blocks}"); - eprintln!("num_instances = {}", cfg.cluster.num_instances); + eprintln!("num_instances = {}", cfg.cluster.total_instances()); // Sample prefill times at a few prompt lengths. eprintln!("prefill_time samples:"); for &n in &[256, 1024, 4096, 16384, 65536, 131072] { @@ -462,7 +466,7 @@ fn cmd_oracle( let cfg = load(path, overrides)?; let block_bytes = cfg.model.kv_block_bytes() as f64; let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64; - let aggregate_blocks = per_instance_blocks * cfg.cluster.num_instances as u64; + let aggregate_blocks = per_instance_blocks * cfg.cluster.total_instances() as u64; let capacity = match (capacity_blocks, per_instance) { (Some(_), true) => { return Err(anyhow::anyhow!( @@ -518,7 +522,7 @@ fn cmd_oracle( records.len(), capacity, per_instance_blocks, - cfg.cluster.num_instances, + cfg.cluster.total_instances(), if per_instance { ", per-instance mode" } else { diff --git a/src/replay.rs b/src/replay.rs index 13a7cb5..2199dc8 100644 --- a/src/replay.rs +++ b/src/replay.rs @@ -519,7 +519,7 @@ pub fn replay_fixed_placement( let block_bytes = cfg.model.kv_block_bytes() as f64; let l0_cap = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as usize; let l1_cap = (cfg.hardware.dram_bytes / block_bytes).max(1.0) as usize; - let num_instances = cfg.cluster.num_instances as usize; + let num_instances = cfg.cluster.total_instances() as usize; let mut caches: Vec = (0..num_instances) .map(|_| ReplayInstanceCache::new(policy, l0_cap, l1_cap)) .collect(); diff --git a/src/router/adaptive_affinity.rs b/src/router/adaptive_affinity.rs index 7b94239..c1ff154 100644 --- a/src/router/adaptive_affinity.rs +++ b/src/router/adaptive_affinity.rs @@ -57,7 +57,7 @@ pub struct AdaptiveAffinityRouter { impl AdaptiveAffinityRouter { pub fn new(config: &Config) -> Self { - let n = config.cluster.num_instances as usize; + let n = config.cluster.total_instances() as usize; let configured_fan_out = config.cluster.router.affinity_fan_out; let max_fan_out = if configured_fan_out > 0 { configured_fan_out.max(2).min(n) diff --git a/src/router/lineage_affinity.rs b/src/router/lineage_affinity.rs index f8d51e6..d23a580 100644 --- a/src/router/lineage_affinity.rs +++ b/src/router/lineage_affinity.rs @@ -53,7 +53,7 @@ pub struct LineageAffinityRouter { impl LineageAffinityRouter { pub fn new(config: &Config) -> Self { - let n = config.cluster.num_instances as usize; + let n = config.cluster.total_instances() as usize; let configured_fan_out = config.cluster.router.affinity_fan_out; let max_fan_out = if configured_fan_out > 0 { configured_fan_out.max(2).min(n) diff --git a/src/router/mod.rs b/src/router/mod.rs index 7773e3d..2a23e3c 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -181,7 +181,9 @@ mod tests { hardware: test_hardware(), calibration: CalibrationConfig::default(), cluster: ClusterConfig { - num_instances: 2, + num_instances: Some(2), + buckets: Vec::new(), + global_router: Default::default(), meta_store: MetaStoreConfig { ttl_seconds: 1000.0, }, diff --git a/src/router/prefix_affinity.rs b/src/router/prefix_affinity.rs index 8f99b7a..ca6bcb3 100644 --- a/src/router/prefix_affinity.rs +++ b/src/router/prefix_affinity.rs @@ -51,7 +51,7 @@ pub struct PrefixAffinityRouter { impl PrefixAffinityRouter { pub fn new(config: &Config) -> Self { - let n = config.cluster.num_instances as usize; + let n = config.cluster.total_instances() as usize; let cfg_fan = config.cluster.router.affinity_fan_out; // fan_out: if configured, use it; otherwise auto = max(2, n/8). let fan_out = if cfg_fan > 0 { diff --git a/tests/smoke.rs b/tests/smoke.rs index 700408d..33c81db 100644 --- a/tests/smoke.rs +++ b/tests/smoke.rs @@ -41,7 +41,9 @@ fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config { }, calibration: CalibrationConfig::default(), cluster: ClusterConfig { - num_instances: 4, + num_instances: Some(4), + buckets: Vec::new(), + global_router: Default::default(), meta_store: MetaStoreConfig { ttl_seconds: 1000.0, },