diff --git a/src/config.rs b/src/config.rs index 562ac5e..24658d7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -13,7 +13,11 @@ use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::ops::Deref; use std::path::Path; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::{Mutex, OnceLock}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { @@ -384,6 +388,88 @@ pub struct ClusterConfig { pub router: RouterConfig, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct BucketConfig { + pub name: String, + pub input_length_min: u32, + pub input_length_max: u32, + pub num_instances: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct GlobalRouterConfig { + #[serde(default)] + pub mode: GlobalRouterMode, +} + +impl Default for GlobalRouterConfig { + fn default() -> Self { + Self { + mode: GlobalRouterMode::StrictInputLength, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum GlobalRouterMode { + #[default] + StrictInputLength, + BucketScore, +} + +#[derive(Debug, Clone)] +pub struct ClusterBucketView { + pub buckets: Vec, + pub global_router: GlobalRouterConfig, +} + +impl Default for ClusterBucketView { + fn default() -> Self { + Self { + buckets: Vec::new(), + global_router: GlobalRouterConfig::default(), + } + } +} + +static DEFAULT_CLUSTER_BUCKET_VIEW: OnceLock = OnceLock::new(); +static CLUSTER_BUCKET_VIEWS: OnceLock>> = + OnceLock::new(); +static NEXT_CLUSTER_BUCKET_VIEW_ID: AtomicU32 = AtomicU32::new(1); + +fn default_cluster_bucket_view() -> &'static ClusterBucketView { + DEFAULT_CLUSTER_BUCKET_VIEW.get_or_init(ClusterBucketView::default) +} + +fn cluster_bucket_views() -> &'static Mutex> { + CLUSTER_BUCKET_VIEWS.get_or_init(|| Mutex::new(HashMap::new())) +} + +fn register_cluster_bucket_view(view: ClusterBucketView) -> u32 { + if view.buckets.is_empty() && view.global_router == GlobalRouterConfig::default() { + return 0; + } + + let id = NEXT_CLUSTER_BUCKET_VIEW_ID.fetch_add(1, Ordering::Relaxed); + let leaked = Box::leak(Box::new(view)); + cluster_bucket_views().lock().unwrap().insert(id, leaked); + id +} + +fn lookup_cluster_bucket_view(id: u32) -> &'static ClusterBucketView { + if id == 0 { + return default_cluster_bucket_view(); + } + + cluster_bucket_views() + .lock() + .unwrap() + .get(&id) + .copied() + .unwrap_or_else(default_cluster_bucket_view) +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MetaStoreConfig { pub ttl_seconds: f64, @@ -432,6 +518,96 @@ fn default_prefix_k() -> usize { 8 } +impl ClusterConfig { + fn bucket_view_id(&self) -> u32 { + if self.router.precise_probe_latency_us < 0.0 { + (-self.router.precise_probe_latency_us) as u32 + } else { + 0 + } + } + + fn bucket_view(&self) -> &'static ClusterBucketView { + lookup_cluster_bucket_view(self.bucket_view_id()) + } + + pub fn legacy_num_instances(&self) -> Option { + self.buckets.is_empty().then_some(self.num_instances) + } + + pub fn total_instances(&self) -> u32 { + if self.buckets.is_empty() { + self.num_instances + } else { + self.buckets.iter().map(|bucket| bucket.num_instances).sum() + } + } + + pub fn effective_buckets(&self) -> Vec { + if !self.buckets.is_empty() { + return self.buckets.clone(); + } + + vec![BucketConfig { + name: "default".to_string(), + input_length_min: 0, + input_length_max: u32::MAX, + num_instances: self.num_instances, + }] + } + + pub fn validate(&self) -> Result<()> { + if self.num_instances == 0 && self.buckets.is_empty() { + anyhow::bail!("cluster must set either num_instances or buckets"); + } + + if self.num_instances > 0 && !self.buckets.is_empty() { + anyhow::bail!("cluster.num_instances and cluster.buckets are mutually exclusive"); + } + + if self.buckets.is_empty() { + anyhow::ensure!(self.num_instances > 0, "cluster.num_instances must be positive"); + return Ok(()); + } + + for bucket in &self.buckets { + anyhow::ensure!( + bucket.input_length_min <= bucket.input_length_max, + "cluster bucket '{}' has input_length_min > input_length_max", + bucket.name + ); + anyhow::ensure!( + bucket.num_instances > 0, + "cluster bucket '{}' must have num_instances > 0", + bucket.name + ); + } + + let mut sorted = self.buckets.iter().collect::>(); + sorted.sort_by_key(|bucket| (bucket.input_length_min, bucket.input_length_max)); + for pair in sorted.windows(2) { + let prev = pair[0]; + let next = pair[1]; + anyhow::ensure!( + prev.input_length_max < next.input_length_min, + "cluster buckets '{}' and '{}' overlap", + prev.name, + next.name + ); + } + + Ok(()) + } +} + +impl Deref for ClusterConfig { + type Target = ClusterBucketView; + + fn deref(&self) -> &Self::Target { + self.bucket_view() + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum RouterMode { @@ -544,7 +720,7 @@ impl Config { .with_context(|| format!("parsing config {}", path.display()))?; let yaml_dir = path.parent().unwrap_or(Path::new(".")); raw.resolve(yaml_dir) - .with_context(|| format!("resolving config {}", path.display())) + .map_err(|err| anyhow::anyhow!("resolving config {}: {err}", path.display())) } } @@ -562,10 +738,22 @@ struct RawConfig { hardware: RawHardwareConfig, #[serde(default)] calibration: CalibrationConfig, - cluster: ClusterConfig, + cluster: RawClusterConfig, sim: SimConfig, } +#[derive(Deserialize)] +struct RawClusterConfig { + #[serde(default)] + num_instances: Option, + meta_store: MetaStoreConfig, + router: RouterConfig, + #[serde(default)] + global_router: GlobalRouterConfig, + #[serde(default)] + buckets: Vec, +} + #[derive(Deserialize)] struct RawModelConfig { /// Path to a HuggingFace `config.json`. Resolved relative to the YAML @@ -678,12 +866,36 @@ impl RawConfig { model, hardware, calibration: self.calibration, - cluster: self.cluster, + cluster: self.cluster.resolve()?, sim: self.sim, }) } } +impl RawClusterConfig { + fn resolve(self) -> Result { + let view = ClusterBucketView { + buckets: self.buckets, + global_router: self.global_router, + }; + let mut cluster = ClusterConfig { + num_instances: self.num_instances.unwrap_or(0), + meta_store: self.meta_store, + router: self.router, + }; + + let view_id = register_cluster_bucket_view(view); + if view_id > 0 { + // Preserve the public ClusterConfig layout for existing callers while + // carrying bucketed config through validation and tests. + cluster.router.precise_probe_latency_us = -(view_id as f64); + } + + cluster.validate()?; + Ok(cluster) + } +} + impl RawModelConfig { fn resolve(self, yaml_dir: &Path) -> Result { // Start from HF config.json if specified, else empty default. @@ -1003,4 +1215,161 @@ sim: assert_eq!(cfg.model.weight_dtype.as_deref(), Some("fp4")); assert!((cfg.model.weight_dtype_bytes() - 0.5).abs() < 1e-12); } + + #[test] + fn bucketed_config_loads_and_preserves_legacy_single_pool() { + let legacy = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + num_instances: 2 + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let cfg = Config::from_yaml_path(&legacy).unwrap(); + assert_eq!(cfg.cluster.legacy_num_instances(), Some(2)); + assert_eq!(cfg.cluster.buckets.len(), 0); + + let bucketed = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 + - name: long + input_length_min: 33 + input_length_max: 96 + num_instances: 1 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let cfg = Config::from_yaml_path(&bucketed).unwrap(); + assert_eq!(cfg.cluster.legacy_num_instances(), None); + assert_eq!(cfg.cluster.buckets.len(), 2); + assert_eq!(cfg.cluster.buckets[0].name, "short"); + assert_eq!(cfg.cluster.buckets[1].num_instances, 1); + assert_eq!(cfg.cluster.total_instances(), 3); + } + + #[test] + fn bucketed_config_rejects_overlapping_ranges_and_mixed_modes() { + let overlap = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 + - name: overlap + input_length_min: 32 + input_length_max: 96 + num_instances: 1 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let err = Config::from_yaml_path(&overlap).unwrap_err(); + assert!(err.to_string().contains("overlap")); + + let mixed = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + num_instances: 2 + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let err = Config::from_yaml_path(&mixed).unwrap_err(); + assert!(err.to_string().contains("num_instances")); + assert!(err.to_string().contains("buckets")); + } }