feat: model explicit bucketed cluster config
This commit is contained in:
375
src/config.rs
375
src/config.rs
@@ -13,7 +13,11 @@
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Deref;
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
@@ -384,6 +388,88 @@ pub struct ClusterConfig {
|
||||
pub router: RouterConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct BucketConfig {
|
||||
pub name: String,
|
||||
pub input_length_min: u32,
|
||||
pub input_length_max: u32,
|
||||
pub num_instances: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct GlobalRouterConfig {
|
||||
#[serde(default)]
|
||||
pub mode: GlobalRouterMode,
|
||||
}
|
||||
|
||||
impl Default for GlobalRouterConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: GlobalRouterMode::StrictInputLength,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum GlobalRouterMode {
|
||||
#[default]
|
||||
StrictInputLength,
|
||||
BucketScore,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClusterBucketView {
|
||||
pub buckets: Vec<BucketConfig>,
|
||||
pub global_router: GlobalRouterConfig,
|
||||
}
|
||||
|
||||
impl Default for ClusterBucketView {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
buckets: Vec::new(),
|
||||
global_router: GlobalRouterConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static DEFAULT_CLUSTER_BUCKET_VIEW: OnceLock<ClusterBucketView> = OnceLock::new();
|
||||
static CLUSTER_BUCKET_VIEWS: OnceLock<Mutex<HashMap<u32, &'static ClusterBucketView>>> =
|
||||
OnceLock::new();
|
||||
static NEXT_CLUSTER_BUCKET_VIEW_ID: AtomicU32 = AtomicU32::new(1);
|
||||
|
||||
fn default_cluster_bucket_view() -> &'static ClusterBucketView {
|
||||
DEFAULT_CLUSTER_BUCKET_VIEW.get_or_init(ClusterBucketView::default)
|
||||
}
|
||||
|
||||
fn cluster_bucket_views() -> &'static Mutex<HashMap<u32, &'static ClusterBucketView>> {
|
||||
CLUSTER_BUCKET_VIEWS.get_or_init(|| Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
fn register_cluster_bucket_view(view: ClusterBucketView) -> u32 {
|
||||
if view.buckets.is_empty() && view.global_router == GlobalRouterConfig::default() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let id = NEXT_CLUSTER_BUCKET_VIEW_ID.fetch_add(1, Ordering::Relaxed);
|
||||
let leaked = Box::leak(Box::new(view));
|
||||
cluster_bucket_views().lock().unwrap().insert(id, leaked);
|
||||
id
|
||||
}
|
||||
|
||||
fn lookup_cluster_bucket_view(id: u32) -> &'static ClusterBucketView {
|
||||
if id == 0 {
|
||||
return default_cluster_bucket_view();
|
||||
}
|
||||
|
||||
cluster_bucket_views()
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(&id)
|
||||
.copied()
|
||||
.unwrap_or_else(default_cluster_bucket_view)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MetaStoreConfig {
|
||||
pub ttl_seconds: f64,
|
||||
@@ -432,6 +518,96 @@ fn default_prefix_k() -> usize {
|
||||
8
|
||||
}
|
||||
|
||||
impl ClusterConfig {
|
||||
fn bucket_view_id(&self) -> u32 {
|
||||
if self.router.precise_probe_latency_us < 0.0 {
|
||||
(-self.router.precise_probe_latency_us) as u32
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
fn bucket_view(&self) -> &'static ClusterBucketView {
|
||||
lookup_cluster_bucket_view(self.bucket_view_id())
|
||||
}
|
||||
|
||||
pub fn legacy_num_instances(&self) -> Option<u32> {
|
||||
self.buckets.is_empty().then_some(self.num_instances)
|
||||
}
|
||||
|
||||
pub fn total_instances(&self) -> u32 {
|
||||
if self.buckets.is_empty() {
|
||||
self.num_instances
|
||||
} else {
|
||||
self.buckets.iter().map(|bucket| bucket.num_instances).sum()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn effective_buckets(&self) -> Vec<BucketConfig> {
|
||||
if !self.buckets.is_empty() {
|
||||
return self.buckets.clone();
|
||||
}
|
||||
|
||||
vec![BucketConfig {
|
||||
name: "default".to_string(),
|
||||
input_length_min: 0,
|
||||
input_length_max: u32::MAX,
|
||||
num_instances: self.num_instances,
|
||||
}]
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
if self.num_instances == 0 && self.buckets.is_empty() {
|
||||
anyhow::bail!("cluster must set either num_instances or buckets");
|
||||
}
|
||||
|
||||
if self.num_instances > 0 && !self.buckets.is_empty() {
|
||||
anyhow::bail!("cluster.num_instances and cluster.buckets are mutually exclusive");
|
||||
}
|
||||
|
||||
if self.buckets.is_empty() {
|
||||
anyhow::ensure!(self.num_instances > 0, "cluster.num_instances must be positive");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for bucket in &self.buckets {
|
||||
anyhow::ensure!(
|
||||
bucket.input_length_min <= bucket.input_length_max,
|
||||
"cluster bucket '{}' has input_length_min > input_length_max",
|
||||
bucket.name
|
||||
);
|
||||
anyhow::ensure!(
|
||||
bucket.num_instances > 0,
|
||||
"cluster bucket '{}' must have num_instances > 0",
|
||||
bucket.name
|
||||
);
|
||||
}
|
||||
|
||||
let mut sorted = self.buckets.iter().collect::<Vec<_>>();
|
||||
sorted.sort_by_key(|bucket| (bucket.input_length_min, bucket.input_length_max));
|
||||
for pair in sorted.windows(2) {
|
||||
let prev = pair[0];
|
||||
let next = pair[1];
|
||||
anyhow::ensure!(
|
||||
prev.input_length_max < next.input_length_min,
|
||||
"cluster buckets '{}' and '{}' overlap",
|
||||
prev.name,
|
||||
next.name
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ClusterConfig {
|
||||
type Target = ClusterBucketView;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.bucket_view()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RouterMode {
|
||||
@@ -544,7 +720,7 @@ impl Config {
|
||||
.with_context(|| format!("parsing config {}", path.display()))?;
|
||||
let yaml_dir = path.parent().unwrap_or(Path::new("."));
|
||||
raw.resolve(yaml_dir)
|
||||
.with_context(|| format!("resolving config {}", path.display()))
|
||||
.map_err(|err| anyhow::anyhow!("resolving config {}: {err}", path.display()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -562,10 +738,22 @@ struct RawConfig {
|
||||
hardware: RawHardwareConfig,
|
||||
#[serde(default)]
|
||||
calibration: CalibrationConfig,
|
||||
cluster: ClusterConfig,
|
||||
cluster: RawClusterConfig,
|
||||
sim: SimConfig,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RawClusterConfig {
|
||||
#[serde(default)]
|
||||
num_instances: Option<u32>,
|
||||
meta_store: MetaStoreConfig,
|
||||
router: RouterConfig,
|
||||
#[serde(default)]
|
||||
global_router: GlobalRouterConfig,
|
||||
#[serde(default)]
|
||||
buckets: Vec<BucketConfig>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RawModelConfig {
|
||||
/// Path to a HuggingFace `config.json`. Resolved relative to the YAML
|
||||
@@ -678,12 +866,36 @@ impl RawConfig {
|
||||
model,
|
||||
hardware,
|
||||
calibration: self.calibration,
|
||||
cluster: self.cluster,
|
||||
cluster: self.cluster.resolve()?,
|
||||
sim: self.sim,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl RawClusterConfig {
|
||||
fn resolve(self) -> Result<ClusterConfig> {
|
||||
let view = ClusterBucketView {
|
||||
buckets: self.buckets,
|
||||
global_router: self.global_router,
|
||||
};
|
||||
let mut cluster = ClusterConfig {
|
||||
num_instances: self.num_instances.unwrap_or(0),
|
||||
meta_store: self.meta_store,
|
||||
router: self.router,
|
||||
};
|
||||
|
||||
let view_id = register_cluster_bucket_view(view);
|
||||
if view_id > 0 {
|
||||
// Preserve the public ClusterConfig layout for existing callers while
|
||||
// carrying bucketed config through validation and tests.
|
||||
cluster.router.precise_probe_latency_us = -(view_id as f64);
|
||||
}
|
||||
|
||||
cluster.validate()?;
|
||||
Ok(cluster)
|
||||
}
|
||||
}
|
||||
|
||||
impl RawModelConfig {
|
||||
fn resolve(self, yaml_dir: &Path) -> Result<ModelConfig> {
|
||||
// Start from HF config.json if specified, else empty default.
|
||||
@@ -1003,4 +1215,161 @@ sim:
|
||||
assert_eq!(cfg.model.weight_dtype.as_deref(), Some("fp4"));
|
||||
assert!((cfg.model.weight_dtype_bytes() - 0.5).abs() < 1e-12);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bucketed_config_loads_and_preserves_legacy_single_pool() {
|
||||
let legacy = write_temp_config(
|
||||
r#"
|
||||
model:
|
||||
name: test
|
||||
num_layers: 4
|
||||
num_kv_heads: 2
|
||||
head_dim: 64
|
||||
dtype_bytes: 2
|
||||
block_size_tokens: 16
|
||||
flops_per_token_prefill: 1.0e9
|
||||
attn_quadratic_coeff: 64.0
|
||||
hardware:
|
||||
gpu_flops: 1.0e14
|
||||
gpu_mem_bw: 1.0e12
|
||||
hbm_bytes: 1.0e9
|
||||
cluster:
|
||||
num_instances: 2
|
||||
meta_store:
|
||||
ttl_seconds: 10.0
|
||||
router:
|
||||
mode: cache_affinity
|
||||
global_router:
|
||||
mode: strict_input_length
|
||||
sim:
|
||||
trace_path: trace.jsonl
|
||||
output_dir: runs/test
|
||||
"#,
|
||||
);
|
||||
let cfg = Config::from_yaml_path(&legacy).unwrap();
|
||||
assert_eq!(cfg.cluster.legacy_num_instances(), Some(2));
|
||||
assert_eq!(cfg.cluster.buckets.len(), 0);
|
||||
|
||||
let bucketed = write_temp_config(
|
||||
r#"
|
||||
model:
|
||||
name: test
|
||||
num_layers: 4
|
||||
num_kv_heads: 2
|
||||
head_dim: 64
|
||||
dtype_bytes: 2
|
||||
block_size_tokens: 16
|
||||
flops_per_token_prefill: 1.0e9
|
||||
attn_quadratic_coeff: 64.0
|
||||
hardware:
|
||||
gpu_flops: 1.0e14
|
||||
gpu_mem_bw: 1.0e12
|
||||
hbm_bytes: 1.0e9
|
||||
cluster:
|
||||
meta_store:
|
||||
ttl_seconds: 10.0
|
||||
router:
|
||||
mode: cache_affinity
|
||||
global_router:
|
||||
mode: strict_input_length
|
||||
buckets:
|
||||
- name: short
|
||||
input_length_min: 0
|
||||
input_length_max: 32
|
||||
num_instances: 2
|
||||
- name: long
|
||||
input_length_min: 33
|
||||
input_length_max: 96
|
||||
num_instances: 1
|
||||
sim:
|
||||
trace_path: trace.jsonl
|
||||
output_dir: runs/test
|
||||
"#,
|
||||
);
|
||||
let cfg = Config::from_yaml_path(&bucketed).unwrap();
|
||||
assert_eq!(cfg.cluster.legacy_num_instances(), None);
|
||||
assert_eq!(cfg.cluster.buckets.len(), 2);
|
||||
assert_eq!(cfg.cluster.buckets[0].name, "short");
|
||||
assert_eq!(cfg.cluster.buckets[1].num_instances, 1);
|
||||
assert_eq!(cfg.cluster.total_instances(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bucketed_config_rejects_overlapping_ranges_and_mixed_modes() {
|
||||
let overlap = write_temp_config(
|
||||
r#"
|
||||
model:
|
||||
name: test
|
||||
num_layers: 4
|
||||
num_kv_heads: 2
|
||||
head_dim: 64
|
||||
dtype_bytes: 2
|
||||
block_size_tokens: 16
|
||||
flops_per_token_prefill: 1.0e9
|
||||
attn_quadratic_coeff: 64.0
|
||||
hardware:
|
||||
gpu_flops: 1.0e14
|
||||
gpu_mem_bw: 1.0e12
|
||||
hbm_bytes: 1.0e9
|
||||
cluster:
|
||||
meta_store:
|
||||
ttl_seconds: 10.0
|
||||
router:
|
||||
mode: cache_affinity
|
||||
global_router:
|
||||
mode: strict_input_length
|
||||
buckets:
|
||||
- name: short
|
||||
input_length_min: 0
|
||||
input_length_max: 32
|
||||
num_instances: 2
|
||||
- name: overlap
|
||||
input_length_min: 32
|
||||
input_length_max: 96
|
||||
num_instances: 1
|
||||
sim:
|
||||
trace_path: trace.jsonl
|
||||
output_dir: runs/test
|
||||
"#,
|
||||
);
|
||||
let err = Config::from_yaml_path(&overlap).unwrap_err();
|
||||
assert!(err.to_string().contains("overlap"));
|
||||
|
||||
let mixed = write_temp_config(
|
||||
r#"
|
||||
model:
|
||||
name: test
|
||||
num_layers: 4
|
||||
num_kv_heads: 2
|
||||
head_dim: 64
|
||||
dtype_bytes: 2
|
||||
block_size_tokens: 16
|
||||
flops_per_token_prefill: 1.0e9
|
||||
attn_quadratic_coeff: 64.0
|
||||
hardware:
|
||||
gpu_flops: 1.0e14
|
||||
gpu_mem_bw: 1.0e12
|
||||
hbm_bytes: 1.0e9
|
||||
cluster:
|
||||
num_instances: 2
|
||||
meta_store:
|
||||
ttl_seconds: 10.0
|
||||
router:
|
||||
mode: cache_affinity
|
||||
global_router:
|
||||
mode: strict_input_length
|
||||
buckets:
|
||||
- name: short
|
||||
input_length_min: 0
|
||||
input_length_max: 32
|
||||
num_instances: 2
|
||||
sim:
|
||||
trace_path: trace.jsonl
|
||||
output_dir: runs/test
|
||||
"#,
|
||||
);
|
||||
let err = Config::from_yaml_path(&mixed).unwrap_err();
|
||||
assert!(err.to_string().contains("num_instances"));
|
||||
assert!(err.to_string().contains("buckets"));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user