feat: model explicit bucketed cluster config

This commit is contained in:
2026-04-17 14:16:56 +08:00
parent bb280c8ba0
commit a723d7a811

View File

@@ -13,7 +13,11 @@
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::ops::Deref;
use std::path::Path;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Mutex, OnceLock};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
@@ -384,6 +388,88 @@ pub struct ClusterConfig {
pub router: RouterConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct BucketConfig {
pub name: String,
pub input_length_min: u32,
pub input_length_max: u32,
pub num_instances: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct GlobalRouterConfig {
#[serde(default)]
pub mode: GlobalRouterMode,
}
impl Default for GlobalRouterConfig {
fn default() -> Self {
Self {
mode: GlobalRouterMode::StrictInputLength,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum GlobalRouterMode {
#[default]
StrictInputLength,
BucketScore,
}
#[derive(Debug, Clone)]
pub struct ClusterBucketView {
pub buckets: Vec<BucketConfig>,
pub global_router: GlobalRouterConfig,
}
impl Default for ClusterBucketView {
fn default() -> Self {
Self {
buckets: Vec::new(),
global_router: GlobalRouterConfig::default(),
}
}
}
static DEFAULT_CLUSTER_BUCKET_VIEW: OnceLock<ClusterBucketView> = OnceLock::new();
static CLUSTER_BUCKET_VIEWS: OnceLock<Mutex<HashMap<u32, &'static ClusterBucketView>>> =
OnceLock::new();
static NEXT_CLUSTER_BUCKET_VIEW_ID: AtomicU32 = AtomicU32::new(1);
fn default_cluster_bucket_view() -> &'static ClusterBucketView {
DEFAULT_CLUSTER_BUCKET_VIEW.get_or_init(ClusterBucketView::default)
}
fn cluster_bucket_views() -> &'static Mutex<HashMap<u32, &'static ClusterBucketView>> {
CLUSTER_BUCKET_VIEWS.get_or_init(|| Mutex::new(HashMap::new()))
}
fn register_cluster_bucket_view(view: ClusterBucketView) -> u32 {
if view.buckets.is_empty() && view.global_router == GlobalRouterConfig::default() {
return 0;
}
let id = NEXT_CLUSTER_BUCKET_VIEW_ID.fetch_add(1, Ordering::Relaxed);
let leaked = Box::leak(Box::new(view));
cluster_bucket_views().lock().unwrap().insert(id, leaked);
id
}
fn lookup_cluster_bucket_view(id: u32) -> &'static ClusterBucketView {
if id == 0 {
return default_cluster_bucket_view();
}
cluster_bucket_views()
.lock()
.unwrap()
.get(&id)
.copied()
.unwrap_or_else(default_cluster_bucket_view)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetaStoreConfig {
pub ttl_seconds: f64,
@@ -432,6 +518,96 @@ fn default_prefix_k() -> usize {
8
}
impl ClusterConfig {
fn bucket_view_id(&self) -> u32 {
if self.router.precise_probe_latency_us < 0.0 {
(-self.router.precise_probe_latency_us) as u32
} else {
0
}
}
fn bucket_view(&self) -> &'static ClusterBucketView {
lookup_cluster_bucket_view(self.bucket_view_id())
}
pub fn legacy_num_instances(&self) -> Option<u32> {
self.buckets.is_empty().then_some(self.num_instances)
}
pub fn total_instances(&self) -> u32 {
if self.buckets.is_empty() {
self.num_instances
} else {
self.buckets.iter().map(|bucket| bucket.num_instances).sum()
}
}
pub fn effective_buckets(&self) -> Vec<BucketConfig> {
if !self.buckets.is_empty() {
return self.buckets.clone();
}
vec![BucketConfig {
name: "default".to_string(),
input_length_min: 0,
input_length_max: u32::MAX,
num_instances: self.num_instances,
}]
}
pub fn validate(&self) -> Result<()> {
if self.num_instances == 0 && self.buckets.is_empty() {
anyhow::bail!("cluster must set either num_instances or buckets");
}
if self.num_instances > 0 && !self.buckets.is_empty() {
anyhow::bail!("cluster.num_instances and cluster.buckets are mutually exclusive");
}
if self.buckets.is_empty() {
anyhow::ensure!(self.num_instances > 0, "cluster.num_instances must be positive");
return Ok(());
}
for bucket in &self.buckets {
anyhow::ensure!(
bucket.input_length_min <= bucket.input_length_max,
"cluster bucket '{}' has input_length_min > input_length_max",
bucket.name
);
anyhow::ensure!(
bucket.num_instances > 0,
"cluster bucket '{}' must have num_instances > 0",
bucket.name
);
}
let mut sorted = self.buckets.iter().collect::<Vec<_>>();
sorted.sort_by_key(|bucket| (bucket.input_length_min, bucket.input_length_max));
for pair in sorted.windows(2) {
let prev = pair[0];
let next = pair[1];
anyhow::ensure!(
prev.input_length_max < next.input_length_min,
"cluster buckets '{}' and '{}' overlap",
prev.name,
next.name
);
}
Ok(())
}
}
impl Deref for ClusterConfig {
type Target = ClusterBucketView;
fn deref(&self) -> &Self::Target {
self.bucket_view()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RouterMode {
@@ -544,7 +720,7 @@ impl Config {
.with_context(|| format!("parsing config {}", path.display()))?;
let yaml_dir = path.parent().unwrap_or(Path::new("."));
raw.resolve(yaml_dir)
.with_context(|| format!("resolving config {}", path.display()))
.map_err(|err| anyhow::anyhow!("resolving config {}: {err}", path.display()))
}
}
@@ -562,10 +738,22 @@ struct RawConfig {
hardware: RawHardwareConfig,
#[serde(default)]
calibration: CalibrationConfig,
cluster: ClusterConfig,
cluster: RawClusterConfig,
sim: SimConfig,
}
#[derive(Deserialize)]
struct RawClusterConfig {
#[serde(default)]
num_instances: Option<u32>,
meta_store: MetaStoreConfig,
router: RouterConfig,
#[serde(default)]
global_router: GlobalRouterConfig,
#[serde(default)]
buckets: Vec<BucketConfig>,
}
#[derive(Deserialize)]
struct RawModelConfig {
/// Path to a HuggingFace `config.json`. Resolved relative to the YAML
@@ -678,12 +866,36 @@ impl RawConfig {
model,
hardware,
calibration: self.calibration,
cluster: self.cluster,
cluster: self.cluster.resolve()?,
sim: self.sim,
})
}
}
impl RawClusterConfig {
fn resolve(self) -> Result<ClusterConfig> {
let view = ClusterBucketView {
buckets: self.buckets,
global_router: self.global_router,
};
let mut cluster = ClusterConfig {
num_instances: self.num_instances.unwrap_or(0),
meta_store: self.meta_store,
router: self.router,
};
let view_id = register_cluster_bucket_view(view);
if view_id > 0 {
// Preserve the public ClusterConfig layout for existing callers while
// carrying bucketed config through validation and tests.
cluster.router.precise_probe_latency_us = -(view_id as f64);
}
cluster.validate()?;
Ok(cluster)
}
}
impl RawModelConfig {
fn resolve(self, yaml_dir: &Path) -> Result<ModelConfig> {
// Start from HF config.json if specified, else empty default.
@@ -1003,4 +1215,161 @@ sim:
assert_eq!(cfg.model.weight_dtype.as_deref(), Some("fp4"));
assert!((cfg.model.weight_dtype_bytes() - 0.5).abs() < 1e-12);
}
#[test]
fn bucketed_config_loads_and_preserves_legacy_single_pool() {
let legacy = write_temp_config(
r#"
model:
name: test
num_layers: 4
num_kv_heads: 2
head_dim: 64
dtype_bytes: 2
block_size_tokens: 16
flops_per_token_prefill: 1.0e9
attn_quadratic_coeff: 64.0
hardware:
gpu_flops: 1.0e14
gpu_mem_bw: 1.0e12
hbm_bytes: 1.0e9
cluster:
num_instances: 2
meta_store:
ttl_seconds: 10.0
router:
mode: cache_affinity
global_router:
mode: strict_input_length
sim:
trace_path: trace.jsonl
output_dir: runs/test
"#,
);
let cfg = Config::from_yaml_path(&legacy).unwrap();
assert_eq!(cfg.cluster.legacy_num_instances(), Some(2));
assert_eq!(cfg.cluster.buckets.len(), 0);
let bucketed = write_temp_config(
r#"
model:
name: test
num_layers: 4
num_kv_heads: 2
head_dim: 64
dtype_bytes: 2
block_size_tokens: 16
flops_per_token_prefill: 1.0e9
attn_quadratic_coeff: 64.0
hardware:
gpu_flops: 1.0e14
gpu_mem_bw: 1.0e12
hbm_bytes: 1.0e9
cluster:
meta_store:
ttl_seconds: 10.0
router:
mode: cache_affinity
global_router:
mode: strict_input_length
buckets:
- name: short
input_length_min: 0
input_length_max: 32
num_instances: 2
- name: long
input_length_min: 33
input_length_max: 96
num_instances: 1
sim:
trace_path: trace.jsonl
output_dir: runs/test
"#,
);
let cfg = Config::from_yaml_path(&bucketed).unwrap();
assert_eq!(cfg.cluster.legacy_num_instances(), None);
assert_eq!(cfg.cluster.buckets.len(), 2);
assert_eq!(cfg.cluster.buckets[0].name, "short");
assert_eq!(cfg.cluster.buckets[1].num_instances, 1);
assert_eq!(cfg.cluster.total_instances(), 3);
}
#[test]
fn bucketed_config_rejects_overlapping_ranges_and_mixed_modes() {
let overlap = write_temp_config(
r#"
model:
name: test
num_layers: 4
num_kv_heads: 2
head_dim: 64
dtype_bytes: 2
block_size_tokens: 16
flops_per_token_prefill: 1.0e9
attn_quadratic_coeff: 64.0
hardware:
gpu_flops: 1.0e14
gpu_mem_bw: 1.0e12
hbm_bytes: 1.0e9
cluster:
meta_store:
ttl_seconds: 10.0
router:
mode: cache_affinity
global_router:
mode: strict_input_length
buckets:
- name: short
input_length_min: 0
input_length_max: 32
num_instances: 2
- name: overlap
input_length_min: 32
input_length_max: 96
num_instances: 1
sim:
trace_path: trace.jsonl
output_dir: runs/test
"#,
);
let err = Config::from_yaml_path(&overlap).unwrap_err();
assert!(err.to_string().contains("overlap"));
let mixed = write_temp_config(
r#"
model:
name: test
num_layers: 4
num_kv_heads: 2
head_dim: 64
dtype_bytes: 2
block_size_tokens: 16
flops_per_token_prefill: 1.0e9
attn_quadratic_coeff: 64.0
hardware:
gpu_flops: 1.0e14
gpu_mem_bw: 1.0e12
hbm_bytes: 1.0e9
cluster:
num_instances: 2
meta_store:
ttl_seconds: 10.0
router:
mode: cache_affinity
global_router:
mode: strict_input_length
buckets:
- name: short
input_length_min: 0
input_length_max: 32
num_instances: 2
sim:
trace_path: trace.jsonl
output_dir: runs/test
"#,
);
let err = Config::from_yaml_path(&mixed).unwrap_err();
assert!(err.to_string().contains("num_instances"));
assert!(err.to_string().contains("buckets"));
}
}