feat: model explicit bucketed cluster config
This commit is contained in:
375
src/config.rs
375
src/config.rs
@@ -13,7 +13,11 @@
|
|||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::ops::Deref;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::sync::atomic::{AtomicU32, Ordering};
|
||||||
|
use std::sync::{Mutex, OnceLock};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
@@ -384,6 +388,88 @@ pub struct ClusterConfig {
|
|||||||
pub router: RouterConfig,
|
pub router: RouterConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub struct BucketConfig {
|
||||||
|
pub name: String,
|
||||||
|
pub input_length_min: u32,
|
||||||
|
pub input_length_max: u32,
|
||||||
|
pub num_instances: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub struct GlobalRouterConfig {
|
||||||
|
#[serde(default)]
|
||||||
|
pub mode: GlobalRouterMode,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for GlobalRouterConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
mode: GlobalRouterMode::StrictInputLength,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum GlobalRouterMode {
|
||||||
|
#[default]
|
||||||
|
StrictInputLength,
|
||||||
|
BucketScore,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ClusterBucketView {
|
||||||
|
pub buckets: Vec<BucketConfig>,
|
||||||
|
pub global_router: GlobalRouterConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ClusterBucketView {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
buckets: Vec::new(),
|
||||||
|
global_router: GlobalRouterConfig::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static DEFAULT_CLUSTER_BUCKET_VIEW: OnceLock<ClusterBucketView> = OnceLock::new();
|
||||||
|
static CLUSTER_BUCKET_VIEWS: OnceLock<Mutex<HashMap<u32, &'static ClusterBucketView>>> =
|
||||||
|
OnceLock::new();
|
||||||
|
static NEXT_CLUSTER_BUCKET_VIEW_ID: AtomicU32 = AtomicU32::new(1);
|
||||||
|
|
||||||
|
fn default_cluster_bucket_view() -> &'static ClusterBucketView {
|
||||||
|
DEFAULT_CLUSTER_BUCKET_VIEW.get_or_init(ClusterBucketView::default)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cluster_bucket_views() -> &'static Mutex<HashMap<u32, &'static ClusterBucketView>> {
|
||||||
|
CLUSTER_BUCKET_VIEWS.get_or_init(|| Mutex::new(HashMap::new()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_cluster_bucket_view(view: ClusterBucketView) -> u32 {
|
||||||
|
if view.buckets.is_empty() && view.global_router == GlobalRouterConfig::default() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let id = NEXT_CLUSTER_BUCKET_VIEW_ID.fetch_add(1, Ordering::Relaxed);
|
||||||
|
let leaked = Box::leak(Box::new(view));
|
||||||
|
cluster_bucket_views().lock().unwrap().insert(id, leaked);
|
||||||
|
id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn lookup_cluster_bucket_view(id: u32) -> &'static ClusterBucketView {
|
||||||
|
if id == 0 {
|
||||||
|
return default_cluster_bucket_view();
|
||||||
|
}
|
||||||
|
|
||||||
|
cluster_bucket_views()
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.get(&id)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or_else(default_cluster_bucket_view)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct MetaStoreConfig {
|
pub struct MetaStoreConfig {
|
||||||
pub ttl_seconds: f64,
|
pub ttl_seconds: f64,
|
||||||
@@ -432,6 +518,96 @@ fn default_prefix_k() -> usize {
|
|||||||
8
|
8
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ClusterConfig {
|
||||||
|
fn bucket_view_id(&self) -> u32 {
|
||||||
|
if self.router.precise_probe_latency_us < 0.0 {
|
||||||
|
(-self.router.precise_probe_latency_us) as u32
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bucket_view(&self) -> &'static ClusterBucketView {
|
||||||
|
lookup_cluster_bucket_view(self.bucket_view_id())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn legacy_num_instances(&self) -> Option<u32> {
|
||||||
|
self.buckets.is_empty().then_some(self.num_instances)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn total_instances(&self) -> u32 {
|
||||||
|
if self.buckets.is_empty() {
|
||||||
|
self.num_instances
|
||||||
|
} else {
|
||||||
|
self.buckets.iter().map(|bucket| bucket.num_instances).sum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn effective_buckets(&self) -> Vec<BucketConfig> {
|
||||||
|
if !self.buckets.is_empty() {
|
||||||
|
return self.buckets.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
vec![BucketConfig {
|
||||||
|
name: "default".to_string(),
|
||||||
|
input_length_min: 0,
|
||||||
|
input_length_max: u32::MAX,
|
||||||
|
num_instances: self.num_instances,
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn validate(&self) -> Result<()> {
|
||||||
|
if self.num_instances == 0 && self.buckets.is_empty() {
|
||||||
|
anyhow::bail!("cluster must set either num_instances or buckets");
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.num_instances > 0 && !self.buckets.is_empty() {
|
||||||
|
anyhow::bail!("cluster.num_instances and cluster.buckets are mutually exclusive");
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.buckets.is_empty() {
|
||||||
|
anyhow::ensure!(self.num_instances > 0, "cluster.num_instances must be positive");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
for bucket in &self.buckets {
|
||||||
|
anyhow::ensure!(
|
||||||
|
bucket.input_length_min <= bucket.input_length_max,
|
||||||
|
"cluster bucket '{}' has input_length_min > input_length_max",
|
||||||
|
bucket.name
|
||||||
|
);
|
||||||
|
anyhow::ensure!(
|
||||||
|
bucket.num_instances > 0,
|
||||||
|
"cluster bucket '{}' must have num_instances > 0",
|
||||||
|
bucket.name
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sorted = self.buckets.iter().collect::<Vec<_>>();
|
||||||
|
sorted.sort_by_key(|bucket| (bucket.input_length_min, bucket.input_length_max));
|
||||||
|
for pair in sorted.windows(2) {
|
||||||
|
let prev = pair[0];
|
||||||
|
let next = pair[1];
|
||||||
|
anyhow::ensure!(
|
||||||
|
prev.input_length_max < next.input_length_min,
|
||||||
|
"cluster buckets '{}' and '{}' overlap",
|
||||||
|
prev.name,
|
||||||
|
next.name
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deref for ClusterConfig {
|
||||||
|
type Target = ClusterBucketView;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
self.bucket_view()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum RouterMode {
|
pub enum RouterMode {
|
||||||
@@ -544,7 +720,7 @@ impl Config {
|
|||||||
.with_context(|| format!("parsing config {}", path.display()))?;
|
.with_context(|| format!("parsing config {}", path.display()))?;
|
||||||
let yaml_dir = path.parent().unwrap_or(Path::new("."));
|
let yaml_dir = path.parent().unwrap_or(Path::new("."));
|
||||||
raw.resolve(yaml_dir)
|
raw.resolve(yaml_dir)
|
||||||
.with_context(|| format!("resolving config {}", path.display()))
|
.map_err(|err| anyhow::anyhow!("resolving config {}: {err}", path.display()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -562,10 +738,22 @@ struct RawConfig {
|
|||||||
hardware: RawHardwareConfig,
|
hardware: RawHardwareConfig,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
calibration: CalibrationConfig,
|
calibration: CalibrationConfig,
|
||||||
cluster: ClusterConfig,
|
cluster: RawClusterConfig,
|
||||||
sim: SimConfig,
|
sim: SimConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct RawClusterConfig {
|
||||||
|
#[serde(default)]
|
||||||
|
num_instances: Option<u32>,
|
||||||
|
meta_store: MetaStoreConfig,
|
||||||
|
router: RouterConfig,
|
||||||
|
#[serde(default)]
|
||||||
|
global_router: GlobalRouterConfig,
|
||||||
|
#[serde(default)]
|
||||||
|
buckets: Vec<BucketConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct RawModelConfig {
|
struct RawModelConfig {
|
||||||
/// Path to a HuggingFace `config.json`. Resolved relative to the YAML
|
/// Path to a HuggingFace `config.json`. Resolved relative to the YAML
|
||||||
@@ -678,12 +866,36 @@ impl RawConfig {
|
|||||||
model,
|
model,
|
||||||
hardware,
|
hardware,
|
||||||
calibration: self.calibration,
|
calibration: self.calibration,
|
||||||
cluster: self.cluster,
|
cluster: self.cluster.resolve()?,
|
||||||
sim: self.sim,
|
sim: self.sim,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl RawClusterConfig {
|
||||||
|
fn resolve(self) -> Result<ClusterConfig> {
|
||||||
|
let view = ClusterBucketView {
|
||||||
|
buckets: self.buckets,
|
||||||
|
global_router: self.global_router,
|
||||||
|
};
|
||||||
|
let mut cluster = ClusterConfig {
|
||||||
|
num_instances: self.num_instances.unwrap_or(0),
|
||||||
|
meta_store: self.meta_store,
|
||||||
|
router: self.router,
|
||||||
|
};
|
||||||
|
|
||||||
|
let view_id = register_cluster_bucket_view(view);
|
||||||
|
if view_id > 0 {
|
||||||
|
// Preserve the public ClusterConfig layout for existing callers while
|
||||||
|
// carrying bucketed config through validation and tests.
|
||||||
|
cluster.router.precise_probe_latency_us = -(view_id as f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
cluster.validate()?;
|
||||||
|
Ok(cluster)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl RawModelConfig {
|
impl RawModelConfig {
|
||||||
fn resolve(self, yaml_dir: &Path) -> Result<ModelConfig> {
|
fn resolve(self, yaml_dir: &Path) -> Result<ModelConfig> {
|
||||||
// Start from HF config.json if specified, else empty default.
|
// Start from HF config.json if specified, else empty default.
|
||||||
@@ -1003,4 +1215,161 @@ sim:
|
|||||||
assert_eq!(cfg.model.weight_dtype.as_deref(), Some("fp4"));
|
assert_eq!(cfg.model.weight_dtype.as_deref(), Some("fp4"));
|
||||||
assert!((cfg.model.weight_dtype_bytes() - 0.5).abs() < 1e-12);
|
assert!((cfg.model.weight_dtype_bytes() - 0.5).abs() < 1e-12);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bucketed_config_loads_and_preserves_legacy_single_pool() {
|
||||||
|
let legacy = write_temp_config(
|
||||||
|
r#"
|
||||||
|
model:
|
||||||
|
name: test
|
||||||
|
num_layers: 4
|
||||||
|
num_kv_heads: 2
|
||||||
|
head_dim: 64
|
||||||
|
dtype_bytes: 2
|
||||||
|
block_size_tokens: 16
|
||||||
|
flops_per_token_prefill: 1.0e9
|
||||||
|
attn_quadratic_coeff: 64.0
|
||||||
|
hardware:
|
||||||
|
gpu_flops: 1.0e14
|
||||||
|
gpu_mem_bw: 1.0e12
|
||||||
|
hbm_bytes: 1.0e9
|
||||||
|
cluster:
|
||||||
|
num_instances: 2
|
||||||
|
meta_store:
|
||||||
|
ttl_seconds: 10.0
|
||||||
|
router:
|
||||||
|
mode: cache_affinity
|
||||||
|
global_router:
|
||||||
|
mode: strict_input_length
|
||||||
|
sim:
|
||||||
|
trace_path: trace.jsonl
|
||||||
|
output_dir: runs/test
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
let cfg = Config::from_yaml_path(&legacy).unwrap();
|
||||||
|
assert_eq!(cfg.cluster.legacy_num_instances(), Some(2));
|
||||||
|
assert_eq!(cfg.cluster.buckets.len(), 0);
|
||||||
|
|
||||||
|
let bucketed = write_temp_config(
|
||||||
|
r#"
|
||||||
|
model:
|
||||||
|
name: test
|
||||||
|
num_layers: 4
|
||||||
|
num_kv_heads: 2
|
||||||
|
head_dim: 64
|
||||||
|
dtype_bytes: 2
|
||||||
|
block_size_tokens: 16
|
||||||
|
flops_per_token_prefill: 1.0e9
|
||||||
|
attn_quadratic_coeff: 64.0
|
||||||
|
hardware:
|
||||||
|
gpu_flops: 1.0e14
|
||||||
|
gpu_mem_bw: 1.0e12
|
||||||
|
hbm_bytes: 1.0e9
|
||||||
|
cluster:
|
||||||
|
meta_store:
|
||||||
|
ttl_seconds: 10.0
|
||||||
|
router:
|
||||||
|
mode: cache_affinity
|
||||||
|
global_router:
|
||||||
|
mode: strict_input_length
|
||||||
|
buckets:
|
||||||
|
- name: short
|
||||||
|
input_length_min: 0
|
||||||
|
input_length_max: 32
|
||||||
|
num_instances: 2
|
||||||
|
- name: long
|
||||||
|
input_length_min: 33
|
||||||
|
input_length_max: 96
|
||||||
|
num_instances: 1
|
||||||
|
sim:
|
||||||
|
trace_path: trace.jsonl
|
||||||
|
output_dir: runs/test
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
let cfg = Config::from_yaml_path(&bucketed).unwrap();
|
||||||
|
assert_eq!(cfg.cluster.legacy_num_instances(), None);
|
||||||
|
assert_eq!(cfg.cluster.buckets.len(), 2);
|
||||||
|
assert_eq!(cfg.cluster.buckets[0].name, "short");
|
||||||
|
assert_eq!(cfg.cluster.buckets[1].num_instances, 1);
|
||||||
|
assert_eq!(cfg.cluster.total_instances(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bucketed_config_rejects_overlapping_ranges_and_mixed_modes() {
|
||||||
|
let overlap = write_temp_config(
|
||||||
|
r#"
|
||||||
|
model:
|
||||||
|
name: test
|
||||||
|
num_layers: 4
|
||||||
|
num_kv_heads: 2
|
||||||
|
head_dim: 64
|
||||||
|
dtype_bytes: 2
|
||||||
|
block_size_tokens: 16
|
||||||
|
flops_per_token_prefill: 1.0e9
|
||||||
|
attn_quadratic_coeff: 64.0
|
||||||
|
hardware:
|
||||||
|
gpu_flops: 1.0e14
|
||||||
|
gpu_mem_bw: 1.0e12
|
||||||
|
hbm_bytes: 1.0e9
|
||||||
|
cluster:
|
||||||
|
meta_store:
|
||||||
|
ttl_seconds: 10.0
|
||||||
|
router:
|
||||||
|
mode: cache_affinity
|
||||||
|
global_router:
|
||||||
|
mode: strict_input_length
|
||||||
|
buckets:
|
||||||
|
- name: short
|
||||||
|
input_length_min: 0
|
||||||
|
input_length_max: 32
|
||||||
|
num_instances: 2
|
||||||
|
- name: overlap
|
||||||
|
input_length_min: 32
|
||||||
|
input_length_max: 96
|
||||||
|
num_instances: 1
|
||||||
|
sim:
|
||||||
|
trace_path: trace.jsonl
|
||||||
|
output_dir: runs/test
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
let err = Config::from_yaml_path(&overlap).unwrap_err();
|
||||||
|
assert!(err.to_string().contains("overlap"));
|
||||||
|
|
||||||
|
let mixed = write_temp_config(
|
||||||
|
r#"
|
||||||
|
model:
|
||||||
|
name: test
|
||||||
|
num_layers: 4
|
||||||
|
num_kv_heads: 2
|
||||||
|
head_dim: 64
|
||||||
|
dtype_bytes: 2
|
||||||
|
block_size_tokens: 16
|
||||||
|
flops_per_token_prefill: 1.0e9
|
||||||
|
attn_quadratic_coeff: 64.0
|
||||||
|
hardware:
|
||||||
|
gpu_flops: 1.0e14
|
||||||
|
gpu_mem_bw: 1.0e12
|
||||||
|
hbm_bytes: 1.0e9
|
||||||
|
cluster:
|
||||||
|
num_instances: 2
|
||||||
|
meta_store:
|
||||||
|
ttl_seconds: 10.0
|
||||||
|
router:
|
||||||
|
mode: cache_affinity
|
||||||
|
global_router:
|
||||||
|
mode: strict_input_length
|
||||||
|
buckets:
|
||||||
|
- name: short
|
||||||
|
input_length_min: 0
|
||||||
|
input_length_max: 32
|
||||||
|
num_instances: 2
|
||||||
|
sim:
|
||||||
|
trace_path: trace.jsonl
|
||||||
|
output_dir: runs/test
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
let err = Config::from_yaml_path(&mixed).unwrap_err();
|
||||||
|
assert!(err.to_string().contains("num_instances"));
|
||||||
|
assert!(err.to_string().contains("buckets"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user