From 96019082cca9cd68d061fd5ec9bf05602370bc52 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 17 Apr 2026 14:50:47 +0800 Subject: [PATCH] fix: complete global router config and recoverable cluster init --- src/cluster/cluster.rs | 32 ++++++++------------ src/config.rs | 68 +++++++++++++++++++++++++++++++++++++++++- src/driver.rs | 2 +- 3 files changed, 80 insertions(+), 22 deletions(-) diff --git a/src/cluster/cluster.rs b/src/cluster/cluster.rs index 122945e..035a72e 100644 --- a/src/cluster/cluster.rs +++ b/src/cluster/cluster.rs @@ -1,6 +1,8 @@ //! Cluster: routes arrivals, performs the L0 / L1 / remote-RDMA fetch chain //! described in the design diagram, and bookkeeps the global meta store. +use anyhow::Result; + use crate::cluster::meta_store::MetaStore; use crate::config::{Config, ModelConfig}; use crate::instance::instance::AdmittedRequest; @@ -36,11 +38,8 @@ pub struct Cluster { } impl Cluster { - pub fn new(config: &Config, model: &ModelConfig) -> Self { - let total_instances = config - .cluster - .require_legacy_single_pool("Cluster::new") - .unwrap_or_else(|err| panic!("{err}")); + pub fn new(config: &Config, model: &ModelConfig) -> Result { + let total_instances = config.cluster.require_legacy_single_pool("Cluster::new")?; let mut instances = Vec::with_capacity(total_instances as usize); for id in 0..total_instances { instances.push(Instance::new( @@ -52,7 +51,7 @@ impl Cluster { } let meta_store = MetaStore::new(config.cluster.meta_store.ttl_seconds); let router = router::build(config, config.sim.seed); - Self { + Ok(Self { instances, meta_store, router, @@ -63,7 +62,7 @@ impl Cluster { &config.calibration, model.kv_block_bytes(), ), - } + }) } /// Route + admit a request. Returns the chosen instance plus rich @@ -262,7 +261,7 @@ mod tests { #[test] fn l1_ready_at_includes_dram_and_transform_overhead() { let cfg = test_config(RouterMode::EstimatedTtft); - let mut cluster = Cluster::new(&cfg, &cfg.model); + let mut cluster = Cluster::new(&cfg, &cfg.model).unwrap(); let req = RequestRecord { req_id: 1, chat_id: 0, @@ -300,17 +299,10 @@ mod tests { num_instances: 2, }]; - let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - Cluster::new(&cfg, &cfg.model); - })) - .expect_err("bucketed Cluster::new should panic"); - - let msg = panic - .downcast_ref::() - .cloned() - .or_else(|| panic.downcast_ref::<&str>().map(|s| (*s).to_string())) - .expect("panic payload should be a string"); - assert!(msg.contains("Cluster::new")); - assert!(msg.contains("cluster.buckets")); + let result = Cluster::new(&cfg, &cfg.model); + assert!(result.is_err(), "bucketed Cluster::new should fail"); + let err = result.err().unwrap(); + assert!(err.to_string().contains("Cluster::new")); + assert!(err.to_string().contains("cluster.buckets")); } } diff --git a/src/config.rs b/src/config.rs index 5be0ba6..36e4947 100644 --- a/src/config.rs +++ b/src/config.rs @@ -397,16 +397,25 @@ pub struct BucketConfig { pub num_instances: u32, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct GlobalRouterConfig { #[serde(default)] pub mode: GlobalRouterMode, + #[serde(default = "default_global_router_length_penalty_weight")] + pub length_penalty_weight: f64, + #[serde(default = "default_global_router_load_weight")] + pub load_weight: f64, + #[serde(default = "default_global_router_cache_weight")] + pub cache_weight: f64, } impl Default for GlobalRouterConfig { fn default() -> Self { Self { mode: GlobalRouterMode::StrictInputLength, + length_penalty_weight: default_global_router_length_penalty_weight(), + load_weight: default_global_router_load_weight(), + cache_weight: default_global_router_cache_weight(), } } } @@ -428,6 +437,18 @@ impl GlobalRouterMode { } } +fn default_global_router_length_penalty_weight() -> f64 { + 1.0 +} + +fn default_global_router_load_weight() -> f64 { + 1.0 +} + +fn default_global_router_cache_weight() -> f64 { + 1.0 +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MetaStoreConfig { pub ttl_seconds: f64, @@ -1228,6 +1249,51 @@ sim: assert_eq!(cfg.cluster.total_instances(), 3); } + #[test] + fn bucketed_config_deserializes_global_router_weights() { + let path = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: bucket_score + length_penalty_weight: 1.5 + load_weight: 0.75 + cache_weight: 2.25 + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + + let cfg = Config::from_yaml_path(&path).unwrap(); + assert_eq!(cfg.cluster.global_router.mode, GlobalRouterMode::BucketScore); + assert!((cfg.cluster.global_router.length_penalty_weight - 1.5).abs() < 1e-12); + assert!((cfg.cluster.global_router.load_weight - 0.75).abs() < 1e-12); + assert!((cfg.cluster.global_router.cache_weight - 2.25).abs() < 1e-12); + } + #[test] fn bucketed_config_rejects_overlapping_ranges_and_mixed_modes() { let overlap = write_temp_config( diff --git a/src/driver.rs b/src/driver.rs index 05e9531..dbc6b5e 100644 --- a/src/driver.rs +++ b/src/driver.rs @@ -52,7 +52,7 @@ pub fn run(config: &Config, output_subdir: Option<&str>) -> Result { config .cluster .require_legacy_single_pool("driver run")?; - let mut cluster = Cluster::new(config, &config.model); + let mut cluster = Cluster::new(config, &config.model)?; let mut q = EventQueue::new(); // Output directory