feat: add bucket score global router
This commit is contained in:
@@ -52,7 +52,7 @@ impl BucketedService {
|
|||||||
let bucket_views = self
|
let bucket_views = self
|
||||||
.buckets
|
.buckets
|
||||||
.iter()
|
.iter()
|
||||||
.map(|bucket| bucket.cluster.bucket_view(bucket.id, &bucket.cfg))
|
.map(|bucket| bucket.cluster.bucket_view(bucket.id, &bucket.cfg, req, now))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let global = self.global_router.route(req, &bucket_views, now)?;
|
let global = self.global_router.route(req, &bucket_views, now)?;
|
||||||
let bucket = &mut self.buckets[global.chosen_bucket as usize];
|
let bucket = &mut self.buckets[global.chosen_bucket as usize];
|
||||||
@@ -213,21 +213,4 @@ mod tests {
|
|||||||
assert!(err.to_string().contains("no bucket"));
|
assert!(err.to_string().contains("no bucket"));
|
||||||
assert!(err.to_string().contains("input_length=36"));
|
assert!(err.to_string().contains("input_length=36"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn bucket_score_placeholder_reports_strict_fallback() {
|
|
||||||
let mut cfg = test_config();
|
|
||||||
cfg.cluster.global_router.mode = GlobalRouterMode::BucketScore;
|
|
||||||
let mut service = BucketedService::new(&cfg, &cfg.model);
|
|
||||||
|
|
||||||
let stats = service
|
|
||||||
.route_and_admit(&req(4, 24, &[30, 31]), 0.0)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(stats.decision.global_mode, "strict_input_length");
|
|
||||||
assert!(stats
|
|
||||||
.decision
|
|
||||||
.global_reason
|
|
||||||
.contains("bucket_score is not implemented"));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -171,7 +171,19 @@ impl Cluster {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn bucket_view(&self, bucket_id: BucketId, cfg: &BucketConfig) -> BucketView {
|
pub fn bucket_view(
|
||||||
|
&self,
|
||||||
|
bucket_id: BucketId,
|
||||||
|
cfg: &BucketConfig,
|
||||||
|
req: &RequestRecord,
|
||||||
|
now: f64,
|
||||||
|
) -> BucketView {
|
||||||
|
let predicted_prefix = self
|
||||||
|
.meta_store
|
||||||
|
.score_prefix(&req.hash_ids, now, self.instances.len())
|
||||||
|
.into_iter()
|
||||||
|
.max()
|
||||||
|
.unwrap_or(0);
|
||||||
BucketView {
|
BucketView {
|
||||||
id: bucket_id,
|
id: bucket_id,
|
||||||
input_length_min: cfg.input_length_min,
|
input_length_min: cfg.input_length_min,
|
||||||
@@ -179,6 +191,7 @@ impl Cluster {
|
|||||||
num_instances: self.instances.len() as u32,
|
num_instances: self.instances.len() as u32,
|
||||||
total_queue_len: self.instances.iter().map(Instance::queue_len).sum(),
|
total_queue_len: self.instances.iter().map(Instance::queue_len).sum(),
|
||||||
total_load_blocks: self.instances.iter().map(|inst| inst.kv_blocks_used).sum(),
|
total_load_blocks: self.instances.iter().map(|inst| inst.kv_blocks_used).sum(),
|
||||||
|
predicted_prefix,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ pub struct BucketView {
|
|||||||
pub num_instances: u32,
|
pub num_instances: u32,
|
||||||
pub total_queue_len: u32,
|
pub total_queue_len: u32,
|
||||||
pub total_load_blocks: u32,
|
pub total_load_blocks: u32,
|
||||||
|
pub predicted_prefix: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
@@ -24,7 +25,9 @@ pub struct BucketCandidate {
|
|||||||
pub num_instances: u32,
|
pub num_instances: u32,
|
||||||
pub total_queue_len: u32,
|
pub total_queue_len: u32,
|
||||||
pub total_load_blocks: u32,
|
pub total_load_blocks: u32,
|
||||||
|
pub predicted_prefix: u32,
|
||||||
pub matches_input_len: bool,
|
pub matches_input_len: bool,
|
||||||
|
pub score: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
@@ -92,8 +95,16 @@ impl GlobalRouter for StrictInputLengthRouter {
|
|||||||
num_instances: view.num_instances,
|
num_instances: view.num_instances,
|
||||||
total_queue_len: view.total_queue_len,
|
total_queue_len: view.total_queue_len,
|
||||||
total_load_blocks: view.total_load_blocks,
|
total_load_blocks: view.total_load_blocks,
|
||||||
|
predicted_prefix: view.predicted_prefix,
|
||||||
matches_input_len: view.input_length_min <= req.input_len
|
matches_input_len: view.input_length_min <= req.input_len
|
||||||
&& req.input_len <= view.input_length_max,
|
&& req.input_len <= view.input_length_max,
|
||||||
|
score: if view.input_length_min <= req.input_len
|
||||||
|
&& req.input_len <= view.input_length_max
|
||||||
|
{
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
f64::INFINITY
|
||||||
|
},
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
@@ -131,15 +142,230 @@ impl GlobalRouter for StrictInputLengthRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct BucketScoreRouter {
|
||||||
|
length_penalty_weight: f64,
|
||||||
|
load_weight: f64,
|
||||||
|
cache_weight: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BucketScoreRouter {
|
||||||
|
fn new(full: &Config) -> Self {
|
||||||
|
Self {
|
||||||
|
length_penalty_weight: full.cluster.global_router.length_penalty_weight,
|
||||||
|
load_weight: full.cluster.global_router.load_weight,
|
||||||
|
cache_weight: full.cluster.global_router.cache_weight,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn length_penalty(&self, req: &RequestRecord, bucket: &BucketView) -> f64 {
|
||||||
|
if req.input_len < bucket.input_length_min {
|
||||||
|
(bucket.input_length_min - req.input_len) as f64
|
||||||
|
} else if req.input_len > bucket.input_length_max {
|
||||||
|
(req.input_len - bucket.input_length_max) as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GlobalRouter for BucketScoreRouter {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"bucket_score"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn route(
|
||||||
|
&mut self,
|
||||||
|
req: &RequestRecord,
|
||||||
|
buckets: &[BucketView],
|
||||||
|
_now: f64,
|
||||||
|
) -> Result<GlobalRouteDecision> {
|
||||||
|
let mut chosen_bucket = None;
|
||||||
|
let mut best_score = f64::INFINITY;
|
||||||
|
let mut candidates = Vec::with_capacity(buckets.len());
|
||||||
|
|
||||||
|
for bucket in buckets {
|
||||||
|
let length_penalty = self.length_penalty(req, bucket);
|
||||||
|
let miss = req
|
||||||
|
.hash_ids
|
||||||
|
.len()
|
||||||
|
.saturating_sub(bucket.predicted_prefix as usize) as f64;
|
||||||
|
let score = self.length_penalty_weight * length_penalty
|
||||||
|
+ self.load_weight * bucket.total_queue_len as f64
|
||||||
|
+ self.cache_weight * miss;
|
||||||
|
|
||||||
|
candidates.push(BucketCandidate {
|
||||||
|
bucket: bucket.id,
|
||||||
|
input_length_min: bucket.input_length_min,
|
||||||
|
input_length_max: bucket.input_length_max,
|
||||||
|
num_instances: bucket.num_instances,
|
||||||
|
total_queue_len: bucket.total_queue_len,
|
||||||
|
total_load_blocks: bucket.total_load_blocks,
|
||||||
|
predicted_prefix: bucket.predicted_prefix,
|
||||||
|
matches_input_len: bucket.input_length_min <= req.input_len
|
||||||
|
&& req.input_len <= bucket.input_length_max,
|
||||||
|
score,
|
||||||
|
});
|
||||||
|
|
||||||
|
let better = score < best_score
|
||||||
|
|| (score == best_score && chosen_bucket.is_none_or(|best| bucket.id < best));
|
||||||
|
if better {
|
||||||
|
best_score = score;
|
||||||
|
chosen_bucket = Some(bucket.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(GlobalRouteDecision {
|
||||||
|
req_id: req.req_id,
|
||||||
|
mode: self.name(),
|
||||||
|
chosen_bucket: chosen_bucket.ok_or_else(|| anyhow!("no buckets available"))?,
|
||||||
|
candidates,
|
||||||
|
reason: "weighted length/load/cache bucket score",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn build(full: &Config) -> Box<dyn GlobalRouter> {
|
pub fn build(full: &Config) -> Box<dyn GlobalRouter> {
|
||||||
match full.cluster.global_router.mode {
|
match full.cluster.global_router.mode {
|
||||||
GlobalRouterMode::StrictInputLength => Box::new(StrictInputLengthRouter::new(
|
GlobalRouterMode::StrictInputLength => Box::new(StrictInputLengthRouter::new(
|
||||||
"strict_input_length",
|
"strict_input_length",
|
||||||
"unique bucket range contains input_length",
|
"unique bucket range contains input_length",
|
||||||
)) as Box<dyn GlobalRouter>,
|
)) as Box<dyn GlobalRouter>,
|
||||||
GlobalRouterMode::BucketScore => Box::new(StrictInputLengthRouter::new(
|
GlobalRouterMode::BucketScore => {
|
||||||
"strict_input_length",
|
Box::new(BucketScoreRouter::new(full)) as Box<dyn GlobalRouter>
|
||||||
"bucket_score is not implemented in Task 2; falling back to strict_input_length",
|
}
|
||||||
)) as Box<dyn GlobalRouter>,
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::config::{
|
||||||
|
ClusterConfig, GlobalRouterConfig, MetaStoreConfig, RouterConfig, RouterMode,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn cfg() -> Config {
|
||||||
|
Config {
|
||||||
|
model: crate::config::ModelConfig::default(),
|
||||||
|
hardware: crate::config::HardwareConfig {
|
||||||
|
gpu_flops: 1.0,
|
||||||
|
gpu_fp8_flops: 0.0,
|
||||||
|
gpu_fp4_flops: 0.0,
|
||||||
|
gpu_mem_bw: 1.0,
|
||||||
|
hbm_bytes: 1.0,
|
||||||
|
dram_bytes: 1.0,
|
||||||
|
host_dram_bw: 1.0,
|
||||||
|
pcie_bw: 1.0,
|
||||||
|
pcie_latency_us: 1.0,
|
||||||
|
rdma_bw: 1.0,
|
||||||
|
rdma_latency_us: 1.0,
|
||||||
|
intra_node_tp_bw: 1.0,
|
||||||
|
intra_node_tp_latency_us: 1.0,
|
||||||
|
tp_degree: 1,
|
||||||
|
max_batch_slots: 1,
|
||||||
|
prefill_chunk_tokens: 1,
|
||||||
|
},
|
||||||
|
calibration: crate::config::CalibrationConfig::default(),
|
||||||
|
cluster: ClusterConfig {
|
||||||
|
num_instances: None,
|
||||||
|
buckets: Vec::new(),
|
||||||
|
global_router: GlobalRouterConfig {
|
||||||
|
mode: GlobalRouterMode::BucketScore,
|
||||||
|
length_penalty_weight: 1.0,
|
||||||
|
load_weight: 1.0,
|
||||||
|
cache_weight: 1.0,
|
||||||
|
},
|
||||||
|
meta_store: MetaStoreConfig { ttl_seconds: 1.0 },
|
||||||
|
router: RouterConfig {
|
||||||
|
mode: RouterMode::LeastLoaded,
|
||||||
|
precise_probe_latency_us: 1.0,
|
||||||
|
precise_probe_topk: 1,
|
||||||
|
load_alpha: 1.0,
|
||||||
|
score_alpha: 1.0,
|
||||||
|
score_beta: 1.0,
|
||||||
|
prefix_k: 8,
|
||||||
|
affinity_fan_out: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
sim: crate::config::SimConfig {
|
||||||
|
trace_path: String::new(),
|
||||||
|
max_requests: None,
|
||||||
|
output_dir: String::new(),
|
||||||
|
sample_interval_s: 0.0,
|
||||||
|
seed: 0,
|
||||||
|
input_length_min: None,
|
||||||
|
input_length_max: None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn req(input_len: u32) -> RequestRecord {
|
||||||
|
RequestRecord {
|
||||||
|
req_id: 1,
|
||||||
|
chat_id: 0,
|
||||||
|
parent_chat_id: -1,
|
||||||
|
turn: 0,
|
||||||
|
arrival: 0.0,
|
||||||
|
input_len,
|
||||||
|
output_len: 16,
|
||||||
|
hash_ids: vec![10, 11, 12],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bucket_score_prefers_matching_bucket_when_load_is_equal() {
|
||||||
|
let mut router = BucketScoreRouter::new(&cfg());
|
||||||
|
let buckets = vec![
|
||||||
|
BucketView {
|
||||||
|
id: 0,
|
||||||
|
input_length_min: 0,
|
||||||
|
input_length_max: 32,
|
||||||
|
num_instances: 2,
|
||||||
|
total_queue_len: 1,
|
||||||
|
total_load_blocks: 0,
|
||||||
|
predicted_prefix: 0,
|
||||||
|
},
|
||||||
|
BucketView {
|
||||||
|
id: 1,
|
||||||
|
input_length_min: 33,
|
||||||
|
input_length_max: 96,
|
||||||
|
num_instances: 2,
|
||||||
|
total_queue_len: 1,
|
||||||
|
total_load_blocks: 0,
|
||||||
|
predicted_prefix: 0,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
let decision = router.route(&req(24), &buckets, 0.0).unwrap();
|
||||||
|
assert_eq!(decision.chosen_bucket, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bucket_score_can_override_length_match_when_load_gap_is_large() {
|
||||||
|
let mut full = cfg();
|
||||||
|
full.cluster.global_router.load_weight = 5.0;
|
||||||
|
full.cluster.global_router.cache_weight = 1.0;
|
||||||
|
full.cluster.global_router.length_penalty_weight = 1.0;
|
||||||
|
let mut router = BucketScoreRouter::new(&full);
|
||||||
|
let buckets = vec![
|
||||||
|
BucketView {
|
||||||
|
id: 0,
|
||||||
|
input_length_min: 0,
|
||||||
|
input_length_max: 32,
|
||||||
|
num_instances: 2,
|
||||||
|
total_queue_len: 20,
|
||||||
|
total_load_blocks: 0,
|
||||||
|
predicted_prefix: 0,
|
||||||
|
},
|
||||||
|
BucketView {
|
||||||
|
id: 1,
|
||||||
|
input_length_min: 33,
|
||||||
|
input_length_max: 96,
|
||||||
|
num_instances: 2,
|
||||||
|
total_queue_len: 0,
|
||||||
|
total_load_blocks: 0,
|
||||||
|
predicted_prefix: 2,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
let decision = router.route(&req(24), &buckets, 0.0).unwrap();
|
||||||
|
assert_eq!(decision.chosen_bucket, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -385,3 +385,69 @@ fn bucketed_configs_are_rejected_by_legacy_fixed_placement_paths() {
|
|||||||
.expect_err("bucketed replay should fail");
|
.expect_err("bucketed replay should fail");
|
||||||
assert!(err.to_string().contains("cluster.buckets"));
|
assert!(err.to_string().contains("cluster.buckets"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bucket_score_can_deviate_from_strict_length_bucket() {
|
||||||
|
let tmp = std::env::temp_dir().join("kvcache_sim_bucket_score");
|
||||||
|
let _ = std::fs::remove_dir_all(&tmp);
|
||||||
|
std::fs::create_dir_all(&tmp).unwrap();
|
||||||
|
let trace_path = tmp.join("trace.jsonl");
|
||||||
|
|
||||||
|
let mut f = std::fs::File::create(&trace_path).unwrap();
|
||||||
|
for req_id in 0..3 {
|
||||||
|
writeln!(
|
||||||
|
f,
|
||||||
|
"{}",
|
||||||
|
serde_json::json!({
|
||||||
|
"chat_id": req_id,
|
||||||
|
"parent_chat_id": -1,
|
||||||
|
"timestamp": 0.0,
|
||||||
|
"input_length": 24,
|
||||||
|
"output_length": 16,
|
||||||
|
"type": "text",
|
||||||
|
"turn": 0,
|
||||||
|
"hash_ids": [100 + req_id, 200 + req_id]
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut strict_cfg = bucketed_config(
|
||||||
|
trace_path.to_str().unwrap(),
|
||||||
|
tmp.to_str().unwrap(),
|
||||||
|
RouterMode::LeastLoaded,
|
||||||
|
);
|
||||||
|
strict_cfg.cluster.buckets = vec![
|
||||||
|
BucketConfig {
|
||||||
|
name: "short".into(),
|
||||||
|
input_length_min: 0,
|
||||||
|
input_length_max: 32,
|
||||||
|
num_instances: 1,
|
||||||
|
},
|
||||||
|
BucketConfig {
|
||||||
|
name: "long".into(),
|
||||||
|
input_length_min: 33,
|
||||||
|
input_length_max: 96,
|
||||||
|
num_instances: 1,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
strict_cfg.cluster.global_router.mode = GlobalRouterMode::StrictInputLength;
|
||||||
|
|
||||||
|
let mut score_cfg = strict_cfg.clone();
|
||||||
|
score_cfg.cluster.global_router.mode = GlobalRouterMode::BucketScore;
|
||||||
|
score_cfg.cluster.global_router.length_penalty_weight = 1.0;
|
||||||
|
score_cfg.cluster.global_router.load_weight = 10.0;
|
||||||
|
score_cfg.cluster.global_router.cache_weight = 0.0;
|
||||||
|
|
||||||
|
let _ = driver::run(&strict_cfg, Some("strict_score_cmp")).expect("strict run");
|
||||||
|
let _ = driver::run(&score_cfg, Some("bucket_score_cmp")).expect("bucket score run");
|
||||||
|
|
||||||
|
let strict_log =
|
||||||
|
std::fs::read_to_string(tmp.join("strict_score_cmp/routing_log.jsonl")).unwrap();
|
||||||
|
let score_log =
|
||||||
|
std::fs::read_to_string(tmp.join("bucket_score_cmp/routing_log.jsonl")).unwrap();
|
||||||
|
|
||||||
|
assert!(strict_log.contains("\"chosen_bucket\":0"));
|
||||||
|
assert!(score_log.contains("\"global_mode\":\"bucket_score\""));
|
||||||
|
assert!(score_log.contains("\"chosen_bucket\":1"));
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user