feat: add bucket score global router

This commit is contained in:
2026-04-17 17:55:54 +08:00
parent b5a6fb964c
commit 43ada0cfc0
4 changed files with 311 additions and 23 deletions

View File

@@ -52,7 +52,7 @@ impl BucketedService {
let bucket_views = self let bucket_views = self
.buckets .buckets
.iter() .iter()
.map(|bucket| bucket.cluster.bucket_view(bucket.id, &bucket.cfg)) .map(|bucket| bucket.cluster.bucket_view(bucket.id, &bucket.cfg, req, now))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let global = self.global_router.route(req, &bucket_views, now)?; let global = self.global_router.route(req, &bucket_views, now)?;
let bucket = &mut self.buckets[global.chosen_bucket as usize]; let bucket = &mut self.buckets[global.chosen_bucket as usize];
@@ -213,21 +213,4 @@ mod tests {
assert!(err.to_string().contains("no bucket")); assert!(err.to_string().contains("no bucket"));
assert!(err.to_string().contains("input_length=36")); assert!(err.to_string().contains("input_length=36"));
} }
#[test]
fn bucket_score_placeholder_reports_strict_fallback() {
let mut cfg = test_config();
cfg.cluster.global_router.mode = GlobalRouterMode::BucketScore;
let mut service = BucketedService::new(&cfg, &cfg.model);
let stats = service
.route_and_admit(&req(4, 24, &[30, 31]), 0.0)
.unwrap();
assert_eq!(stats.decision.global_mode, "strict_input_length");
assert!(stats
.decision
.global_reason
.contains("bucket_score is not implemented"));
}
} }

View File

@@ -171,7 +171,19 @@ impl Cluster {
} }
} }
pub fn bucket_view(&self, bucket_id: BucketId, cfg: &BucketConfig) -> BucketView { pub fn bucket_view(
&self,
bucket_id: BucketId,
cfg: &BucketConfig,
req: &RequestRecord,
now: f64,
) -> BucketView {
let predicted_prefix = self
.meta_store
.score_prefix(&req.hash_ids, now, self.instances.len())
.into_iter()
.max()
.unwrap_or(0);
BucketView { BucketView {
id: bucket_id, id: bucket_id,
input_length_min: cfg.input_length_min, input_length_min: cfg.input_length_min,
@@ -179,6 +191,7 @@ impl Cluster {
num_instances: self.instances.len() as u32, num_instances: self.instances.len() as u32,
total_queue_len: self.instances.iter().map(Instance::queue_len).sum(), total_queue_len: self.instances.iter().map(Instance::queue_len).sum(),
total_load_blocks: self.instances.iter().map(|inst| inst.kv_blocks_used).sum(), total_load_blocks: self.instances.iter().map(|inst| inst.kv_blocks_used).sum(),
predicted_prefix,
} }
} }

View File

@@ -14,6 +14,7 @@ pub struct BucketView {
pub num_instances: u32, pub num_instances: u32,
pub total_queue_len: u32, pub total_queue_len: u32,
pub total_load_blocks: u32, pub total_load_blocks: u32,
pub predicted_prefix: u32,
} }
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
@@ -24,7 +25,9 @@ pub struct BucketCandidate {
pub num_instances: u32, pub num_instances: u32,
pub total_queue_len: u32, pub total_queue_len: u32,
pub total_load_blocks: u32, pub total_load_blocks: u32,
pub predicted_prefix: u32,
pub matches_input_len: bool, pub matches_input_len: bool,
pub score: f64,
} }
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
@@ -92,8 +95,16 @@ impl GlobalRouter for StrictInputLengthRouter {
num_instances: view.num_instances, num_instances: view.num_instances,
total_queue_len: view.total_queue_len, total_queue_len: view.total_queue_len,
total_load_blocks: view.total_load_blocks, total_load_blocks: view.total_load_blocks,
predicted_prefix: view.predicted_prefix,
matches_input_len: view.input_length_min <= req.input_len matches_input_len: view.input_length_min <= req.input_len
&& req.input_len <= view.input_length_max, && req.input_len <= view.input_length_max,
score: if view.input_length_min <= req.input_len
&& req.input_len <= view.input_length_max
{
0.0
} else {
f64::INFINITY
},
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@@ -131,15 +142,230 @@ impl GlobalRouter for StrictInputLengthRouter {
} }
} }
struct BucketScoreRouter {
length_penalty_weight: f64,
load_weight: f64,
cache_weight: f64,
}
impl BucketScoreRouter {
fn new(full: &Config) -> Self {
Self {
length_penalty_weight: full.cluster.global_router.length_penalty_weight,
load_weight: full.cluster.global_router.load_weight,
cache_weight: full.cluster.global_router.cache_weight,
}
}
fn length_penalty(&self, req: &RequestRecord, bucket: &BucketView) -> f64 {
if req.input_len < bucket.input_length_min {
(bucket.input_length_min - req.input_len) as f64
} else if req.input_len > bucket.input_length_max {
(req.input_len - bucket.input_length_max) as f64
} else {
0.0
}
}
}
impl GlobalRouter for BucketScoreRouter {
fn name(&self) -> &'static str {
"bucket_score"
}
fn route(
&mut self,
req: &RequestRecord,
buckets: &[BucketView],
_now: f64,
) -> Result<GlobalRouteDecision> {
let mut chosen_bucket = None;
let mut best_score = f64::INFINITY;
let mut candidates = Vec::with_capacity(buckets.len());
for bucket in buckets {
let length_penalty = self.length_penalty(req, bucket);
let miss = req
.hash_ids
.len()
.saturating_sub(bucket.predicted_prefix as usize) as f64;
let score = self.length_penalty_weight * length_penalty
+ self.load_weight * bucket.total_queue_len as f64
+ self.cache_weight * miss;
candidates.push(BucketCandidate {
bucket: bucket.id,
input_length_min: bucket.input_length_min,
input_length_max: bucket.input_length_max,
num_instances: bucket.num_instances,
total_queue_len: bucket.total_queue_len,
total_load_blocks: bucket.total_load_blocks,
predicted_prefix: bucket.predicted_prefix,
matches_input_len: bucket.input_length_min <= req.input_len
&& req.input_len <= bucket.input_length_max,
score,
});
let better = score < best_score
|| (score == best_score && chosen_bucket.is_none_or(|best| bucket.id < best));
if better {
best_score = score;
chosen_bucket = Some(bucket.id);
}
}
Ok(GlobalRouteDecision {
req_id: req.req_id,
mode: self.name(),
chosen_bucket: chosen_bucket.ok_or_else(|| anyhow!("no buckets available"))?,
candidates,
reason: "weighted length/load/cache bucket score",
})
}
}
pub fn build(full: &Config) -> Box<dyn GlobalRouter> { pub fn build(full: &Config) -> Box<dyn GlobalRouter> {
match full.cluster.global_router.mode { match full.cluster.global_router.mode {
GlobalRouterMode::StrictInputLength => Box::new(StrictInputLengthRouter::new( GlobalRouterMode::StrictInputLength => Box::new(StrictInputLengthRouter::new(
"strict_input_length", "strict_input_length",
"unique bucket range contains input_length", "unique bucket range contains input_length",
)) as Box<dyn GlobalRouter>, )) as Box<dyn GlobalRouter>,
GlobalRouterMode::BucketScore => Box::new(StrictInputLengthRouter::new( GlobalRouterMode::BucketScore => {
"strict_input_length", Box::new(BucketScoreRouter::new(full)) as Box<dyn GlobalRouter>
"bucket_score is not implemented in Task 2; falling back to strict_input_length", }
)) as Box<dyn GlobalRouter>, }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{
ClusterConfig, GlobalRouterConfig, MetaStoreConfig, RouterConfig, RouterMode,
};
fn cfg() -> Config {
Config {
model: crate::config::ModelConfig::default(),
hardware: crate::config::HardwareConfig {
gpu_flops: 1.0,
gpu_fp8_flops: 0.0,
gpu_fp4_flops: 0.0,
gpu_mem_bw: 1.0,
hbm_bytes: 1.0,
dram_bytes: 1.0,
host_dram_bw: 1.0,
pcie_bw: 1.0,
pcie_latency_us: 1.0,
rdma_bw: 1.0,
rdma_latency_us: 1.0,
intra_node_tp_bw: 1.0,
intra_node_tp_latency_us: 1.0,
tp_degree: 1,
max_batch_slots: 1,
prefill_chunk_tokens: 1,
},
calibration: crate::config::CalibrationConfig::default(),
cluster: ClusterConfig {
num_instances: None,
buckets: Vec::new(),
global_router: GlobalRouterConfig {
mode: GlobalRouterMode::BucketScore,
length_penalty_weight: 1.0,
load_weight: 1.0,
cache_weight: 1.0,
},
meta_store: MetaStoreConfig { ttl_seconds: 1.0 },
router: RouterConfig {
mode: RouterMode::LeastLoaded,
precise_probe_latency_us: 1.0,
precise_probe_topk: 1,
load_alpha: 1.0,
score_alpha: 1.0,
score_beta: 1.0,
prefix_k: 8,
affinity_fan_out: 1,
},
},
sim: crate::config::SimConfig {
trace_path: String::new(),
max_requests: None,
output_dir: String::new(),
sample_interval_s: 0.0,
seed: 0,
input_length_min: None,
input_length_max: None,
},
}
}
fn req(input_len: u32) -> RequestRecord {
RequestRecord {
req_id: 1,
chat_id: 0,
parent_chat_id: -1,
turn: 0,
arrival: 0.0,
input_len,
output_len: 16,
hash_ids: vec![10, 11, 12],
}
}
#[test]
fn bucket_score_prefers_matching_bucket_when_load_is_equal() {
let mut router = BucketScoreRouter::new(&cfg());
let buckets = vec![
BucketView {
id: 0,
input_length_min: 0,
input_length_max: 32,
num_instances: 2,
total_queue_len: 1,
total_load_blocks: 0,
predicted_prefix: 0,
},
BucketView {
id: 1,
input_length_min: 33,
input_length_max: 96,
num_instances: 2,
total_queue_len: 1,
total_load_blocks: 0,
predicted_prefix: 0,
},
];
let decision = router.route(&req(24), &buckets, 0.0).unwrap();
assert_eq!(decision.chosen_bucket, 0);
}
#[test]
fn bucket_score_can_override_length_match_when_load_gap_is_large() {
let mut full = cfg();
full.cluster.global_router.load_weight = 5.0;
full.cluster.global_router.cache_weight = 1.0;
full.cluster.global_router.length_penalty_weight = 1.0;
let mut router = BucketScoreRouter::new(&full);
let buckets = vec![
BucketView {
id: 0,
input_length_min: 0,
input_length_max: 32,
num_instances: 2,
total_queue_len: 20,
total_load_blocks: 0,
predicted_prefix: 0,
},
BucketView {
id: 1,
input_length_min: 33,
input_length_max: 96,
num_instances: 2,
total_queue_len: 0,
total_load_blocks: 0,
predicted_prefix: 2,
},
];
let decision = router.route(&req(24), &buckets, 0.0).unwrap();
assert_eq!(decision.chosen_bucket, 1);
} }
} }

View File

@@ -385,3 +385,69 @@ fn bucketed_configs_are_rejected_by_legacy_fixed_placement_paths() {
.expect_err("bucketed replay should fail"); .expect_err("bucketed replay should fail");
assert!(err.to_string().contains("cluster.buckets")); assert!(err.to_string().contains("cluster.buckets"));
} }
#[test]
fn bucket_score_can_deviate_from_strict_length_bucket() {
let tmp = std::env::temp_dir().join("kvcache_sim_bucket_score");
let _ = std::fs::remove_dir_all(&tmp);
std::fs::create_dir_all(&tmp).unwrap();
let trace_path = tmp.join("trace.jsonl");
let mut f = std::fs::File::create(&trace_path).unwrap();
for req_id in 0..3 {
writeln!(
f,
"{}",
serde_json::json!({
"chat_id": req_id,
"parent_chat_id": -1,
"timestamp": 0.0,
"input_length": 24,
"output_length": 16,
"type": "text",
"turn": 0,
"hash_ids": [100 + req_id, 200 + req_id]
})
)
.unwrap();
}
let mut strict_cfg = bucketed_config(
trace_path.to_str().unwrap(),
tmp.to_str().unwrap(),
RouterMode::LeastLoaded,
);
strict_cfg.cluster.buckets = vec![
BucketConfig {
name: "short".into(),
input_length_min: 0,
input_length_max: 32,
num_instances: 1,
},
BucketConfig {
name: "long".into(),
input_length_min: 33,
input_length_max: 96,
num_instances: 1,
},
];
strict_cfg.cluster.global_router.mode = GlobalRouterMode::StrictInputLength;
let mut score_cfg = strict_cfg.clone();
score_cfg.cluster.global_router.mode = GlobalRouterMode::BucketScore;
score_cfg.cluster.global_router.length_penalty_weight = 1.0;
score_cfg.cluster.global_router.load_weight = 10.0;
score_cfg.cluster.global_router.cache_weight = 0.0;
let _ = driver::run(&strict_cfg, Some("strict_score_cmp")).expect("strict run");
let _ = driver::run(&score_cfg, Some("bucket_score_cmp")).expect("bucket score run");
let strict_log =
std::fs::read_to_string(tmp.join("strict_score_cmp/routing_log.jsonl")).unwrap();
let score_log =
std::fs::read_to_string(tmp.join("bucket_score_cmp/routing_log.jsonl")).unwrap();
assert!(strict_log.contains("\"chosen_bucket\":0"));
assert!(score_log.contains("\"global_mode\":\"bucket_score\""));
assert!(score_log.contains("\"chosen_bucket\":1"));
}