diff --git a/src/cluster/bucketed_service.rs b/src/cluster/bucketed_service.rs new file mode 100644 index 0000000..633f851 --- /dev/null +++ b/src/cluster/bucketed_service.rs @@ -0,0 +1,216 @@ +use anyhow::Result; + +use super::cluster::{AdmissionStats, Cluster}; +use crate::config::{BucketConfig, Config, ModelConfig}; +use crate::instance::Instance; +use crate::router::{self, BucketId, GlobalRouter}; +use crate::trace::RequestRecord; + +pub struct ServiceBucket { + pub id: BucketId, + pub cfg: BucketConfig, + pub cluster: Cluster, +} + +impl ServiceBucket { + pub fn instances(&self) -> &[Instance] { + &self.cluster.instances + } +} + +pub struct BucketedService { + pub buckets: Vec, + pub global_router: Box, +} + +impl BucketedService { + pub fn new(config: &Config, model: &ModelConfig) -> Self { + let buckets = config + .cluster + .effective_buckets() + .into_iter() + .enumerate() + .map(|(idx, cfg)| ServiceBucket { + id: idx as BucketId, + cluster: Cluster::new_for_bucket(config, model, idx as BucketId, cfg.num_instances) + .expect("bucket-local cluster construction should succeed"), + cfg, + }) + .collect(); + + Self { + buckets, + global_router: router::build_global(config), + } + } + + pub fn bucket(&self, bucket_id: BucketId) -> &ServiceBucket { + &self.buckets[bucket_id as usize] + } + + pub fn route_and_admit(&mut self, req: &RequestRecord, now: f64) -> Result { + let bucket_views = self + .buckets + .iter() + .map(|bucket| bucket.cluster.bucket_view(bucket.id, &bucket.cfg, req, now)) + .collect::>(); + let global = self.global_router.route(req, &bucket_views, now)?; + let bucket = &mut self.buckets[global.chosen_bucket as usize]; + Ok(bucket + .cluster + .route_and_admit_with_global(req, now, &global)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{ + BucketConfig, CalibrationConfig, ClusterConfig, Config, GlobalRouterConfig, + GlobalRouterMode, HardwareConfig, MetaStoreConfig, ModelConfig, RouterConfig, RouterMode, + SimConfig, + }; + use crate::trace::RequestRecord; + + fn test_config() -> Config { + Config { + model: ModelConfig { + name: "test".into(), + num_layers: 4, + num_kv_heads: 2, + head_dim: 64, + dtype_bytes: 2, + block_size_tokens: 16, + flops_per_token_prefill: Some(1.0e9), + attn_quadratic_coeff: Some(64.0), + ..Default::default() + }, + hardware: HardwareConfig { + gpu_flops: 1.0e14, + gpu_fp8_flops: 0.0, + gpu_fp4_flops: 0.0, + gpu_mem_bw: 1.0e12, + hbm_bytes: 1.0e9, + dram_bytes: 4.0e9, + host_dram_bw: 5.0e11, + pcie_bw: 32.0e9, + pcie_latency_us: 1.0, + rdma_bw: 12.0e9, + rdma_latency_us: 5.0, + intra_node_tp_bw: 9.0e11, + intra_node_tp_latency_us: 2.0, + tp_degree: 1, + max_batch_slots: 32, + prefill_chunk_tokens: 1024, + }, + calibration: CalibrationConfig::default(), + cluster: ClusterConfig { + num_instances: None, + buckets: vec![ + BucketConfig { + name: "short".into(), + input_length_min: 0, + input_length_max: 32, + num_instances: 2, + }, + BucketConfig { + name: "long".into(), + input_length_min: 33, + input_length_max: 96, + num_instances: 1, + }, + ], + global_router: GlobalRouterConfig { + mode: GlobalRouterMode::StrictInputLength, + length_penalty_weight: 1.0, + load_weight: 1.0, + cache_weight: 1.0, + }, + meta_store: MetaStoreConfig { + ttl_seconds: 1000.0, + }, + router: RouterConfig { + mode: RouterMode::LeastLoaded, + precise_probe_latency_us: 10.0, + precise_probe_topk: 2, + load_alpha: 0.0, + score_alpha: 1.0, + score_beta: 0.1, + prefix_k: 8, + affinity_fan_out: 2, + }, + }, + sim: SimConfig { + trace_path: String::new(), + max_requests: None, + output_dir: String::new(), + sample_interval_s: 0.0, + seed: 7, + input_length_min: None, + input_length_max: None, + }, + } + } + + fn req(req_id: u64, input_len: u32, hashes: &[u64]) -> RequestRecord { + RequestRecord { + req_id, + chat_id: req_id as i64, + parent_chat_id: -1, + turn: 0, + arrival: 0.0, + input_len, + output_len: 16, + hash_ids: hashes.to_vec(), + } + } + + #[test] + fn strict_input_length_routes_into_matching_bucket() { + let cfg = test_config(); + let mut service = BucketedService::new(&cfg, &cfg.model); + let stats = service + .route_and_admit(&req(1, 24, &[10, 11]), 0.0) + .unwrap(); + assert_eq!(stats.bucket, 0); + assert_eq!(stats.decision.chosen_bucket, 0); + assert_eq!( + stats.decision.global_reason, + "unique bucket range contains input_length" + ); + assert_eq!( + stats.decision.local_reason, + "argmin(kv_used + alpha * queue_len)" + ); + assert_eq!(service.bucket(0).instances().len(), 2); + } + + #[test] + fn bucket_meta_store_is_isolated() { + let cfg = test_config(); + let mut service = BucketedService::new(&cfg, &cfg.model); + let _ = service + .route_and_admit(&req(1, 24, &[10, 11]), 0.0) + .unwrap(); + let long_stats = service + .route_and_admit(&req(2, 64, &[10, 11, 12, 13]), 1.0) + .unwrap(); + assert_eq!(long_stats.bucket, 1); + assert_eq!(long_stats.remote_hit_blocks, 0); + assert_eq!(long_stats.l1_hit_blocks, 0); + } + + #[test] + fn unmatched_input_length_returns_recoverable_error() { + let mut cfg = test_config(); + cfg.cluster.buckets[1].input_length_min = 40; + let mut service = BucketedService::new(&cfg, &cfg.model); + + let err = service + .route_and_admit(&req(3, 36, &[20, 21, 22]), 0.0) + .unwrap_err(); + + assert!(err.to_string().contains("no bucket")); + assert!(err.to_string().contains("input_length=36")); + } +} diff --git a/src/cluster/cluster.rs b/src/cluster/cluster.rs index 9927cdc..322b867 100644 --- a/src/cluster/cluster.rs +++ b/src/cluster/cluster.rs @@ -1,18 +1,21 @@ //! Cluster: routes arrivals, performs the L0 / L1 / remote-RDMA fetch chain //! described in the design diagram, and bookkeeps the global meta store. +use anyhow::Result; + use crate::cluster::meta_store::MetaStore; -use crate::config::{Config, ModelConfig}; +use crate::config::{BucketConfig, Config, ModelConfig}; use crate::instance::instance::AdmittedRequest; use crate::instance::kv_cache::L1Change; use crate::instance::Instance; -use crate::router::{self, RouteDecision, Router}; +use crate::router::{self, BucketId, BucketView, GlobalRouteDecision, RouteDecision, Router}; use crate::trace::RequestRecord; use crate::ttft::{classify_prefix_tiers, TtftModel}; use crate::types::InstanceId; #[derive(Debug, Clone)] pub struct AdmissionStats { + pub bucket: BucketId, pub instance: InstanceId, pub l0_hit_blocks: u32, pub l1_hit_blocks: u32, @@ -36,39 +39,41 @@ pub struct Cluster { } impl Cluster { - pub fn new(config: &Config, model: &ModelConfig) -> Self { - let mut instances = Vec::with_capacity(config.cluster.num_instances as usize); - for id in 0..config.cluster.num_instances { - instances.push(Instance::new( - id as InstanceId, - model, - &config.hardware, - &config.calibration, - )); - } - let meta_store = MetaStore::new(config.cluster.meta_store.ttl_seconds); - let router = router::build(config, config.sim.seed); - Self { - instances, - meta_store, - router, - block_size_tokens: model.block_size_tokens, - kv_block_bytes: model.kv_block_bytes(), - ttft_model: TtftModel::new( - &config.hardware, - &config.calibration, - model.kv_block_bytes(), - ), - } + pub fn new(config: &Config, model: &ModelConfig) -> Result { + let total_instances = config.cluster.require_legacy_single_pool("Cluster::new")?; + Self::build_local_cluster(config, model, total_instances) + } + + pub fn new_for_bucket( + config: &Config, + model: &ModelConfig, + _bucket_id: BucketId, + num_instances: u32, + ) -> Result { + let mut local_config = config.clone(); + local_config.cluster.num_instances = Some(num_instances); + local_config.cluster.buckets.clear(); + Self::build_local_cluster(&local_config, model, num_instances) } /// Route + admit a request. Returns the chosen instance plus rich /// per-request stats for metrics. Does NOT schedule the BatchTick — the /// simulator driver does that based on the returned `ready_at`. pub fn route_and_admit(&mut self, req: &RequestRecord, now: f64) -> AdmissionStats { + let global = GlobalRouteDecision::single_bucket(req.req_id, 0); + self.route_and_admit_with_global(req, now, &global) + } + + pub fn route_and_admit_with_global( + &mut self, + req: &RequestRecord, + now: f64, + global: &GlobalRouteDecision, + ) -> AdmissionStats { let decision = self .router - .route(req, &self.instances, &self.meta_store, now); + .route(req, &self.instances, &self.meta_store, now) + .with_global(global); let inst_id = decision.chosen; let probe_overhead_s = decision.probe_overhead_s; let scheduler_overhead_s = self @@ -151,6 +156,7 @@ impl Cluster { let fetch_time_s = (t - effective_now).max(0.0); AdmissionStats { + bucket: decision.chosen_bucket, instance: inst_id, l0_hit_blocks: l0_hits, l1_hit_blocks: l1_hits, @@ -165,6 +171,30 @@ impl Cluster { } } + pub fn bucket_view( + &self, + bucket_id: BucketId, + cfg: &BucketConfig, + req: &RequestRecord, + now: f64, + ) -> BucketView { + let predicted_prefix = self + .meta_store + .score_prefix(&req.hash_ids, now, self.instances.len()) + .into_iter() + .max() + .unwrap_or(0); + BucketView { + id: bucket_id, + input_length_min: cfg.input_length_min, + input_length_max: cfg.input_length_max, + num_instances: self.instances.len() as u32, + total_queue_len: self.instances.iter().map(Instance::queue_len).sum(), + total_load_blocks: self.instances.iter().map(|inst| inst.kv_blocks_used).sum(), + predicted_prefix, + } + } + fn apply_l1_changes( meta_store: &mut MetaStore, inst_id: InstanceId, @@ -178,14 +208,44 @@ impl Cluster { } } } + + fn build_local_cluster( + config: &Config, + model: &ModelConfig, + num_instances: u32, + ) -> Result { + let mut instances = Vec::with_capacity(num_instances as usize); + for id in 0..num_instances { + instances.push(Instance::new( + id as InstanceId, + model, + &config.hardware, + &config.calibration, + )); + } + let meta_store = MetaStore::new(config.cluster.meta_store.ttl_seconds); + let router = router::build(config, config.sim.seed); + Ok(Self { + instances, + meta_store, + router, + block_size_tokens: model.block_size_tokens, + kv_block_bytes: model.kv_block_bytes(), + ttft_model: TtftModel::new( + &config.hardware, + &config.calibration, + model.kv_block_bytes(), + ), + }) + } } #[cfg(test)] mod tests { use super::*; use crate::config::{ - CalibrationConfig, ClusterConfig, Config, HardwareConfig, MetaStoreConfig, ModelConfig, - RouterConfig, RouterMode, SimConfig, + BucketConfig, CalibrationConfig, ClusterConfig, Config, HardwareConfig, MetaStoreConfig, + ModelConfig, RouterConfig, RouterMode, SimConfig, }; use crate::trace::RequestRecord; @@ -226,7 +286,9 @@ mod tests { ..CalibrationConfig::default() }, cluster: ClusterConfig { - num_instances: 1, + num_instances: Some(1), + buckets: Vec::new(), + global_router: Default::default(), meta_store: MetaStoreConfig { ttl_seconds: 1000.0, }, @@ -256,7 +318,7 @@ mod tests { #[test] fn l1_ready_at_includes_dram_and_transform_overhead() { let cfg = test_config(RouterMode::EstimatedTtft); - let mut cluster = Cluster::new(&cfg, &cfg.model); + let mut cluster = Cluster::new(&cfg, &cfg.model).unwrap(); let req = RequestRecord { req_id: 1, chat_id: 0, @@ -282,4 +344,22 @@ mod tests { assert!(stats.ready_at > pure_pcie); } + + #[test] + fn cluster_new_rejects_bucketed_configs() { + let mut cfg = test_config(RouterMode::EstimatedTtft); + cfg.cluster.num_instances = None; + cfg.cluster.buckets = vec![BucketConfig { + name: "short".into(), + input_length_min: 0, + input_length_max: 64, + num_instances: 2, + }]; + + let result = Cluster::new(&cfg, &cfg.model); + assert!(result.is_err(), "bucketed Cluster::new should fail"); + let err = result.err().unwrap(); + assert!(err.to_string().contains("Cluster::new")); + assert!(err.to_string().contains("cluster.buckets")); + } } diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index bc36667..01893e8 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -1,6 +1,8 @@ +pub mod bucketed_service; #[allow(clippy::module_inception)] pub mod cluster; pub mod meta_store; +pub use bucketed_service::BucketedService; pub use cluster::Cluster; pub use meta_store::MetaStore; diff --git a/src/config.rs b/src/config.rs index 562ac5e..087b960 100644 --- a/src/config.rs +++ b/src/config.rs @@ -379,11 +379,76 @@ fn default_first_token_ready_us() -> f64 { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ClusterConfig { - pub num_instances: u32, + #[serde(default)] + pub num_instances: Option, + #[serde(default)] + pub buckets: Vec, + #[serde(default)] + pub global_router: GlobalRouterConfig, pub meta_store: MetaStoreConfig, pub router: RouterConfig, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct BucketConfig { + pub name: String, + pub input_length_min: u32, + pub input_length_max: u32, + pub num_instances: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct GlobalRouterConfig { + #[serde(default)] + pub mode: GlobalRouterMode, + #[serde(default = "default_global_router_length_penalty_weight")] + pub length_penalty_weight: f64, + #[serde(default = "default_global_router_load_weight")] + pub load_weight: f64, + #[serde(default = "default_global_router_cache_weight")] + pub cache_weight: f64, +} + +impl Default for GlobalRouterConfig { + fn default() -> Self { + Self { + mode: GlobalRouterMode::StrictInputLength, + length_penalty_weight: default_global_router_length_penalty_weight(), + load_weight: default_global_router_load_weight(), + cache_weight: default_global_router_cache_weight(), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum GlobalRouterMode { + #[default] + StrictInputLength, + BucketScore, +} + +impl GlobalRouterMode { + pub fn as_str(&self) -> &'static str { + match self { + Self::StrictInputLength => "strict_input_length", + Self::BucketScore => "bucket_score", + } + } +} + +fn default_global_router_length_penalty_weight() -> f64 { + 1.0 +} + +fn default_global_router_load_weight() -> f64 { + 1.0 +} + +fn default_global_router_cache_weight() -> f64 { + 1.0 +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MetaStoreConfig { pub ttl_seconds: f64, @@ -432,6 +497,105 @@ fn default_prefix_k() -> usize { 8 } +impl ClusterConfig { + pub fn require_legacy_single_pool(&self, context: &str) -> Result { + if !self.buckets.is_empty() { + anyhow::bail!("{context} does not support cluster.buckets until Task 2 lands"); + } + + self.legacy_num_instances() + .ok_or_else(|| anyhow::anyhow!("{context} requires cluster.num_instances")) + } + + pub fn legacy_num_instances(&self) -> Option { + if self.buckets.is_empty() { + self.num_instances + } else { + None + } + } + + pub fn total_instances(&self) -> u32 { + self.legacy_num_instances() + .unwrap_or_else(|| self.buckets.iter().map(|bucket| bucket.num_instances).sum()) + } + + pub fn bucket_index_for_input_len(&self, input_len: u32) -> Result { + if self.buckets.is_empty() { + return Ok(0); + } + + self.buckets + .iter() + .position(|bucket| { + bucket.input_length_min <= input_len && input_len <= bucket.input_length_max + }) + .ok_or_else(|| { + anyhow::anyhow!( + "cluster.global_router.mode={} has no bucket for input_length={input_len}", + self.global_router.mode.as_str() + ) + }) + } + + pub fn effective_buckets(&self) -> Vec { + if !self.buckets.is_empty() { + return self.buckets.clone(); + } + + vec![BucketConfig { + name: "default".to_string(), + input_length_min: 0, + input_length_max: u32::MAX, + num_instances: self + .num_instances + .expect("legacy single-pool cluster must have num_instances"), + }] + } + + pub fn validate(&self) -> Result<()> { + if self.num_instances.is_some() && !self.buckets.is_empty() { + anyhow::bail!("cluster.num_instances and cluster.buckets are mutually exclusive"); + } + + if self.buckets.is_empty() { + let num_instances = self.num_instances.ok_or_else(|| { + anyhow::anyhow!("cluster must set either num_instances or buckets") + })?; + anyhow::ensure!(num_instances > 0, "cluster.num_instances must be positive"); + return Ok(()); + } + + for bucket in &self.buckets { + anyhow::ensure!( + bucket.input_length_min <= bucket.input_length_max, + "cluster bucket '{}' has input_length_min > input_length_max", + bucket.name + ); + anyhow::ensure!( + bucket.num_instances > 0, + "cluster bucket '{}' must have num_instances > 0", + bucket.name + ); + } + + let mut sorted = self.buckets.iter().collect::>(); + sorted.sort_by_key(|bucket| (bucket.input_length_min, bucket.input_length_max)); + for pair in sorted.windows(2) { + let prev = pair[0]; + let next = pair[1]; + anyhow::ensure!( + prev.input_length_max < next.input_length_min, + "cluster buckets '{}' and '{}' overlap", + prev.name, + next.name + ); + } + + Ok(()) + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum RouterMode { @@ -544,7 +708,7 @@ impl Config { .with_context(|| format!("parsing config {}", path.display()))?; let yaml_dir = path.parent().unwrap_or(Path::new(".")); raw.resolve(yaml_dir) - .with_context(|| format!("resolving config {}", path.display())) + .map_err(|err| anyhow::anyhow!("resolving config {}: {err}", path.display())) } } @@ -678,7 +842,10 @@ impl RawConfig { model, hardware, calibration: self.calibration, - cluster: self.cluster, + cluster: { + self.cluster.validate()?; + self.cluster + }, sim: self.sim, }) } @@ -1003,4 +1170,298 @@ sim: assert_eq!(cfg.model.weight_dtype.as_deref(), Some("fp4")); assert!((cfg.model.weight_dtype_bytes() - 0.5).abs() < 1e-12); } + + #[test] + fn bucketed_config_loads_and_preserves_legacy_single_pool() { + let legacy = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + num_instances: 2 + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let cfg = Config::from_yaml_path(&legacy).unwrap(); + assert_eq!(cfg.cluster.legacy_num_instances(), Some(2)); + assert_eq!(cfg.cluster.buckets.len(), 0); + + let bucketed = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 + - name: long + input_length_min: 33 + input_length_max: 96 + num_instances: 1 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let cfg = Config::from_yaml_path(&bucketed).unwrap(); + assert_eq!(cfg.cluster.legacy_num_instances(), None); + assert_eq!(cfg.cluster.buckets.len(), 2); + assert_eq!(cfg.cluster.buckets[0].name, "short"); + assert_eq!(cfg.cluster.buckets[1].num_instances, 1); + assert_eq!(cfg.cluster.total_instances(), 3); + } + + #[test] + fn bucketed_config_deserializes_global_router_weights() { + let path = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: bucket_score + length_penalty_weight: 1.5 + load_weight: 0.75 + cache_weight: 2.25 + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + + let cfg = Config::from_yaml_path(&path).unwrap(); + assert_eq!( + cfg.cluster.global_router.mode, + GlobalRouterMode::BucketScore + ); + assert!((cfg.cluster.global_router.length_penalty_weight - 1.5).abs() < 1e-12); + assert!((cfg.cluster.global_router.load_weight - 0.75).abs() < 1e-12); + assert!((cfg.cluster.global_router.cache_weight - 2.25).abs() < 1e-12); + } + + #[test] + fn bucketed_config_rejects_overlapping_ranges_and_mixed_modes() { + let overlap = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 + - name: overlap + input_length_min: 32 + input_length_max: 96 + num_instances: 1 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let err = Config::from_yaml_path(&overlap).unwrap_err(); + assert!(err.to_string().contains("overlap")); + + let mixed = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + num_instances: 2 + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let err = Config::from_yaml_path(&mixed).unwrap_err(); + assert!(err.to_string().contains("num_instances")); + assert!(err.to_string().contains("buckets")); + } + + #[test] + fn bucketed_config_reports_unmatched_input_length_gaps_clearly() { + let gapful = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 + - name: long + input_length_min: 64 + input_length_max: 96 + num_instances: 1 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let cfg = Config::from_yaml_path(&gapful).unwrap(); + + let err = cfg.cluster.bucket_index_for_input_len(40).unwrap_err(); + assert!(err.to_string().contains("40")); + assert!(err.to_string().contains("no bucket")); + } + + #[test] + fn bucketed_config_requires_legacy_single_pool_for_legacy_runtime_paths() { + let bucketed = write_temp_config( + r#" +model: + name: test + num_layers: 4 + num_kv_heads: 2 + head_dim: 64 + dtype_bytes: 2 + block_size_tokens: 16 + flops_per_token_prefill: 1.0e9 + attn_quadratic_coeff: 64.0 +hardware: + gpu_flops: 1.0e14 + gpu_mem_bw: 1.0e12 + hbm_bytes: 1.0e9 +cluster: + meta_store: + ttl_seconds: 10.0 + router: + mode: cache_affinity + global_router: + mode: strict_input_length + buckets: + - name: short + input_length_min: 0 + input_length_max: 32 + num_instances: 2 +sim: + trace_path: trace.jsonl + output_dir: runs/test +"#, + ); + let cfg = Config::from_yaml_path(&bucketed).unwrap(); + + let err = cfg + .cluster + .require_legacy_single_pool("driver run") + .unwrap_err(); + assert!(err.to_string().contains("driver run")); + assert!(err.to_string().contains("cluster.buckets")); + } } diff --git a/src/driver.rs b/src/driver.rs index 6c1c064..0494942 100644 --- a/src/driver.rs +++ b/src/driver.rs @@ -6,7 +6,7 @@ use std::collections::{HashMap, VecDeque}; use std::path::Path; use std::sync::{Arc, Mutex}; -use crate::cluster::Cluster; +use crate::cluster::BucketedService; use crate::config::{Config, RouterMode}; use crate::metrics::ablation::AblationRow; use crate::metrics::per_request::{PerRequestRow, PerRequestWriter}; @@ -37,7 +37,9 @@ pub struct RunOutputs { #[derive(Debug, Clone)] struct InflightInfo { arrival: f64, + bucket: u32, instance: u32, + length_bucket_match: bool, total_blocks: u32, l0_hit_blocks: u32, l1_hit_blocks: u32, @@ -49,7 +51,7 @@ struct InflightInfo { } pub fn run(config: &Config, output_subdir: Option<&str>) -> Result { - let mut cluster = Cluster::new(config, &config.model); + let mut service = BucketedService::new(config, &config.model); let mut q = EventQueue::new(); // Output directory @@ -108,13 +110,16 @@ pub fn run(config: &Config, output_subdir: Option<&str>) -> Result { Some(r) => r.clone(), None => continue, }; - let stats = cluster.route_and_admit(&req, now); + let stats = service.route_and_admit(&req, now)?; rt_writer.write(&stats.decision)?; + let strict_bucket = config.cluster.bucket_index_for_input_len(req.input_len)?; inflight.insert( req_id, InflightInfo { arrival: req.arrival, + bucket: stats.bucket, instance: stats.instance, + length_bucket_match: stats.bucket as usize == strict_bucket, total_blocks: req.hash_ids.len() as u32, l0_hit_blocks: stats.l0_hit_blocks, l1_hit_blocks: stats.l1_hit_blocks, @@ -125,20 +130,23 @@ pub fn run(config: &Config, output_subdir: Option<&str>) -> Result { probe_overhead_s: stats.probe_overhead_s, }, ); - let inst = &mut cluster.instances[stats.instance as usize]; + let inst = &mut service.buckets[stats.bucket as usize].cluster.instances + [stats.instance as usize]; if !inst.tick_scheduled { inst.tick_scheduled = true; let when = stats.ready_at.max(now); q.schedule( when, Event::BatchTick { + bucket: stats.bucket, instance: stats.instance, }, ); } } - Event::BatchTick { instance } => { - let inst = &mut cluster.instances[instance as usize]; + Event::BatchTick { bucket, instance } => { + let inst = + &mut service.buckets[bucket as usize].cluster.instances[instance as usize]; inst.tick_scheduled = false; let result = inst.step(now); for (rid, ttft, end) in result.completed { @@ -148,7 +156,9 @@ pub fn run(config: &Config, output_subdir: Option<&str>) -> Result { arrival: info.arrival, ttft, e2e: end - info.arrival, + bucket: info.bucket, instance: info.instance, + length_bucket_match: info.length_bucket_match, total_blocks: info.total_blocks, l0_hit_blocks: info.l0_hit_blocks, l1_hit_blocks: info.l1_hit_blocks, @@ -163,24 +173,28 @@ pub fn run(config: &Config, output_subdir: Option<&str>) -> Result { } } if let Some(next) = result.next_tick { - let inst = &mut cluster.instances[instance as usize]; + let inst = + &mut service.buckets[bucket as usize].cluster.instances[instance as usize]; if !inst.tick_scheduled { inst.tick_scheduled = true; - q.schedule(next.max(now), Event::BatchTick { instance }); + q.schedule(next.max(now), Event::BatchTick { bucket, instance }); } } } Event::Sample => { - for inst in &cluster.instances { - let busy = if inst.queue_len() > 0 { 1 } else { 0 }; - ts_writer.write(&TimeseriesRow { - t: now, - instance: inst.id, - queue_len: inst.queue_len(), - kv_blocks_used: inst.kv_blocks_used, - kv_blocks_total: inst.hbm_block_budget, - busy, - })?; + for bucket in &service.buckets { + for inst in &bucket.cluster.instances { + let busy = if inst.queue_len() > 0 { 1 } else { 0 }; + ts_writer.write(&TimeseriesRow { + t: now, + bucket: bucket.id, + instance: inst.id, + queue_len: inst.queue_len(), + kv_blocks_used: inst.kv_blocks_used, + kv_blocks_total: inst.hbm_block_budget, + busy, + })?; + } } } Event::Stop => break, @@ -217,6 +231,8 @@ pub fn ablate_fixed_placement_with_parallelism( evict_policies: &[ReplayEvictPolicy], jobs: usize, ) -> Result> { + base.cluster + .require_legacy_single_pool("fixed-placement ablation")?; let mut out = Vec::new(); for &policy in evict_policies { if policy != ReplayEvictPolicy::Lru { diff --git a/src/main.rs b/src/main.rs index c0b6e6f..af59237 100644 --- a/src/main.rs +++ b/src/main.rs @@ -50,9 +50,14 @@ struct ConfigOverrides { } impl ConfigOverrides { - fn apply(&self, cfg: &mut Config) { + fn apply(&self, cfg: &mut Config) -> Result<()> { if let Some(n) = self.num_instances { - cfg.cluster.num_instances = n; + if !cfg.cluster.buckets.is_empty() { + anyhow::bail!( + "--num-instances does not support cluster.buckets until Task 2 lands" + ); + } + cfg.cluster.num_instances = Some(n); } if let Some(m) = self.max_requests { cfg.sim.max_requests = Some(m); @@ -78,6 +83,7 @@ impl ConfigOverrides { if let Some(hi) = self.input_length_max { cfg.sim.input_length_max = Some(hi); } + Ok(()) } } @@ -204,7 +210,8 @@ fn main() -> Result<()> { fn load(config: &PathBuf, overrides: &ConfigOverrides) -> Result { let mut cfg = Config::from_yaml_path(config)?; - overrides.apply(&mut cfg); + overrides.apply(&mut cfg)?; + cfg.cluster.validate()?; Ok(cfg) } @@ -268,7 +275,8 @@ fn cmd_ablate( auto_target_ttft_mean, probe_mode.as_str() ); - base.cluster.num_instances = chosen; + base.cluster.num_instances = Some(chosen); + base.cluster.buckets.clear(); } eprintln!( @@ -283,7 +291,7 @@ fn cmd_ablate( .map(ReplayEvictPolicy::as_str) .collect::>() .join(","), - base.cluster.num_instances, + base.cluster.total_instances(), if jobs == 0 { "auto".to_string() } else { @@ -310,6 +318,9 @@ fn auto_select_instances( probe: RouterMode, target_ttft_mean: f64, ) -> Result { + base.cluster + .require_legacy_single_pool("auto-instances calibration")?; + #[derive(serde::Serialize)] struct CalibRow { num_instances: u32, @@ -330,7 +341,7 @@ fn auto_select_instances( for &n in candidates { let mut cfg = base.clone(); - cfg.cluster.num_instances = n; + cfg.cluster.num_instances = Some(n); cfg.cluster.router.mode = probe; // Isolate calibration output so ablation runs don't overwrite it. cfg.sim.output_dir = out_root @@ -429,7 +440,7 @@ fn cmd_validate(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> { let hbm_blocks = (cfg.hardware.hbm_bytes / block_bytes) as u64; let dram_blocks = (cfg.hardware.dram_bytes / block_bytes) as u64; eprintln!("per-instance HBM blocks = {hbm_blocks}, DRAM blocks = {dram_blocks}"); - eprintln!("num_instances = {}", cfg.cluster.num_instances); + eprintln!("num_instances = {}", cfg.cluster.total_instances()); // Sample prefill times at a few prompt lengths. eprintln!("prefill_time samples:"); for &n in &[256, 1024, 4096, 16384, 65536, 131072] { @@ -460,9 +471,10 @@ fn cmd_oracle( out_path: Option<&std::path::Path>, ) -> Result<()> { let cfg = load(path, overrides)?; + cfg.cluster.require_legacy_single_pool("oracle analysis")?; let block_bytes = cfg.model.kv_block_bytes() as f64; let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64; - let aggregate_blocks = per_instance_blocks * cfg.cluster.num_instances as u64; + let aggregate_blocks = per_instance_blocks * cfg.cluster.total_instances() as u64; let capacity = match (capacity_blocks, per_instance) { (Some(_), true) => { return Err(anyhow::anyhow!( @@ -518,7 +530,7 @@ fn cmd_oracle( records.len(), capacity, per_instance_blocks, - cfg.cluster.num_instances, + cfg.cluster.total_instances(), if per_instance { ", per-instance mode" } else { @@ -541,3 +553,123 @@ fn cmd_oracle( eprintln!("[oracle] wrote {}", target.display()); Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use kvcache_simulator::config::{ + BucketConfig, CalibrationConfig, ClusterConfig, GlobalRouterConfig, HardwareConfig, + MetaStoreConfig, ModelConfig, RouterConfig, SimConfig, + }; + + fn bucketed_config(out_dir: &str) -> Config { + Config { + model: ModelConfig { + name: "test".into(), + num_layers: 4, + num_kv_heads: 2, + head_dim: 64, + dtype_bytes: 2, + block_size_tokens: 16, + flops_per_token_prefill: Some(1.0e9), + attn_quadratic_coeff: Some(64.0), + ..Default::default() + }, + hardware: HardwareConfig { + gpu_flops: 1.0e14, + gpu_fp8_flops: 0.0, + gpu_fp4_flops: 0.0, + gpu_mem_bw: 1.0e12, + hbm_bytes: 1.0e9, + dram_bytes: 4.0e9, + host_dram_bw: 5.0e11, + pcie_bw: 32.0e9, + pcie_latency_us: 1.0, + rdma_bw: 12.0e9, + rdma_latency_us: 5.0, + intra_node_tp_bw: 9.0e11, + intra_node_tp_latency_us: 2.0, + tp_degree: 1, + max_batch_slots: 32, + prefill_chunk_tokens: 1024, + }, + calibration: CalibrationConfig::default(), + cluster: ClusterConfig { + num_instances: None, + buckets: vec![ + BucketConfig { + name: "short".into(), + input_length_min: 0, + input_length_max: 64, + num_instances: 2, + }, + BucketConfig { + name: "long".into(), + input_length_min: 65, + input_length_max: 128, + num_instances: 1, + }, + ], + global_router: GlobalRouterConfig::default(), + meta_store: MetaStoreConfig { + ttl_seconds: 1000.0, + }, + router: RouterConfig { + mode: RouterMode::Random, + precise_probe_latency_us: 10.0, + precise_probe_topk: 4, + load_alpha: 0.1, + score_alpha: 1.0, + score_beta: 0.1, + prefix_k: 8, + affinity_fan_out: 0, + }, + }, + sim: SimConfig { + trace_path: "unused.jsonl".into(), + max_requests: None, + output_dir: out_dir.into(), + sample_interval_s: 0.0, + seed: 7, + input_length_min: None, + input_length_max: None, + }, + } + } + + #[test] + fn num_instances_override_rejects_bucketed_configs() { + let mut cfg = bucketed_config(std::env::temp_dir().to_str().unwrap()); + let overrides = ConfigOverrides { + num_instances: Some(8), + ..ConfigOverrides::default() + }; + + let err = overrides.apply(&mut cfg).unwrap_err(); + assert!(err.to_string().contains("--num-instances")); + assert!(err.to_string().contains("cluster.buckets")); + } + + #[test] + fn auto_instances_rejects_bucketed_configs() { + let cfg = bucketed_config(std::env::temp_dir().to_str().unwrap()); + + let err = auto_select_instances(&cfg, &[4, 8], RouterMode::Random, 1.0).unwrap_err(); + assert!(err.to_string().contains("auto-instances")); + assert!(err.to_string().contains("cluster.buckets")); + } + + #[test] + fn oracle_rejects_bucketed_configs() { + let tmp = std::env::temp_dir().join("kvcache_sim_main_tests"); + let _ = std::fs::remove_dir_all(&tmp); + std::fs::create_dir_all(&tmp).unwrap(); + let path = tmp.join("bucketed.yaml"); + let cfg = bucketed_config(tmp.to_str().unwrap()); + std::fs::write(&path, serde_yaml::to_string(&cfg).unwrap()).unwrap(); + + let err = cmd_oracle(&path, &ConfigOverrides::default(), None, false, None).unwrap_err(); + assert!(err.to_string().contains("oracle analysis")); + assert!(err.to_string().contains("cluster.buckets")); + } +} diff --git a/src/metrics/per_request.rs b/src/metrics/per_request.rs index b4ccdd8..9d2a792 100644 --- a/src/metrics/per_request.rs +++ b/src/metrics/per_request.rs @@ -8,7 +8,9 @@ pub struct PerRequestRow { pub arrival: f64, pub ttft: f64, pub e2e: f64, + pub bucket: u32, pub instance: u32, + pub length_bucket_match: bool, pub total_blocks: u32, pub l0_hit_blocks: u32, pub l1_hit_blocks: u32, diff --git a/src/metrics/timeseries.rs b/src/metrics/timeseries.rs index cb670e0..3851465 100644 --- a/src/metrics/timeseries.rs +++ b/src/metrics/timeseries.rs @@ -5,6 +5,7 @@ use std::path::Path; #[derive(Debug, Clone, Serialize)] pub struct TimeseriesRow { pub t: f64, + pub bucket: u32, pub instance: u32, pub queue_len: u32, pub kv_blocks_used: u32, diff --git a/src/oracle.rs b/src/oracle.rs index 064a2d2..590736f 100644 --- a/src/oracle.rs +++ b/src/oracle.rs @@ -196,7 +196,12 @@ fn build_next_use(records: &[RequestRecord]) -> Vec> { /// Implementation: lazy-deletion max-heap keyed by next-use index. Each /// cache entry has a version; the heap may contain stale entries from /// previous insertions, which we skip on pop. -fn run_belady(records: &[RequestRecord], next_use: &[Vec], capacity: usize, mask: &[bool]) -> u64 { +fn run_belady( + records: &[RequestRecord], + next_use: &[Vec], + capacity: usize, + mask: &[bool], +) -> u64 { if capacity == 0 { return 0; } @@ -327,10 +332,7 @@ mod tests { // req 0 populates blocks [1,2,3] but is not counted. // req 1 has prefix [1,2,3,4] — the first 3 blocks are cache hits // because req 0 populated them, even though req 0 is masked out. - let recs = vec![ - req(0, 0.0, vec![1, 2, 3]), - req(1, 1.0, vec![1, 2, 3, 4]), - ]; + let recs = vec![req(0, 0.0, vec![1, 2, 3]), req(1, 1.0, vec![1, 2, 3, 4])]; let mask = vec![false, true]; let out = analyze(&recs, 100, Some(&mask)); // Only req 1 is counted: total = 4, hits = 3 (prefix [1,2,3] hit) diff --git a/src/replay.rs b/src/replay.rs index 13a7cb5..a26b724 100644 --- a/src/replay.rs +++ b/src/replay.rs @@ -496,6 +496,8 @@ pub fn replay_fixed_placement( placements: &[PlacementEntry], policy: ReplayEvictPolicy, ) -> Result { + cfg.cluster + .require_legacy_single_pool("fixed-placement replay")?; if records.len() != placements.len() { return Err(anyhow!( "records/placements length mismatch: {} vs {}", @@ -519,7 +521,7 @@ pub fn replay_fixed_placement( let block_bytes = cfg.model.kv_block_bytes() as f64; let l0_cap = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as usize; let l1_cap = (cfg.hardware.dram_bytes / block_bytes).max(1.0) as usize; - let num_instances = cfg.cluster.num_instances as usize; + let num_instances = cfg.cluster.total_instances() as usize; let mut caches: Vec = (0..num_instances) .map(|_| ReplayInstanceCache::new(policy, l0_cap, l1_cap)) .collect(); diff --git a/src/router/adaptive_affinity.rs b/src/router/adaptive_affinity.rs index 7b94239..8b8c7cb 100644 --- a/src/router/adaptive_affinity.rs +++ b/src/router/adaptive_affinity.rs @@ -57,7 +57,7 @@ pub struct AdaptiveAffinityRouter { impl AdaptiveAffinityRouter { pub fn new(config: &Config) -> Self { - let n = config.cluster.num_instances as usize; + let n = config.cluster.total_instances() as usize; let configured_fan_out = config.cluster.router.affinity_fan_out; let max_fan_out = if configured_fan_out > 0 { configured_fan_out.max(2).min(n) @@ -242,13 +242,13 @@ impl Router for AdaptiveAffinityRouter { self.observe(fp, now); - RouteDecision { - req_id: req.req_id, - mode: "adaptive_affinity", - chosen: instances[chosen_idx].id, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "adaptive_affinity", + instances[chosen_idx].id, + 0.0, candidates, reason, - } + ) } } diff --git a/src/router/cache_affinity.rs b/src/router/cache_affinity.rs index 5e09d33..a42a1af 100644 --- a/src/router/cache_affinity.rs +++ b/src/router/cache_affinity.rs @@ -205,13 +205,13 @@ impl Router for CacheAffinityRouter { } } - RouteDecision { - req_id: req.req_id, - mode: "cache_affinity", - chosen: instances[best_idx].id, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "cache_affinity", + instances[best_idx].id, + 0.0, candidates, - reason: "argmin(α·q − γ·l0_hit − δ·meta_only) + rendezvous tiebreak", - } + "argmin(α·q − γ·l0_hit − δ·meta_only) + rendezvous tiebreak", + ) } } diff --git a/src/router/cache_load.rs b/src/router/cache_load.rs index 142e00c..3101c6f 100644 --- a/src/router/cache_load.rs +++ b/src/router/cache_load.rs @@ -77,13 +77,13 @@ impl Router for CacheLoadRouter { }); } - RouteDecision { - req_id: req.req_id, - mode: "cache_load", - chosen: instances[best_idx].id, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "cache_load", + instances[best_idx].id, + 0.0, candidates, - reason: "least-loaded 1/4, then best local L0 prefix", - } + "least-loaded 1/4, then best local L0 prefix", + ) } } diff --git a/src/router/cache_score.rs b/src/router/cache_score.rs index ab23882..a727f32 100644 --- a/src/router/cache_score.rs +++ b/src/router/cache_score.rs @@ -99,13 +99,13 @@ impl Router for CacheScoreRouter { } } - RouteDecision { - req_id: req.req_id, - mode: "cache_score", - chosen: instances[best_idx].id, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "cache_score", + instances[best_idx].id, + 0.0, candidates, - reason: "argmin 2^(α·load + β·miss)", - } + "argmin 2^(α·load + β·miss)", + ) } } diff --git a/src/router/cache_score_ttl.rs b/src/router/cache_score_ttl.rs index c67ad1d..d411247 100644 --- a/src/router/cache_score_ttl.rs +++ b/src/router/cache_score_ttl.rs @@ -74,13 +74,13 @@ impl Router for CacheScoreTtlRouter { } } - RouteDecision { - req_id: req.req_id, - mode: "cache_score_ttl", - chosen: instances[best_idx].id, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "cache_score_ttl", + instances[best_idx].id, + 0.0, candidates, - reason: "argmin 2^(alpha*load + beta*meta_store_miss)", - } + "argmin 2^(alpha*load + beta*meta_store_miss)", + ) } } diff --git a/src/router/estimated_ttft.rs b/src/router/estimated_ttft.rs index bc206bc..e4b8a69 100644 --- a/src/router/estimated_ttft.rs +++ b/src/router/estimated_ttft.rs @@ -89,13 +89,13 @@ impl Router for EstimatedTtftRouter { } } - RouteDecision { - req_id: req.req_id, - mode: "estimated_ttft", - chosen: best, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "estimated_ttft", + best, + 0.0, candidates, - reason: "argmin(drain + scheduler + kv_prepare + prefill + first_token_tail)", - } + "argmin(drain + scheduler + kv_prepare + prefill + first_token_tail)", + ) } } diff --git a/src/router/global_bucket.rs b/src/router/global_bucket.rs new file mode 100644 index 0000000..2522e7d --- /dev/null +++ b/src/router/global_bucket.rs @@ -0,0 +1,371 @@ +use anyhow::{anyhow, Result}; +use serde::Serialize; + +use crate::config::{Config, GlobalRouterMode}; +use crate::trace::RequestRecord; + +pub type BucketId = u32; + +#[derive(Debug, Clone, Serialize)] +pub struct BucketView { + pub id: BucketId, + pub input_length_min: u32, + pub input_length_max: u32, + pub num_instances: u32, + pub total_queue_len: u32, + pub total_load_blocks: u32, + pub predicted_prefix: u32, +} + +#[derive(Debug, Clone, Serialize)] +pub struct BucketCandidate { + pub bucket: BucketId, + pub input_length_min: u32, + pub input_length_max: u32, + pub num_instances: u32, + pub total_queue_len: u32, + pub total_load_blocks: u32, + pub predicted_prefix: u32, + pub matches_input_len: bool, + pub score: f64, +} + +#[derive(Debug, Clone, Serialize)] +pub struct GlobalRouteDecision { + pub req_id: u64, + pub mode: &'static str, + pub chosen_bucket: BucketId, + pub candidates: Vec, + pub reason: &'static str, +} + +impl GlobalRouteDecision { + pub fn single_bucket(req_id: u64, chosen_bucket: BucketId) -> Self { + Self { + req_id, + mode: "single_pool", + chosen_bucket, + candidates: Vec::new(), + reason: "single pool uses bucket 0", + } + } +} + +pub trait GlobalRouter: Send { + fn name(&self) -> &'static str; + fn route( + &mut self, + req: &RequestRecord, + buckets: &[BucketView], + now: f64, + ) -> Result; +} + +struct StrictInputLengthRouter { + reported_mode: &'static str, + reason: &'static str, +} + +impl StrictInputLengthRouter { + fn new(reported_mode: &'static str, reason: &'static str) -> Self { + Self { + reported_mode, + reason, + } + } +} + +impl GlobalRouter for StrictInputLengthRouter { + fn name(&self) -> &'static str { + self.reported_mode + } + + fn route( + &mut self, + req: &RequestRecord, + buckets: &[BucketView], + _now: f64, + ) -> Result { + let candidates = buckets + .iter() + .map(|view| BucketCandidate { + bucket: view.id, + input_length_min: view.input_length_min, + input_length_max: view.input_length_max, + num_instances: view.num_instances, + total_queue_len: view.total_queue_len, + total_load_blocks: view.total_load_blocks, + predicted_prefix: view.predicted_prefix, + matches_input_len: view.input_length_min <= req.input_len + && req.input_len <= view.input_length_max, + score: if view.input_length_min <= req.input_len + && req.input_len <= view.input_length_max + { + 0.0 + } else { + f64::INFINITY + }, + }) + .collect::>(); + + let matches = candidates + .iter() + .filter(|candidate| candidate.matches_input_len) + .map(|candidate| candidate.bucket) + .collect::>(); + + let chosen_bucket = match matches.as_slice() { + [bucket] => *bucket, + [] => { + return Err(anyhow!( + "cluster.global_router.mode={} has no bucket for input_length={}", + self.reported_mode, + req.input_len + )); + } + _ => { + return Err(anyhow!( + "cluster.global_router.mode={} matched multiple buckets for input_length={}", + self.reported_mode, + req.input_len + )); + } + }; + + Ok(GlobalRouteDecision { + req_id: req.req_id, + mode: self.reported_mode, + chosen_bucket, + candidates, + reason: self.reason, + }) + } +} + +struct BucketScoreRouter { + length_penalty_weight: f64, + load_weight: f64, + cache_weight: f64, +} + +impl BucketScoreRouter { + fn new(full: &Config) -> Self { + Self { + length_penalty_weight: full.cluster.global_router.length_penalty_weight, + load_weight: full.cluster.global_router.load_weight, + cache_weight: full.cluster.global_router.cache_weight, + } + } + + fn length_penalty(&self, req: &RequestRecord, bucket: &BucketView) -> f64 { + if req.input_len < bucket.input_length_min { + (bucket.input_length_min - req.input_len) as f64 + } else if req.input_len > bucket.input_length_max { + (req.input_len - bucket.input_length_max) as f64 + } else { + 0.0 + } + } +} + +impl GlobalRouter for BucketScoreRouter { + fn name(&self) -> &'static str { + "bucket_score" + } + + fn route( + &mut self, + req: &RequestRecord, + buckets: &[BucketView], + _now: f64, + ) -> Result { + let mut chosen_bucket = None; + let mut best_score = f64::INFINITY; + let mut candidates = Vec::with_capacity(buckets.len()); + + for bucket in buckets { + let length_penalty = self.length_penalty(req, bucket); + let miss = req + .hash_ids + .len() + .saturating_sub(bucket.predicted_prefix as usize) as f64; + let score = self.length_penalty_weight * length_penalty + + self.load_weight * bucket.total_queue_len as f64 + + self.cache_weight * miss; + + candidates.push(BucketCandidate { + bucket: bucket.id, + input_length_min: bucket.input_length_min, + input_length_max: bucket.input_length_max, + num_instances: bucket.num_instances, + total_queue_len: bucket.total_queue_len, + total_load_blocks: bucket.total_load_blocks, + predicted_prefix: bucket.predicted_prefix, + matches_input_len: bucket.input_length_min <= req.input_len + && req.input_len <= bucket.input_length_max, + score, + }); + + let better = score < best_score + || (score == best_score && chosen_bucket.is_none_or(|best| bucket.id < best)); + if better { + best_score = score; + chosen_bucket = Some(bucket.id); + } + } + + Ok(GlobalRouteDecision { + req_id: req.req_id, + mode: self.name(), + chosen_bucket: chosen_bucket.ok_or_else(|| anyhow!("no buckets available"))?, + candidates, + reason: "weighted length/load/cache bucket score", + }) + } +} + +pub fn build(full: &Config) -> Box { + match full.cluster.global_router.mode { + GlobalRouterMode::StrictInputLength => Box::new(StrictInputLengthRouter::new( + "strict_input_length", + "unique bucket range contains input_length", + )) as Box, + GlobalRouterMode::BucketScore => { + Box::new(BucketScoreRouter::new(full)) as Box + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{ + ClusterConfig, GlobalRouterConfig, MetaStoreConfig, RouterConfig, RouterMode, + }; + + fn cfg() -> Config { + Config { + model: crate::config::ModelConfig::default(), + hardware: crate::config::HardwareConfig { + gpu_flops: 1.0, + gpu_fp8_flops: 0.0, + gpu_fp4_flops: 0.0, + gpu_mem_bw: 1.0, + hbm_bytes: 1.0, + dram_bytes: 1.0, + host_dram_bw: 1.0, + pcie_bw: 1.0, + pcie_latency_us: 1.0, + rdma_bw: 1.0, + rdma_latency_us: 1.0, + intra_node_tp_bw: 1.0, + intra_node_tp_latency_us: 1.0, + tp_degree: 1, + max_batch_slots: 1, + prefill_chunk_tokens: 1, + }, + calibration: crate::config::CalibrationConfig::default(), + cluster: ClusterConfig { + num_instances: None, + buckets: Vec::new(), + global_router: GlobalRouterConfig { + mode: GlobalRouterMode::BucketScore, + length_penalty_weight: 1.0, + load_weight: 1.0, + cache_weight: 1.0, + }, + meta_store: MetaStoreConfig { ttl_seconds: 1.0 }, + router: RouterConfig { + mode: RouterMode::LeastLoaded, + precise_probe_latency_us: 1.0, + precise_probe_topk: 1, + load_alpha: 1.0, + score_alpha: 1.0, + score_beta: 1.0, + prefix_k: 8, + affinity_fan_out: 1, + }, + }, + sim: crate::config::SimConfig { + trace_path: String::new(), + max_requests: None, + output_dir: String::new(), + sample_interval_s: 0.0, + seed: 0, + input_length_min: None, + input_length_max: None, + }, + } + } + + fn req(input_len: u32) -> RequestRecord { + RequestRecord { + req_id: 1, + chat_id: 0, + parent_chat_id: -1, + turn: 0, + arrival: 0.0, + input_len, + output_len: 16, + hash_ids: vec![10, 11, 12], + } + } + + #[test] + fn bucket_score_prefers_matching_bucket_when_load_is_equal() { + let mut router = BucketScoreRouter::new(&cfg()); + let buckets = vec![ + BucketView { + id: 0, + input_length_min: 0, + input_length_max: 32, + num_instances: 2, + total_queue_len: 1, + total_load_blocks: 0, + predicted_prefix: 0, + }, + BucketView { + id: 1, + input_length_min: 33, + input_length_max: 96, + num_instances: 2, + total_queue_len: 1, + total_load_blocks: 0, + predicted_prefix: 0, + }, + ]; + let decision = router.route(&req(24), &buckets, 0.0).unwrap(); + assert_eq!(decision.chosen_bucket, 0); + } + + #[test] + fn bucket_score_can_override_length_match_when_load_gap_is_large() { + let mut full = cfg(); + full.cluster.global_router.load_weight = 5.0; + full.cluster.global_router.cache_weight = 1.0; + full.cluster.global_router.length_penalty_weight = 1.0; + let mut router = BucketScoreRouter::new(&full); + let buckets = vec![ + BucketView { + id: 0, + input_length_min: 0, + input_length_max: 32, + num_instances: 2, + total_queue_len: 20, + total_load_blocks: 0, + predicted_prefix: 0, + }, + BucketView { + id: 1, + input_length_min: 33, + input_length_max: 96, + num_instances: 2, + total_queue_len: 0, + total_load_blocks: 0, + predicted_prefix: 2, + }, + ]; + let decision = router.route(&req(24), &buckets, 0.0).unwrap(); + assert_eq!(decision.chosen_bucket, 1); + } +} diff --git a/src/router/least_loaded.rs b/src/router/least_loaded.rs index efc0ed8..272cb93 100644 --- a/src/router/least_loaded.rs +++ b/src/router/least_loaded.rs @@ -41,13 +41,13 @@ impl Router for LeastLoadedRouter { best = inst.id; } } - RouteDecision { - req_id: req.req_id, - mode: "least_loaded", - chosen: best, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "least_loaded", + best, + 0.0, candidates, - reason: "argmin(kv_used + alpha * queue_len)", - } + "argmin(kv_used + alpha * queue_len)", + ) } } diff --git a/src/router/least_tokens.rs b/src/router/least_tokens.rs index effdad7..2c06427 100644 --- a/src/router/least_tokens.rs +++ b/src/router/least_tokens.rs @@ -61,13 +61,13 @@ impl Router for LeastTokensRouter { } } - RouteDecision { - req_id: req.req_id, - mode: "least_tokens", - chosen: best, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "least_tokens", + best, + 0.0, candidates, - reason: "argmin(waiting_prefill_tokens)", - } + "argmin(waiting_prefill_tokens)", + ) } } diff --git a/src/router/lineage_affinity.rs b/src/router/lineage_affinity.rs index f8d51e6..0e9035b 100644 --- a/src/router/lineage_affinity.rs +++ b/src/router/lineage_affinity.rs @@ -53,7 +53,7 @@ pub struct LineageAffinityRouter { impl LineageAffinityRouter { pub fn new(config: &Config) -> Self { - let n = config.cluster.num_instances as usize; + let n = config.cluster.total_instances() as usize; let configured_fan_out = config.cluster.router.affinity_fan_out; let max_fan_out = if configured_fan_out > 0 { configured_fan_out.max(2).min(n) @@ -231,13 +231,13 @@ impl Router for LineageAffinityRouter { self.request_home .insert(req.chat_id, instances[chosen.idx].id); - RouteDecision { - req_id: req.req_id, - mode: "lineage_affinity", - chosen: instances[chosen.idx].id, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "lineage_affinity", + instances[chosen.idx].id, + 0.0, candidates, reason, - } + ) } } diff --git a/src/router/min_pd.rs b/src/router/min_pd.rs index f801639..2033dee 100644 --- a/src/router/min_pd.rs +++ b/src/router/min_pd.rs @@ -90,13 +90,13 @@ impl Router for MinPdRouter { } } - RouteDecision { - req_id: req.req_id, - mode: "min_pd", - chosen: best, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "min_pd", + best, + 0.0, candidates, - reason: "argmin(P*D), P=local-L0 miss tokens, D=ongoing reqs", - } + "argmin(P*D), P=local-L0 miss tokens, D=ongoing reqs", + ) } } diff --git a/src/router/mod.rs b/src/router/mod.rs index 7773e3d..ef68dd5 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -6,6 +6,7 @@ pub mod cache_load; pub mod cache_score; pub mod cache_score_ttl; pub mod estimated_ttft; +pub mod global_bucket; pub mod least_loaded; pub mod least_tokens; pub mod lineage_affinity; @@ -23,6 +24,8 @@ use crate::instance::Instance; use crate::trace::RequestRecord; use crate::types::InstanceId; +pub use global_bucket::{BucketCandidate, BucketId, BucketView, GlobalRouteDecision, GlobalRouter}; + #[derive(Debug, Clone, Serialize)] pub struct CandidateInfo { pub instance: InstanceId, @@ -34,11 +37,25 @@ pub struct CandidateInfo { #[derive(Debug, Clone, Serialize)] pub struct RouteDecision { pub req_id: u64, + pub global_mode: &'static str, pub mode: &'static str, + pub global_reason: &'static str, + pub local_reason: &'static str, + pub chosen_bucket: BucketId, pub chosen: InstanceId, pub probe_overhead_s: f64, + pub bucket_candidates: Vec, pub candidates: Vec, - pub reason: &'static str, +} + +impl RouteDecision { + pub fn with_global(mut self, decision: &GlobalRouteDecision) -> Self { + self.global_mode = decision.mode; + self.global_reason = decision.reason; + self.chosen_bucket = decision.chosen_bucket; + self.bucket_candidates = decision.candidates.clone(); + self + } } pub trait Router: Send { @@ -63,6 +80,28 @@ pub(crate) fn local_l0_scores(req: &RequestRecord, instances: &[Instance]) -> Ve .collect() } +pub fn local_route_decision( + req_id: u64, + mode: &'static str, + chosen: InstanceId, + probe_overhead_s: f64, + candidates: Vec, + reason: &'static str, +) -> RouteDecision { + RouteDecision { + req_id, + global_mode: "single_pool", + mode, + global_reason: "single pool uses bucket 0", + local_reason: reason, + chosen_bucket: 0, + chosen, + probe_overhead_s, + bucket_candidates: Vec::new(), + candidates, + } +} + pub fn build(full: &Config, seed: u64) -> Box { use crate::config::RouterMode::*; let cfg = &full.cluster.router; @@ -122,6 +161,10 @@ pub fn build(full: &Config, seed: u64) -> Box { } } +pub fn build_global(full: &Config) -> Box { + global_bucket::build(full) +} + #[cfg(test)] mod tests { use super::*; @@ -181,7 +224,9 @@ mod tests { hardware: test_hardware(), calibration: CalibrationConfig::default(), cluster: ClusterConfig { - num_instances: 2, + num_instances: Some(2), + buckets: Vec::new(), + global_router: Default::default(), meta_store: MetaStoreConfig { ttl_seconds: 1000.0, }, @@ -351,7 +396,7 @@ mod tests { let mut router = PrefixAffinityRouter::new(&cfg); let decision = router.route(&req, &instances, &meta, 0.0); - assert_eq!(decision.reason, "affinity fallback: min(drain+fetch)"); + assert_eq!(decision.local_reason, "affinity fallback: min(drain+fetch)"); assert_eq!(decision.chosen, 1); } diff --git a/src/router/precise_aware.rs b/src/router/precise_aware.rs index 9940815..e0d90cd 100644 --- a/src/router/precise_aware.rs +++ b/src/router/precise_aware.rs @@ -62,13 +62,13 @@ impl Router for PreciseRouter { } } - RouteDecision { - req_id: req.req_id, - mode: "precise", - chosen: best, - probe_overhead_s: n as f64 * self.probe_latency_s, + crate::router::local_route_decision( + req.req_id, + "precise", + best, + n as f64 * self.probe_latency_s, candidates, - reason: "exact-probe all instances' L0 cache", - } + "exact-probe all instances' L0 cache", + ) } } diff --git a/src/router/prefix_affinity.rs b/src/router/prefix_affinity.rs index 8f99b7a..bde3443 100644 --- a/src/router/prefix_affinity.rs +++ b/src/router/prefix_affinity.rs @@ -51,7 +51,7 @@ pub struct PrefixAffinityRouter { impl PrefixAffinityRouter { pub fn new(config: &Config) -> Self { - let n = config.cluster.num_instances as usize; + let n = config.cluster.total_instances() as usize; let cfg_fan = config.cluster.router.affinity_fan_out; // fan_out: if configured, use it; otherwise auto = max(2, n/8). let fan_out = if cfg_fan > 0 { @@ -166,13 +166,13 @@ impl Router for PrefixAffinityRouter { reason = "prefix affinity: top-K min drain"; } - RouteDecision { - req_id: req.req_id, - mode: "prefix_affinity", - chosen: instances[best_idx].id, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "prefix_affinity", + instances[best_idx].id, + 0.0, candidates, reason, - } + ) } } diff --git a/src/router/random.rs b/src/router/random.rs index c4bfb17..8f2deae 100644 --- a/src/router/random.rs +++ b/src/router/random.rs @@ -33,19 +33,19 @@ impl Router for RandomRouter { ) -> RouteDecision { let n = instances.len(); let chosen = self.rng.gen_range(0..n) as InstanceId; - RouteDecision { - req_id: req.req_id, - mode: "random", + crate::router::local_route_decision( + req.req_id, + "random", chosen, - probe_overhead_s: 0.0, - candidates: vec![CandidateInfo { + 0.0, + vec![CandidateInfo { instance: chosen, predicted_prefix: 0, load_blocks: instances[chosen as usize].kv_blocks_used, queue_len: instances[chosen as usize].queue_len(), }], - reason: "uniform random", - } + "uniform random", + ) } } @@ -75,18 +75,18 @@ impl Router for RoundRobinRouter { let n = instances.len() as u32; let chosen = self.next % n; self.next = self.next.wrapping_add(1); - RouteDecision { - req_id: req.req_id, - mode: "round_robin", + crate::router::local_route_decision( + req.req_id, + "round_robin", chosen, - probe_overhead_s: 0.0, - candidates: vec![CandidateInfo { + 0.0, + vec![CandidateInfo { instance: chosen, predicted_prefix: 0, load_blocks: instances[chosen as usize].kv_blocks_used, queue_len: instances[chosen as usize].queue_len(), }], - reason: "round robin", - } + "round robin", + ) } } diff --git a/src/router/ttl_aware.rs b/src/router/ttl_aware.rs index 04481d0..6206ae7 100644 --- a/src/router/ttl_aware.rs +++ b/src/router/ttl_aware.rs @@ -46,13 +46,13 @@ impl Router for TtlAwareRouter { best = inst.id; } } - RouteDecision { - req_id: req.req_id, - mode: "ttl_aware", - chosen: best, - probe_overhead_s: 0.0, + crate::router::local_route_decision( + req.req_id, + "ttl_aware", + best, + 0.0, candidates, - reason: "max meta_store prefix, tie -> least loaded", - } + "max meta_store prefix, tie -> least loaded", + ) } } diff --git a/src/sim/engine.rs b/src/sim/engine.rs index 7a515e6..83651a3 100644 --- a/src/sim/engine.rs +++ b/src/sim/engine.rs @@ -91,11 +91,24 @@ mod tests { q.schedule( 2.0, Event::BatchTick { + bucket: 0, instance: 0 as InstanceId, }, ); - q.schedule(1.0, Event::BatchTick { instance: 1 }); - q.schedule(1.5, Event::BatchTick { instance: 2 }); + q.schedule( + 1.0, + Event::BatchTick { + bucket: 0, + instance: 1, + }, + ); + q.schedule( + 1.5, + Event::BatchTick { + bucket: 0, + instance: 2, + }, + ); let (t1, _) = q.pop().unwrap(); let (t2, _) = q.pop().unwrap(); let (t3, _) = q.pop().unwrap(); @@ -107,12 +120,24 @@ mod tests { #[test] fn equal_time_fifo() { let mut q = EventQueue::new(); - q.schedule(1.0, Event::BatchTick { instance: 7 }); - q.schedule(1.0, Event::BatchTick { instance: 8 }); + q.schedule( + 1.0, + Event::BatchTick { + bucket: 0, + instance: 7, + }, + ); + q.schedule( + 1.0, + Event::BatchTick { + bucket: 1, + instance: 8, + }, + ); let (_, e1) = q.pop().unwrap(); let (_, e2) = q.pop().unwrap(); match (e1, e2) { - (Event::BatchTick { instance: a }, Event::BatchTick { instance: b }) => { + (Event::BatchTick { instance: a, .. }, Event::BatchTick { instance: b, .. }) => { assert_eq!(a, 7); assert_eq!(b, 8); } diff --git a/src/sim/events.rs b/src/sim/events.rs index e369fa2..c8847a3 100644 --- a/src/sim/events.rs +++ b/src/sim/events.rs @@ -1,5 +1,6 @@ //! Event types for the discrete-event engine. +use crate::router::BucketId; use crate::types::{InstanceId, ReqId}; #[derive(Debug)] @@ -7,7 +8,10 @@ pub enum Event { /// New trace request arrives at the cluster router. Arrival { req_id: ReqId }, /// Per-instance scheduler tick (continuous batching). - BatchTick { instance: InstanceId }, + BatchTick { + bucket: BucketId, + instance: InstanceId, + }, /// Periodic time-series sample of all instances. Sample, /// Stop the simulation early (used internally). diff --git a/tests/smoke.rs b/tests/smoke.rs index 700408d..bd5ec38 100644 --- a/tests/smoke.rs +++ b/tests/smoke.rs @@ -6,7 +6,7 @@ use std::io::Write; use kvcache_simulator::config::*; use kvcache_simulator::driver; -use kvcache_simulator::replay::ReplayEvictPolicy; +use kvcache_simulator::replay::{self, PlacementEntry, ReplayEvictPolicy}; fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config { Config { @@ -41,7 +41,9 @@ fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config { }, calibration: CalibrationConfig::default(), cluster: ClusterConfig { - num_instances: 4, + num_instances: Some(4), + buckets: Vec::new(), + global_router: Default::default(), meta_store: MetaStoreConfig { ttl_seconds: 1000.0, }, @@ -68,6 +70,26 @@ fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config { } } +fn bucketed_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config { + let mut cfg = base_config(trace_path, out_dir, mode); + cfg.cluster.num_instances = None; + cfg.cluster.buckets = vec![ + BucketConfig { + name: "short".into(), + input_length_min: 0, + input_length_max: 64, + num_instances: 2, + }, + BucketConfig { + name: "long".into(), + input_length_min: 65, + input_length_max: 128, + num_instances: 1, + }, + ]; + cfg +} + fn write_synthetic_trace(path: &std::path::Path) { // 5 distinct conversations, each with 8 turns. Within a conversation, // turn k+1 reuses the prefix of turn k (shared first ~10 blocks) and @@ -272,3 +294,160 @@ fn ablation_parallel_matches_serial() { assert!((lhs.miss_rate - rhs.miss_rate).abs() < 1e-12); } } + +#[test] +fn strict_bucket_run_emits_bucket_fields_in_outputs() { + let tmp = std::env::temp_dir().join("kvcache_sim_bucket_outputs"); + let _ = std::fs::remove_dir_all(&tmp); + std::fs::create_dir_all(&tmp).unwrap(); + let trace_path = tmp.join("trace.jsonl"); + + let mut f = std::fs::File::create(&trace_path).unwrap(); + writeln!( + f, + "{}", + serde_json::json!({ + "chat_id": 1, + "parent_chat_id": -1, + "timestamp": 0.0, + "input_length": 32, + "output_length": 16, + "type": "text", + "turn": 0, + "hash_ids": [1, 2] + }) + ) + .unwrap(); + writeln!( + f, + "{}", + serde_json::json!({ + "chat_id": 2, + "parent_chat_id": -1, + "timestamp": 0.1, + "input_length": 80, + "output_length": 16, + "type": "text", + "turn": 0, + "hash_ids": [3, 4, 5, 6, 7] + }) + ) + .unwrap(); + + let mut cfg = bucketed_config( + trace_path.to_str().unwrap(), + tmp.to_str().unwrap(), + RouterMode::LeastLoaded, + ); + cfg.cluster.global_router.mode = GlobalRouterMode::StrictInputLength; + cfg.sim.sample_interval_s = 0.05; + + let _ = driver::run(&cfg, Some("strict_bucket")).expect("bucketed run"); + + let per_request = std::fs::read_to_string(tmp.join("strict_bucket/per_request.csv")).unwrap(); + assert!(per_request.contains("bucket")); + assert!(per_request.contains("length_bucket_match")); + + let instances = std::fs::read_to_string(tmp.join("strict_bucket/instances.csv")).unwrap(); + assert!(instances.contains("bucket")); + + let routing_log = std::fs::read_to_string(tmp.join("strict_bucket/routing_log.jsonl")).unwrap(); + assert!(routing_log.contains("\"chosen_bucket\"")); + assert!(routing_log.contains("\"bucket_candidates\"")); + assert!(routing_log.contains("\"global_reason\"")); +} + +#[test] +fn bucketed_configs_are_rejected_by_legacy_fixed_placement_paths() { + let tmp = std::env::temp_dir().join("kvcache_sim_bucketed_reject"); + let _ = std::fs::remove_dir_all(&tmp); + std::fs::create_dir_all(&tmp).unwrap(); + let trace_path = tmp.join("trace.jsonl"); + write_synthetic_trace(&trace_path); + + let cfg = bucketed_config( + trace_path.to_str().unwrap(), + tmp.to_str().unwrap(), + RouterMode::Random, + ); + + let err = + driver::ablate_fixed_placement(&cfg, &[RouterMode::Random], &[ReplayEvictPolicy::Lru]) + .expect_err("bucketed ablation should fail"); + assert!(err.to_string().contains("cluster.buckets")); + + let err = replay::replay_fixed_placement( + &cfg, + &[], + &Vec::::new(), + ReplayEvictPolicy::Lru, + ) + .expect_err("bucketed replay should fail"); + assert!(err.to_string().contains("cluster.buckets")); +} + +#[test] +fn bucket_score_can_deviate_from_strict_length_bucket() { + let tmp = std::env::temp_dir().join("kvcache_sim_bucket_score"); + let _ = std::fs::remove_dir_all(&tmp); + std::fs::create_dir_all(&tmp).unwrap(); + let trace_path = tmp.join("trace.jsonl"); + + let mut f = std::fs::File::create(&trace_path).unwrap(); + for req_id in 0..3 { + writeln!( + f, + "{}", + serde_json::json!({ + "chat_id": req_id, + "parent_chat_id": -1, + "timestamp": 0.0, + "input_length": 24, + "output_length": 16, + "type": "text", + "turn": 0, + "hash_ids": [100 + req_id, 200 + req_id] + }) + ) + .unwrap(); + } + + let mut strict_cfg = bucketed_config( + trace_path.to_str().unwrap(), + tmp.to_str().unwrap(), + RouterMode::LeastLoaded, + ); + strict_cfg.cluster.buckets = vec![ + BucketConfig { + name: "short".into(), + input_length_min: 0, + input_length_max: 32, + num_instances: 1, + }, + BucketConfig { + name: "long".into(), + input_length_min: 33, + input_length_max: 96, + num_instances: 1, + }, + ]; + strict_cfg.cluster.global_router.mode = GlobalRouterMode::StrictInputLength; + + let mut score_cfg = strict_cfg.clone(); + score_cfg.cluster.global_router.mode = GlobalRouterMode::BucketScore; + score_cfg.cluster.global_router.length_penalty_weight = 1.0; + score_cfg.cluster.global_router.load_weight = 10.0; + score_cfg.cluster.global_router.cache_weight = 0.0; + + let _ = driver::run(&strict_cfg, Some("strict_score_cmp")).expect("strict run"); + let _ = driver::run(&score_cfg, Some("bucket_score_cmp")).expect("bucket score run"); + + let strict_log = + std::fs::read_to_string(tmp.join("strict_score_cmp/routing_log.jsonl")).unwrap(); + let score_log = + std::fs::read_to_string(tmp.join("bucket_score_cmp/routing_log.jsonl")).unwrap(); + + assert!(strict_log.contains("\"chosen_bucket\":0")); + assert!(score_log.contains("\"global_mode\":\"bucket_score\"")); + assert!(score_log.contains("\"chosen_bucket\":1")); +}