use anyhow::Result; use super::cluster::{AdmissionStats, Cluster}; use crate::config::{BucketConfig, Config, ModelConfig}; use crate::instance::Instance; use crate::router::{self, BucketId, GlobalRouter}; use crate::trace::RequestRecord; pub struct ServiceBucket { pub id: BucketId, pub cfg: BucketConfig, pub cluster: Cluster, } impl ServiceBucket { pub fn instances(&self) -> &[Instance] { &self.cluster.instances } } pub struct BucketedService { pub buckets: Vec, pub global_router: Box, } impl BucketedService { pub fn new(config: &Config, model: &ModelConfig) -> Self { let buckets = config .cluster .effective_buckets() .into_iter() .enumerate() .map(|(idx, cfg)| ServiceBucket { id: idx as BucketId, cluster: Cluster::new_for_bucket(config, model, idx as BucketId, cfg.num_instances) .expect("bucket-local cluster construction should succeed"), cfg, }) .collect(); Self { buckets, global_router: router::build_global(config), } } pub fn bucket(&self, bucket_id: BucketId) -> &ServiceBucket { &self.buckets[bucket_id as usize] } pub fn route_and_admit(&mut self, req: &RequestRecord, now: f64) -> Result { let bucket_views = self .buckets .iter() .map(|bucket| bucket.cluster.bucket_view(bucket.id, &bucket.cfg, req, now)) .collect::>(); let global = self.global_router.route(req, &bucket_views, now)?; let bucket = &mut self.buckets[global.chosen_bucket as usize]; Ok(bucket .cluster .route_and_admit_with_global(req, now, &global)) } } #[cfg(test)] mod tests { use super::*; use crate::config::{ BucketConfig, CalibrationConfig, ClusterConfig, Config, GlobalRouterConfig, GlobalRouterMode, HardwareConfig, MetaStoreConfig, ModelConfig, RouterConfig, RouterMode, SimConfig, }; use crate::trace::RequestRecord; fn test_config() -> Config { Config { model: ModelConfig { name: "test".into(), num_layers: 4, num_kv_heads: 2, head_dim: 64, dtype_bytes: 2, block_size_tokens: 16, flops_per_token_prefill: Some(1.0e9), attn_quadratic_coeff: Some(64.0), ..Default::default() }, hardware: HardwareConfig { gpu_flops: 1.0e14, gpu_fp8_flops: 0.0, gpu_fp4_flops: 0.0, gpu_mem_bw: 1.0e12, hbm_bytes: 1.0e9, dram_bytes: 4.0e9, host_dram_bw: 5.0e11, pcie_bw: 32.0e9, pcie_latency_us: 1.0, rdma_bw: 12.0e9, rdma_latency_us: 5.0, intra_node_tp_bw: 9.0e11, intra_node_tp_latency_us: 2.0, tp_degree: 1, max_batch_slots: 32, prefill_chunk_tokens: 1024, }, calibration: CalibrationConfig::default(), cluster: ClusterConfig { num_instances: None, buckets: vec![ BucketConfig { name: "short".into(), input_length_min: 0, input_length_max: 32, num_instances: 2, }, BucketConfig { name: "long".into(), input_length_min: 33, input_length_max: 96, num_instances: 1, }, ], global_router: GlobalRouterConfig { mode: GlobalRouterMode::StrictInputLength, length_penalty_weight: 1.0, load_weight: 1.0, cache_weight: 1.0, }, meta_store: MetaStoreConfig { ttl_seconds: 1000.0, }, router: RouterConfig { mode: RouterMode::LeastLoaded, precise_probe_latency_us: 10.0, precise_probe_topk: 2, load_alpha: 0.0, score_alpha: 1.0, score_beta: 0.1, prefix_k: 8, affinity_fan_out: 2, }, }, sim: SimConfig { trace_path: String::new(), max_requests: None, output_dir: String::new(), sample_interval_s: 0.0, seed: 7, input_length_min: None, input_length_max: None, }, } } fn req(req_id: u64, input_len: u32, hashes: &[u64]) -> RequestRecord { RequestRecord { req_id, chat_id: req_id as i64, parent_chat_id: -1, turn: 0, arrival: 0.0, input_len, output_len: 16, hash_ids: hashes.to_vec(), } } #[test] fn strict_input_length_routes_into_matching_bucket() { let cfg = test_config(); let mut service = BucketedService::new(&cfg, &cfg.model); let stats = service .route_and_admit(&req(1, 24, &[10, 11]), 0.0) .unwrap(); assert_eq!(stats.bucket, 0); assert_eq!(stats.decision.chosen_bucket, 0); assert_eq!( stats.decision.global_reason, "unique bucket range contains input_length" ); assert_eq!( stats.decision.local_reason, "argmin(kv_used + alpha * queue_len)" ); assert_eq!(service.bucket(0).instances().len(), 2); } #[test] fn bucket_meta_store_is_isolated() { let cfg = test_config(); let mut service = BucketedService::new(&cfg, &cfg.model); let _ = service .route_and_admit(&req(1, 24, &[10, 11]), 0.0) .unwrap(); let long_stats = service .route_and_admit(&req(2, 64, &[10, 11, 12, 13]), 1.0) .unwrap(); assert_eq!(long_stats.bucket, 1); assert_eq!(long_stats.remote_hit_blocks, 0); assert_eq!(long_stats.l1_hit_blocks, 0); } #[test] fn unmatched_input_length_returns_recoverable_error() { let mut cfg = test_config(); cfg.cluster.buckets[1].input_length_min = 40; let mut service = BucketedService::new(&cfg, &cfg.model); let err = service .route_and_admit(&req(3, 36, &[20, 21, 22]), 0.0) .unwrap_err(); assert!(err.to_string().contains("no bucket")); assert!(err.to_string().contains("input_length=36")); } }