feat: add bucket score global router
This commit is contained in:
@@ -385,3 +385,69 @@ fn bucketed_configs_are_rejected_by_legacy_fixed_placement_paths() {
|
||||
.expect_err("bucketed replay should fail");
|
||||
assert!(err.to_string().contains("cluster.buckets"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bucket_score_can_deviate_from_strict_length_bucket() {
|
||||
let tmp = std::env::temp_dir().join("kvcache_sim_bucket_score");
|
||||
let _ = std::fs::remove_dir_all(&tmp);
|
||||
std::fs::create_dir_all(&tmp).unwrap();
|
||||
let trace_path = tmp.join("trace.jsonl");
|
||||
|
||||
let mut f = std::fs::File::create(&trace_path).unwrap();
|
||||
for req_id in 0..3 {
|
||||
writeln!(
|
||||
f,
|
||||
"{}",
|
||||
serde_json::json!({
|
||||
"chat_id": req_id,
|
||||
"parent_chat_id": -1,
|
||||
"timestamp": 0.0,
|
||||
"input_length": 24,
|
||||
"output_length": 16,
|
||||
"type": "text",
|
||||
"turn": 0,
|
||||
"hash_ids": [100 + req_id, 200 + req_id]
|
||||
})
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let mut strict_cfg = bucketed_config(
|
||||
trace_path.to_str().unwrap(),
|
||||
tmp.to_str().unwrap(),
|
||||
RouterMode::LeastLoaded,
|
||||
);
|
||||
strict_cfg.cluster.buckets = vec![
|
||||
BucketConfig {
|
||||
name: "short".into(),
|
||||
input_length_min: 0,
|
||||
input_length_max: 32,
|
||||
num_instances: 1,
|
||||
},
|
||||
BucketConfig {
|
||||
name: "long".into(),
|
||||
input_length_min: 33,
|
||||
input_length_max: 96,
|
||||
num_instances: 1,
|
||||
},
|
||||
];
|
||||
strict_cfg.cluster.global_router.mode = GlobalRouterMode::StrictInputLength;
|
||||
|
||||
let mut score_cfg = strict_cfg.clone();
|
||||
score_cfg.cluster.global_router.mode = GlobalRouterMode::BucketScore;
|
||||
score_cfg.cluster.global_router.length_penalty_weight = 1.0;
|
||||
score_cfg.cluster.global_router.load_weight = 10.0;
|
||||
score_cfg.cluster.global_router.cache_weight = 0.0;
|
||||
|
||||
let _ = driver::run(&strict_cfg, Some("strict_score_cmp")).expect("strict run");
|
||||
let _ = driver::run(&score_cfg, Some("bucket_score_cmp")).expect("bucket score run");
|
||||
|
||||
let strict_log =
|
||||
std::fs::read_to_string(tmp.join("strict_score_cmp/routing_log.jsonl")).unwrap();
|
||||
let score_log =
|
||||
std::fs::read_to_string(tmp.join("bucket_score_cmp/routing_log.jsonl")).unwrap();
|
||||
|
||||
assert!(strict_log.contains("\"chosen_bucket\":0"));
|
||||
assert!(score_log.contains("\"global_mode\":\"bucket_score\""));
|
||||
assert!(score_log.contains("\"chosen_bucket\":1"));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user