From 43ada0cfc0f82a6347fd8a430bbcf4163af03c00 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 17 Apr 2026 17:55:54 +0800 Subject: [PATCH] feat: add bucket score global router --- src/cluster/bucketed_service.rs | 19 +-- src/cluster/cluster.rs | 15 +- src/router/global_bucket.rs | 234 +++++++++++++++++++++++++++++++- tests/smoke.rs | 66 +++++++++ 4 files changed, 311 insertions(+), 23 deletions(-) diff --git a/src/cluster/bucketed_service.rs b/src/cluster/bucketed_service.rs index 4720965..633f851 100644 --- a/src/cluster/bucketed_service.rs +++ b/src/cluster/bucketed_service.rs @@ -52,7 +52,7 @@ impl BucketedService { let bucket_views = self .buckets .iter() - .map(|bucket| bucket.cluster.bucket_view(bucket.id, &bucket.cfg)) + .map(|bucket| bucket.cluster.bucket_view(bucket.id, &bucket.cfg, req, now)) .collect::>(); let global = self.global_router.route(req, &bucket_views, now)?; let bucket = &mut self.buckets[global.chosen_bucket as usize]; @@ -213,21 +213,4 @@ mod tests { assert!(err.to_string().contains("no bucket")); assert!(err.to_string().contains("input_length=36")); } - - #[test] - fn bucket_score_placeholder_reports_strict_fallback() { - let mut cfg = test_config(); - cfg.cluster.global_router.mode = GlobalRouterMode::BucketScore; - let mut service = BucketedService::new(&cfg, &cfg.model); - - let stats = service - .route_and_admit(&req(4, 24, &[30, 31]), 0.0) - .unwrap(); - - assert_eq!(stats.decision.global_mode, "strict_input_length"); - assert!(stats - .decision - .global_reason - .contains("bucket_score is not implemented")); - } } diff --git a/src/cluster/cluster.rs b/src/cluster/cluster.rs index 31f4386..322b867 100644 --- a/src/cluster/cluster.rs +++ b/src/cluster/cluster.rs @@ -171,7 +171,19 @@ impl Cluster { } } - pub fn bucket_view(&self, bucket_id: BucketId, cfg: &BucketConfig) -> BucketView { + pub fn bucket_view( + &self, + bucket_id: BucketId, + cfg: &BucketConfig, + req: &RequestRecord, + now: f64, + ) -> BucketView { + let predicted_prefix = self + .meta_store + .score_prefix(&req.hash_ids, now, self.instances.len()) + .into_iter() + .max() + .unwrap_or(0); BucketView { id: bucket_id, input_length_min: cfg.input_length_min, @@ -179,6 +191,7 @@ impl Cluster { num_instances: self.instances.len() as u32, total_queue_len: self.instances.iter().map(Instance::queue_len).sum(), total_load_blocks: self.instances.iter().map(|inst| inst.kv_blocks_used).sum(), + predicted_prefix, } } diff --git a/src/router/global_bucket.rs b/src/router/global_bucket.rs index f2f047a..2522e7d 100644 --- a/src/router/global_bucket.rs +++ b/src/router/global_bucket.rs @@ -14,6 +14,7 @@ pub struct BucketView { pub num_instances: u32, pub total_queue_len: u32, pub total_load_blocks: u32, + pub predicted_prefix: u32, } #[derive(Debug, Clone, Serialize)] @@ -24,7 +25,9 @@ pub struct BucketCandidate { pub num_instances: u32, pub total_queue_len: u32, pub total_load_blocks: u32, + pub predicted_prefix: u32, pub matches_input_len: bool, + pub score: f64, } #[derive(Debug, Clone, Serialize)] @@ -92,8 +95,16 @@ impl GlobalRouter for StrictInputLengthRouter { num_instances: view.num_instances, total_queue_len: view.total_queue_len, total_load_blocks: view.total_load_blocks, + predicted_prefix: view.predicted_prefix, matches_input_len: view.input_length_min <= req.input_len && req.input_len <= view.input_length_max, + score: if view.input_length_min <= req.input_len + && req.input_len <= view.input_length_max + { + 0.0 + } else { + f64::INFINITY + }, }) .collect::>(); @@ -131,15 +142,230 @@ impl GlobalRouter for StrictInputLengthRouter { } } +struct BucketScoreRouter { + length_penalty_weight: f64, + load_weight: f64, + cache_weight: f64, +} + +impl BucketScoreRouter { + fn new(full: &Config) -> Self { + Self { + length_penalty_weight: full.cluster.global_router.length_penalty_weight, + load_weight: full.cluster.global_router.load_weight, + cache_weight: full.cluster.global_router.cache_weight, + } + } + + fn length_penalty(&self, req: &RequestRecord, bucket: &BucketView) -> f64 { + if req.input_len < bucket.input_length_min { + (bucket.input_length_min - req.input_len) as f64 + } else if req.input_len > bucket.input_length_max { + (req.input_len - bucket.input_length_max) as f64 + } else { + 0.0 + } + } +} + +impl GlobalRouter for BucketScoreRouter { + fn name(&self) -> &'static str { + "bucket_score" + } + + fn route( + &mut self, + req: &RequestRecord, + buckets: &[BucketView], + _now: f64, + ) -> Result { + let mut chosen_bucket = None; + let mut best_score = f64::INFINITY; + let mut candidates = Vec::with_capacity(buckets.len()); + + for bucket in buckets { + let length_penalty = self.length_penalty(req, bucket); + let miss = req + .hash_ids + .len() + .saturating_sub(bucket.predicted_prefix as usize) as f64; + let score = self.length_penalty_weight * length_penalty + + self.load_weight * bucket.total_queue_len as f64 + + self.cache_weight * miss; + + candidates.push(BucketCandidate { + bucket: bucket.id, + input_length_min: bucket.input_length_min, + input_length_max: bucket.input_length_max, + num_instances: bucket.num_instances, + total_queue_len: bucket.total_queue_len, + total_load_blocks: bucket.total_load_blocks, + predicted_prefix: bucket.predicted_prefix, + matches_input_len: bucket.input_length_min <= req.input_len + && req.input_len <= bucket.input_length_max, + score, + }); + + let better = score < best_score + || (score == best_score && chosen_bucket.is_none_or(|best| bucket.id < best)); + if better { + best_score = score; + chosen_bucket = Some(bucket.id); + } + } + + Ok(GlobalRouteDecision { + req_id: req.req_id, + mode: self.name(), + chosen_bucket: chosen_bucket.ok_or_else(|| anyhow!("no buckets available"))?, + candidates, + reason: "weighted length/load/cache bucket score", + }) + } +} + pub fn build(full: &Config) -> Box { match full.cluster.global_router.mode { GlobalRouterMode::StrictInputLength => Box::new(StrictInputLengthRouter::new( "strict_input_length", "unique bucket range contains input_length", )) as Box, - GlobalRouterMode::BucketScore => Box::new(StrictInputLengthRouter::new( - "strict_input_length", - "bucket_score is not implemented in Task 2; falling back to strict_input_length", - )) as Box, + GlobalRouterMode::BucketScore => { + Box::new(BucketScoreRouter::new(full)) as Box + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{ + ClusterConfig, GlobalRouterConfig, MetaStoreConfig, RouterConfig, RouterMode, + }; + + fn cfg() -> Config { + Config { + model: crate::config::ModelConfig::default(), + hardware: crate::config::HardwareConfig { + gpu_flops: 1.0, + gpu_fp8_flops: 0.0, + gpu_fp4_flops: 0.0, + gpu_mem_bw: 1.0, + hbm_bytes: 1.0, + dram_bytes: 1.0, + host_dram_bw: 1.0, + pcie_bw: 1.0, + pcie_latency_us: 1.0, + rdma_bw: 1.0, + rdma_latency_us: 1.0, + intra_node_tp_bw: 1.0, + intra_node_tp_latency_us: 1.0, + tp_degree: 1, + max_batch_slots: 1, + prefill_chunk_tokens: 1, + }, + calibration: crate::config::CalibrationConfig::default(), + cluster: ClusterConfig { + num_instances: None, + buckets: Vec::new(), + global_router: GlobalRouterConfig { + mode: GlobalRouterMode::BucketScore, + length_penalty_weight: 1.0, + load_weight: 1.0, + cache_weight: 1.0, + }, + meta_store: MetaStoreConfig { ttl_seconds: 1.0 }, + router: RouterConfig { + mode: RouterMode::LeastLoaded, + precise_probe_latency_us: 1.0, + precise_probe_topk: 1, + load_alpha: 1.0, + score_alpha: 1.0, + score_beta: 1.0, + prefix_k: 8, + affinity_fan_out: 1, + }, + }, + sim: crate::config::SimConfig { + trace_path: String::new(), + max_requests: None, + output_dir: String::new(), + sample_interval_s: 0.0, + seed: 0, + input_length_min: None, + input_length_max: None, + }, + } + } + + fn req(input_len: u32) -> RequestRecord { + RequestRecord { + req_id: 1, + chat_id: 0, + parent_chat_id: -1, + turn: 0, + arrival: 0.0, + input_len, + output_len: 16, + hash_ids: vec![10, 11, 12], + } + } + + #[test] + fn bucket_score_prefers_matching_bucket_when_load_is_equal() { + let mut router = BucketScoreRouter::new(&cfg()); + let buckets = vec![ + BucketView { + id: 0, + input_length_min: 0, + input_length_max: 32, + num_instances: 2, + total_queue_len: 1, + total_load_blocks: 0, + predicted_prefix: 0, + }, + BucketView { + id: 1, + input_length_min: 33, + input_length_max: 96, + num_instances: 2, + total_queue_len: 1, + total_load_blocks: 0, + predicted_prefix: 0, + }, + ]; + let decision = router.route(&req(24), &buckets, 0.0).unwrap(); + assert_eq!(decision.chosen_bucket, 0); + } + + #[test] + fn bucket_score_can_override_length_match_when_load_gap_is_large() { + let mut full = cfg(); + full.cluster.global_router.load_weight = 5.0; + full.cluster.global_router.cache_weight = 1.0; + full.cluster.global_router.length_penalty_weight = 1.0; + let mut router = BucketScoreRouter::new(&full); + let buckets = vec![ + BucketView { + id: 0, + input_length_min: 0, + input_length_max: 32, + num_instances: 2, + total_queue_len: 20, + total_load_blocks: 0, + predicted_prefix: 0, + }, + BucketView { + id: 1, + input_length_min: 33, + input_length_max: 96, + num_instances: 2, + total_queue_len: 0, + total_load_blocks: 0, + predicted_prefix: 2, + }, + ]; + let decision = router.route(&req(24), &buckets, 0.0).unwrap(); + assert_eq!(decision.chosen_bucket, 1); } } diff --git a/tests/smoke.rs b/tests/smoke.rs index 8a23f8b..bd5ec38 100644 --- a/tests/smoke.rs +++ b/tests/smoke.rs @@ -385,3 +385,69 @@ fn bucketed_configs_are_rejected_by_legacy_fixed_placement_paths() { .expect_err("bucketed replay should fail"); assert!(err.to_string().contains("cluster.buckets")); } + +#[test] +fn bucket_score_can_deviate_from_strict_length_bucket() { + let tmp = std::env::temp_dir().join("kvcache_sim_bucket_score"); + let _ = std::fs::remove_dir_all(&tmp); + std::fs::create_dir_all(&tmp).unwrap(); + let trace_path = tmp.join("trace.jsonl"); + + let mut f = std::fs::File::create(&trace_path).unwrap(); + for req_id in 0..3 { + writeln!( + f, + "{}", + serde_json::json!({ + "chat_id": req_id, + "parent_chat_id": -1, + "timestamp": 0.0, + "input_length": 24, + "output_length": 16, + "type": "text", + "turn": 0, + "hash_ids": [100 + req_id, 200 + req_id] + }) + ) + .unwrap(); + } + + let mut strict_cfg = bucketed_config( + trace_path.to_str().unwrap(), + tmp.to_str().unwrap(), + RouterMode::LeastLoaded, + ); + strict_cfg.cluster.buckets = vec![ + BucketConfig { + name: "short".into(), + input_length_min: 0, + input_length_max: 32, + num_instances: 1, + }, + BucketConfig { + name: "long".into(), + input_length_min: 33, + input_length_max: 96, + num_instances: 1, + }, + ]; + strict_cfg.cluster.global_router.mode = GlobalRouterMode::StrictInputLength; + + let mut score_cfg = strict_cfg.clone(); + score_cfg.cluster.global_router.mode = GlobalRouterMode::BucketScore; + score_cfg.cluster.global_router.length_penalty_weight = 1.0; + score_cfg.cluster.global_router.load_weight = 10.0; + score_cfg.cluster.global_router.cache_weight = 0.0; + + let _ = driver::run(&strict_cfg, Some("strict_score_cmp")).expect("strict run"); + let _ = driver::run(&score_cfg, Some("bucket_score_cmp")).expect("bucket score run"); + + let strict_log = + std::fs::read_to_string(tmp.join("strict_score_cmp/routing_log.jsonl")).unwrap(); + let score_log = + std::fs::read_to_string(tmp.join("bucket_score_cmp/routing_log.jsonl")).unwrap(); + + assert!(strict_log.contains("\"chosen_bucket\":0")); + assert!(score_log.contains("\"global_mode\":\"bucket_score\"")); + assert!(score_log.contains("\"chosen_bucket\":1")); +}