feat: add bucket score global router
This commit is contained in:
@@ -52,7 +52,7 @@ impl BucketedService {
|
||||
let bucket_views = self
|
||||
.buckets
|
||||
.iter()
|
||||
.map(|bucket| bucket.cluster.bucket_view(bucket.id, &bucket.cfg))
|
||||
.map(|bucket| bucket.cluster.bucket_view(bucket.id, &bucket.cfg, req, now))
|
||||
.collect::<Vec<_>>();
|
||||
let global = self.global_router.route(req, &bucket_views, now)?;
|
||||
let bucket = &mut self.buckets[global.chosen_bucket as usize];
|
||||
@@ -213,21 +213,4 @@ mod tests {
|
||||
assert!(err.to_string().contains("no bucket"));
|
||||
assert!(err.to_string().contains("input_length=36"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bucket_score_placeholder_reports_strict_fallback() {
|
||||
let mut cfg = test_config();
|
||||
cfg.cluster.global_router.mode = GlobalRouterMode::BucketScore;
|
||||
let mut service = BucketedService::new(&cfg, &cfg.model);
|
||||
|
||||
let stats = service
|
||||
.route_and_admit(&req(4, 24, &[30, 31]), 0.0)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(stats.decision.global_mode, "strict_input_length");
|
||||
assert!(stats
|
||||
.decision
|
||||
.global_reason
|
||||
.contains("bucket_score is not implemented"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,7 +171,19 @@ impl Cluster {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bucket_view(&self, bucket_id: BucketId, cfg: &BucketConfig) -> BucketView {
|
||||
pub fn bucket_view(
|
||||
&self,
|
||||
bucket_id: BucketId,
|
||||
cfg: &BucketConfig,
|
||||
req: &RequestRecord,
|
||||
now: f64,
|
||||
) -> BucketView {
|
||||
let predicted_prefix = self
|
||||
.meta_store
|
||||
.score_prefix(&req.hash_ids, now, self.instances.len())
|
||||
.into_iter()
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
BucketView {
|
||||
id: bucket_id,
|
||||
input_length_min: cfg.input_length_min,
|
||||
@@ -179,6 +191,7 @@ impl Cluster {
|
||||
num_instances: self.instances.len() as u32,
|
||||
total_queue_len: self.instances.iter().map(Instance::queue_len).sum(),
|
||||
total_load_blocks: self.instances.iter().map(|inst| inst.kv_blocks_used).sum(),
|
||||
predicted_prefix,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ pub struct BucketView {
|
||||
pub num_instances: u32,
|
||||
pub total_queue_len: u32,
|
||||
pub total_load_blocks: u32,
|
||||
pub predicted_prefix: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
@@ -24,7 +25,9 @@ pub struct BucketCandidate {
|
||||
pub num_instances: u32,
|
||||
pub total_queue_len: u32,
|
||||
pub total_load_blocks: u32,
|
||||
pub predicted_prefix: u32,
|
||||
pub matches_input_len: bool,
|
||||
pub score: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
@@ -92,8 +95,16 @@ impl GlobalRouter for StrictInputLengthRouter {
|
||||
num_instances: view.num_instances,
|
||||
total_queue_len: view.total_queue_len,
|
||||
total_load_blocks: view.total_load_blocks,
|
||||
predicted_prefix: view.predicted_prefix,
|
||||
matches_input_len: view.input_length_min <= req.input_len
|
||||
&& req.input_len <= view.input_length_max,
|
||||
score: if view.input_length_min <= req.input_len
|
||||
&& req.input_len <= view.input_length_max
|
||||
{
|
||||
0.0
|
||||
} else {
|
||||
f64::INFINITY
|
||||
},
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
@@ -131,15 +142,230 @@ impl GlobalRouter for StrictInputLengthRouter {
|
||||
}
|
||||
}
|
||||
|
||||
struct BucketScoreRouter {
|
||||
length_penalty_weight: f64,
|
||||
load_weight: f64,
|
||||
cache_weight: f64,
|
||||
}
|
||||
|
||||
impl BucketScoreRouter {
|
||||
fn new(full: &Config) -> Self {
|
||||
Self {
|
||||
length_penalty_weight: full.cluster.global_router.length_penalty_weight,
|
||||
load_weight: full.cluster.global_router.load_weight,
|
||||
cache_weight: full.cluster.global_router.cache_weight,
|
||||
}
|
||||
}
|
||||
|
||||
fn length_penalty(&self, req: &RequestRecord, bucket: &BucketView) -> f64 {
|
||||
if req.input_len < bucket.input_length_min {
|
||||
(bucket.input_length_min - req.input_len) as f64
|
||||
} else if req.input_len > bucket.input_length_max {
|
||||
(req.input_len - bucket.input_length_max) as f64
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalRouter for BucketScoreRouter {
|
||||
fn name(&self) -> &'static str {
|
||||
"bucket_score"
|
||||
}
|
||||
|
||||
fn route(
|
||||
&mut self,
|
||||
req: &RequestRecord,
|
||||
buckets: &[BucketView],
|
||||
_now: f64,
|
||||
) -> Result<GlobalRouteDecision> {
|
||||
let mut chosen_bucket = None;
|
||||
let mut best_score = f64::INFINITY;
|
||||
let mut candidates = Vec::with_capacity(buckets.len());
|
||||
|
||||
for bucket in buckets {
|
||||
let length_penalty = self.length_penalty(req, bucket);
|
||||
let miss = req
|
||||
.hash_ids
|
||||
.len()
|
||||
.saturating_sub(bucket.predicted_prefix as usize) as f64;
|
||||
let score = self.length_penalty_weight * length_penalty
|
||||
+ self.load_weight * bucket.total_queue_len as f64
|
||||
+ self.cache_weight * miss;
|
||||
|
||||
candidates.push(BucketCandidate {
|
||||
bucket: bucket.id,
|
||||
input_length_min: bucket.input_length_min,
|
||||
input_length_max: bucket.input_length_max,
|
||||
num_instances: bucket.num_instances,
|
||||
total_queue_len: bucket.total_queue_len,
|
||||
total_load_blocks: bucket.total_load_blocks,
|
||||
predicted_prefix: bucket.predicted_prefix,
|
||||
matches_input_len: bucket.input_length_min <= req.input_len
|
||||
&& req.input_len <= bucket.input_length_max,
|
||||
score,
|
||||
});
|
||||
|
||||
let better = score < best_score
|
||||
|| (score == best_score && chosen_bucket.is_none_or(|best| bucket.id < best));
|
||||
if better {
|
||||
best_score = score;
|
||||
chosen_bucket = Some(bucket.id);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(GlobalRouteDecision {
|
||||
req_id: req.req_id,
|
||||
mode: self.name(),
|
||||
chosen_bucket: chosen_bucket.ok_or_else(|| anyhow!("no buckets available"))?,
|
||||
candidates,
|
||||
reason: "weighted length/load/cache bucket score",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build(full: &Config) -> Box<dyn GlobalRouter> {
|
||||
match full.cluster.global_router.mode {
|
||||
GlobalRouterMode::StrictInputLength => Box::new(StrictInputLengthRouter::new(
|
||||
"strict_input_length",
|
||||
"unique bucket range contains input_length",
|
||||
)) as Box<dyn GlobalRouter>,
|
||||
GlobalRouterMode::BucketScore => Box::new(StrictInputLengthRouter::new(
|
||||
"strict_input_length",
|
||||
"bucket_score is not implemented in Task 2; falling back to strict_input_length",
|
||||
)) as Box<dyn GlobalRouter>,
|
||||
GlobalRouterMode::BucketScore => {
|
||||
Box::new(BucketScoreRouter::new(full)) as Box<dyn GlobalRouter>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{
|
||||
ClusterConfig, GlobalRouterConfig, MetaStoreConfig, RouterConfig, RouterMode,
|
||||
};
|
||||
|
||||
fn cfg() -> Config {
|
||||
Config {
|
||||
model: crate::config::ModelConfig::default(),
|
||||
hardware: crate::config::HardwareConfig {
|
||||
gpu_flops: 1.0,
|
||||
gpu_fp8_flops: 0.0,
|
||||
gpu_fp4_flops: 0.0,
|
||||
gpu_mem_bw: 1.0,
|
||||
hbm_bytes: 1.0,
|
||||
dram_bytes: 1.0,
|
||||
host_dram_bw: 1.0,
|
||||
pcie_bw: 1.0,
|
||||
pcie_latency_us: 1.0,
|
||||
rdma_bw: 1.0,
|
||||
rdma_latency_us: 1.0,
|
||||
intra_node_tp_bw: 1.0,
|
||||
intra_node_tp_latency_us: 1.0,
|
||||
tp_degree: 1,
|
||||
max_batch_slots: 1,
|
||||
prefill_chunk_tokens: 1,
|
||||
},
|
||||
calibration: crate::config::CalibrationConfig::default(),
|
||||
cluster: ClusterConfig {
|
||||
num_instances: None,
|
||||
buckets: Vec::new(),
|
||||
global_router: GlobalRouterConfig {
|
||||
mode: GlobalRouterMode::BucketScore,
|
||||
length_penalty_weight: 1.0,
|
||||
load_weight: 1.0,
|
||||
cache_weight: 1.0,
|
||||
},
|
||||
meta_store: MetaStoreConfig { ttl_seconds: 1.0 },
|
||||
router: RouterConfig {
|
||||
mode: RouterMode::LeastLoaded,
|
||||
precise_probe_latency_us: 1.0,
|
||||
precise_probe_topk: 1,
|
||||
load_alpha: 1.0,
|
||||
score_alpha: 1.0,
|
||||
score_beta: 1.0,
|
||||
prefix_k: 8,
|
||||
affinity_fan_out: 1,
|
||||
},
|
||||
},
|
||||
sim: crate::config::SimConfig {
|
||||
trace_path: String::new(),
|
||||
max_requests: None,
|
||||
output_dir: String::new(),
|
||||
sample_interval_s: 0.0,
|
||||
seed: 0,
|
||||
input_length_min: None,
|
||||
input_length_max: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn req(input_len: u32) -> RequestRecord {
|
||||
RequestRecord {
|
||||
req_id: 1,
|
||||
chat_id: 0,
|
||||
parent_chat_id: -1,
|
||||
turn: 0,
|
||||
arrival: 0.0,
|
||||
input_len,
|
||||
output_len: 16,
|
||||
hash_ids: vec![10, 11, 12],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bucket_score_prefers_matching_bucket_when_load_is_equal() {
|
||||
let mut router = BucketScoreRouter::new(&cfg());
|
||||
let buckets = vec![
|
||||
BucketView {
|
||||
id: 0,
|
||||
input_length_min: 0,
|
||||
input_length_max: 32,
|
||||
num_instances: 2,
|
||||
total_queue_len: 1,
|
||||
total_load_blocks: 0,
|
||||
predicted_prefix: 0,
|
||||
},
|
||||
BucketView {
|
||||
id: 1,
|
||||
input_length_min: 33,
|
||||
input_length_max: 96,
|
||||
num_instances: 2,
|
||||
total_queue_len: 1,
|
||||
total_load_blocks: 0,
|
||||
predicted_prefix: 0,
|
||||
},
|
||||
];
|
||||
let decision = router.route(&req(24), &buckets, 0.0).unwrap();
|
||||
assert_eq!(decision.chosen_bucket, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bucket_score_can_override_length_match_when_load_gap_is_large() {
|
||||
let mut full = cfg();
|
||||
full.cluster.global_router.load_weight = 5.0;
|
||||
full.cluster.global_router.cache_weight = 1.0;
|
||||
full.cluster.global_router.length_penalty_weight = 1.0;
|
||||
let mut router = BucketScoreRouter::new(&full);
|
||||
let buckets = vec![
|
||||
BucketView {
|
||||
id: 0,
|
||||
input_length_min: 0,
|
||||
input_length_max: 32,
|
||||
num_instances: 2,
|
||||
total_queue_len: 20,
|
||||
total_load_blocks: 0,
|
||||
predicted_prefix: 0,
|
||||
},
|
||||
BucketView {
|
||||
id: 1,
|
||||
input_length_min: 33,
|
||||
input_length_max: 96,
|
||||
num_instances: 2,
|
||||
total_queue_len: 0,
|
||||
total_load_blocks: 0,
|
||||
predicted_prefix: 2,
|
||||
},
|
||||
];
|
||||
let decision = router.route(&req(24), &buckets, 0.0).unwrap();
|
||||
assert_eq!(decision.chosen_bucket, 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,3 +385,69 @@ fn bucketed_configs_are_rejected_by_legacy_fixed_placement_paths() {
|
||||
.expect_err("bucketed replay should fail");
|
||||
assert!(err.to_string().contains("cluster.buckets"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bucket_score_can_deviate_from_strict_length_bucket() {
|
||||
let tmp = std::env::temp_dir().join("kvcache_sim_bucket_score");
|
||||
let _ = std::fs::remove_dir_all(&tmp);
|
||||
std::fs::create_dir_all(&tmp).unwrap();
|
||||
let trace_path = tmp.join("trace.jsonl");
|
||||
|
||||
let mut f = std::fs::File::create(&trace_path).unwrap();
|
||||
for req_id in 0..3 {
|
||||
writeln!(
|
||||
f,
|
||||
"{}",
|
||||
serde_json::json!({
|
||||
"chat_id": req_id,
|
||||
"parent_chat_id": -1,
|
||||
"timestamp": 0.0,
|
||||
"input_length": 24,
|
||||
"output_length": 16,
|
||||
"type": "text",
|
||||
"turn": 0,
|
||||
"hash_ids": [100 + req_id, 200 + req_id]
|
||||
})
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let mut strict_cfg = bucketed_config(
|
||||
trace_path.to_str().unwrap(),
|
||||
tmp.to_str().unwrap(),
|
||||
RouterMode::LeastLoaded,
|
||||
);
|
||||
strict_cfg.cluster.buckets = vec![
|
||||
BucketConfig {
|
||||
name: "short".into(),
|
||||
input_length_min: 0,
|
||||
input_length_max: 32,
|
||||
num_instances: 1,
|
||||
},
|
||||
BucketConfig {
|
||||
name: "long".into(),
|
||||
input_length_min: 33,
|
||||
input_length_max: 96,
|
||||
num_instances: 1,
|
||||
},
|
||||
];
|
||||
strict_cfg.cluster.global_router.mode = GlobalRouterMode::StrictInputLength;
|
||||
|
||||
let mut score_cfg = strict_cfg.clone();
|
||||
score_cfg.cluster.global_router.mode = GlobalRouterMode::BucketScore;
|
||||
score_cfg.cluster.global_router.length_penalty_weight = 1.0;
|
||||
score_cfg.cluster.global_router.load_weight = 10.0;
|
||||
score_cfg.cluster.global_router.cache_weight = 0.0;
|
||||
|
||||
let _ = driver::run(&strict_cfg, Some("strict_score_cmp")).expect("strict run");
|
||||
let _ = driver::run(&score_cfg, Some("bucket_score_cmp")).expect("bucket score run");
|
||||
|
||||
let strict_log =
|
||||
std::fs::read_to_string(tmp.join("strict_score_cmp/routing_log.jsonl")).unwrap();
|
||||
let score_log =
|
||||
std::fs::read_to_string(tmp.join("bucket_score_cmp/routing_log.jsonl")).unwrap();
|
||||
|
||||
assert!(strict_log.contains("\"chosen_bucket\":0"));
|
||||
assert!(score_log.contains("\"global_mode\":\"bucket_score\""));
|
||||
assert!(score_log.contains("\"chosen_bucket\":1"));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user