fix: close bucketed cluster config model gaps

This commit is contained in:
2026-04-17 14:21:34 +08:00
parent a723d7a811
commit d8a0796506
9 changed files with 52 additions and 143 deletions

View File

@@ -37,8 +37,9 @@ pub struct Cluster {
impl Cluster { impl Cluster {
pub fn new(config: &Config, model: &ModelConfig) -> Self { pub fn new(config: &Config, model: &ModelConfig) -> Self {
let mut instances = Vec::with_capacity(config.cluster.num_instances as usize); let total_instances = config.cluster.total_instances();
for id in 0..config.cluster.num_instances { let mut instances = Vec::with_capacity(total_instances as usize);
for id in 0..total_instances {
instances.push(Instance::new( instances.push(Instance::new(
id as InstanceId, id as InstanceId,
model, model,
@@ -226,7 +227,9 @@ mod tests {
..CalibrationConfig::default() ..CalibrationConfig::default()
}, },
cluster: ClusterConfig { cluster: ClusterConfig {
num_instances: 1, num_instances: Some(1),
buckets: Vec::new(),
global_router: Default::default(),
meta_store: MetaStoreConfig { meta_store: MetaStoreConfig {
ttl_seconds: 1000.0, ttl_seconds: 1000.0,
}, },

View File

@@ -13,11 +13,7 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::ops::Deref;
use std::path::Path; use std::path::Path;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Mutex, OnceLock};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config { pub struct Config {
@@ -383,7 +379,12 @@ fn default_first_token_ready_us() -> f64 {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterConfig { pub struct ClusterConfig {
pub num_instances: u32, #[serde(default)]
pub num_instances: Option<u32>,
#[serde(default)]
pub buckets: Vec<BucketConfig>,
#[serde(default)]
pub global_router: GlobalRouterConfig,
pub meta_store: MetaStoreConfig, pub meta_store: MetaStoreConfig,
pub router: RouterConfig, pub router: RouterConfig,
} }
@@ -418,58 +419,6 @@ pub enum GlobalRouterMode {
BucketScore, BucketScore,
} }
#[derive(Debug, Clone)]
pub struct ClusterBucketView {
pub buckets: Vec<BucketConfig>,
pub global_router: GlobalRouterConfig,
}
impl Default for ClusterBucketView {
fn default() -> Self {
Self {
buckets: Vec::new(),
global_router: GlobalRouterConfig::default(),
}
}
}
static DEFAULT_CLUSTER_BUCKET_VIEW: OnceLock<ClusterBucketView> = OnceLock::new();
static CLUSTER_BUCKET_VIEWS: OnceLock<Mutex<HashMap<u32, &'static ClusterBucketView>>> =
OnceLock::new();
static NEXT_CLUSTER_BUCKET_VIEW_ID: AtomicU32 = AtomicU32::new(1);
fn default_cluster_bucket_view() -> &'static ClusterBucketView {
DEFAULT_CLUSTER_BUCKET_VIEW.get_or_init(ClusterBucketView::default)
}
fn cluster_bucket_views() -> &'static Mutex<HashMap<u32, &'static ClusterBucketView>> {
CLUSTER_BUCKET_VIEWS.get_or_init(|| Mutex::new(HashMap::new()))
}
fn register_cluster_bucket_view(view: ClusterBucketView) -> u32 {
if view.buckets.is_empty() && view.global_router == GlobalRouterConfig::default() {
return 0;
}
let id = NEXT_CLUSTER_BUCKET_VIEW_ID.fetch_add(1, Ordering::Relaxed);
let leaked = Box::leak(Box::new(view));
cluster_bucket_views().lock().unwrap().insert(id, leaked);
id
}
fn lookup_cluster_bucket_view(id: u32) -> &'static ClusterBucketView {
if id == 0 {
return default_cluster_bucket_view();
}
cluster_bucket_views()
.lock()
.unwrap()
.get(&id)
.copied()
.unwrap_or_else(default_cluster_bucket_view)
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetaStoreConfig { pub struct MetaStoreConfig {
pub ttl_seconds: f64, pub ttl_seconds: f64,
@@ -519,30 +468,19 @@ fn default_prefix_k() -> usize {
} }
impl ClusterConfig { impl ClusterConfig {
fn bucket_view_id(&self) -> u32 {
if self.router.precise_probe_latency_us < 0.0 {
(-self.router.precise_probe_latency_us) as u32
} else {
0
}
}
fn bucket_view(&self) -> &'static ClusterBucketView {
lookup_cluster_bucket_view(self.bucket_view_id())
}
pub fn legacy_num_instances(&self) -> Option<u32> { pub fn legacy_num_instances(&self) -> Option<u32> {
self.buckets.is_empty().then_some(self.num_instances)
}
pub fn total_instances(&self) -> u32 {
if self.buckets.is_empty() { if self.buckets.is_empty() {
self.num_instances self.num_instances
} else { } else {
self.buckets.iter().map(|bucket| bucket.num_instances).sum() None
} }
} }
pub fn total_instances(&self) -> u32 {
self.legacy_num_instances()
.unwrap_or_else(|| self.buckets.iter().map(|bucket| bucket.num_instances).sum())
}
pub fn effective_buckets(&self) -> Vec<BucketConfig> { pub fn effective_buckets(&self) -> Vec<BucketConfig> {
if !self.buckets.is_empty() { if !self.buckets.is_empty() {
return self.buckets.clone(); return self.buckets.clone();
@@ -552,21 +490,22 @@ impl ClusterConfig {
name: "default".to_string(), name: "default".to_string(),
input_length_min: 0, input_length_min: 0,
input_length_max: u32::MAX, input_length_max: u32::MAX,
num_instances: self.num_instances, num_instances: self
.num_instances
.expect("legacy single-pool cluster must have num_instances"),
}] }]
} }
pub fn validate(&self) -> Result<()> { pub fn validate(&self) -> Result<()> {
if self.num_instances == 0 && self.buckets.is_empty() { if self.num_instances.is_some() && !self.buckets.is_empty() {
anyhow::bail!("cluster must set either num_instances or buckets");
}
if self.num_instances > 0 && !self.buckets.is_empty() {
anyhow::bail!("cluster.num_instances and cluster.buckets are mutually exclusive"); anyhow::bail!("cluster.num_instances and cluster.buckets are mutually exclusive");
} }
if self.buckets.is_empty() { if self.buckets.is_empty() {
anyhow::ensure!(self.num_instances > 0, "cluster.num_instances must be positive"); let num_instances = self.num_instances.ok_or_else(|| {
anyhow::anyhow!("cluster must set either num_instances or buckets")
})?;
anyhow::ensure!(num_instances > 0, "cluster.num_instances must be positive");
return Ok(()); return Ok(());
} }
@@ -600,14 +539,6 @@ impl ClusterConfig {
} }
} }
impl Deref for ClusterConfig {
type Target = ClusterBucketView;
fn deref(&self) -> &Self::Target {
self.bucket_view()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum RouterMode { pub enum RouterMode {
@@ -738,22 +669,10 @@ struct RawConfig {
hardware: RawHardwareConfig, hardware: RawHardwareConfig,
#[serde(default)] #[serde(default)]
calibration: CalibrationConfig, calibration: CalibrationConfig,
cluster: RawClusterConfig, cluster: ClusterConfig,
sim: SimConfig, sim: SimConfig,
} }
#[derive(Deserialize)]
struct RawClusterConfig {
#[serde(default)]
num_instances: Option<u32>,
meta_store: MetaStoreConfig,
router: RouterConfig,
#[serde(default)]
global_router: GlobalRouterConfig,
#[serde(default)]
buckets: Vec<BucketConfig>,
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct RawModelConfig { struct RawModelConfig {
/// Path to a HuggingFace `config.json`. Resolved relative to the YAML /// Path to a HuggingFace `config.json`. Resolved relative to the YAML
@@ -866,36 +785,15 @@ impl RawConfig {
model, model,
hardware, hardware,
calibration: self.calibration, calibration: self.calibration,
cluster: self.cluster.resolve()?, cluster: {
self.cluster.validate()?;
self.cluster
},
sim: self.sim, sim: self.sim,
}) })
} }
} }
impl RawClusterConfig {
fn resolve(self) -> Result<ClusterConfig> {
let view = ClusterBucketView {
buckets: self.buckets,
global_router: self.global_router,
};
let mut cluster = ClusterConfig {
num_instances: self.num_instances.unwrap_or(0),
meta_store: self.meta_store,
router: self.router,
};
let view_id = register_cluster_bucket_view(view);
if view_id > 0 {
// Preserve the public ClusterConfig layout for existing callers while
// carrying bucketed config through validation and tests.
cluster.router.precise_probe_latency_us = -(view_id as f64);
}
cluster.validate()?;
Ok(cluster)
}
}
impl RawModelConfig { impl RawModelConfig {
fn resolve(self, yaml_dir: &Path) -> Result<ModelConfig> { fn resolve(self, yaml_dir: &Path) -> Result<ModelConfig> {
// Start from HF config.json if specified, else empty default. // Start from HF config.json if specified, else empty default.

View File

@@ -52,7 +52,8 @@ struct ConfigOverrides {
impl ConfigOverrides { impl ConfigOverrides {
fn apply(&self, cfg: &mut Config) { fn apply(&self, cfg: &mut Config) {
if let Some(n) = self.num_instances { if let Some(n) = self.num_instances {
cfg.cluster.num_instances = n; cfg.cluster.num_instances = Some(n);
cfg.cluster.buckets.clear();
} }
if let Some(m) = self.max_requests { if let Some(m) = self.max_requests {
cfg.sim.max_requests = Some(m); cfg.sim.max_requests = Some(m);
@@ -205,6 +206,7 @@ fn main() -> Result<()> {
fn load(config: &PathBuf, overrides: &ConfigOverrides) -> Result<Config> { fn load(config: &PathBuf, overrides: &ConfigOverrides) -> Result<Config> {
let mut cfg = Config::from_yaml_path(config)?; let mut cfg = Config::from_yaml_path(config)?;
overrides.apply(&mut cfg); overrides.apply(&mut cfg);
cfg.cluster.validate()?;
Ok(cfg) Ok(cfg)
} }
@@ -268,7 +270,8 @@ fn cmd_ablate(
auto_target_ttft_mean, auto_target_ttft_mean,
probe_mode.as_str() probe_mode.as_str()
); );
base.cluster.num_instances = chosen; base.cluster.num_instances = Some(chosen);
base.cluster.buckets.clear();
} }
eprintln!( eprintln!(
@@ -283,7 +286,7 @@ fn cmd_ablate(
.map(ReplayEvictPolicy::as_str) .map(ReplayEvictPolicy::as_str)
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(","), .join(","),
base.cluster.num_instances, base.cluster.total_instances(),
if jobs == 0 { if jobs == 0 {
"auto".to_string() "auto".to_string()
} else { } else {
@@ -330,7 +333,8 @@ fn auto_select_instances(
for &n in candidates { for &n in candidates {
let mut cfg = base.clone(); let mut cfg = base.clone();
cfg.cluster.num_instances = n; cfg.cluster.num_instances = Some(n);
cfg.cluster.buckets.clear();
cfg.cluster.router.mode = probe; cfg.cluster.router.mode = probe;
// Isolate calibration output so ablation runs don't overwrite it. // Isolate calibration output so ablation runs don't overwrite it.
cfg.sim.output_dir = out_root cfg.sim.output_dir = out_root
@@ -429,7 +433,7 @@ fn cmd_validate(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> {
let hbm_blocks = (cfg.hardware.hbm_bytes / block_bytes) as u64; let hbm_blocks = (cfg.hardware.hbm_bytes / block_bytes) as u64;
let dram_blocks = (cfg.hardware.dram_bytes / block_bytes) as u64; let dram_blocks = (cfg.hardware.dram_bytes / block_bytes) as u64;
eprintln!("per-instance HBM blocks = {hbm_blocks}, DRAM blocks = {dram_blocks}"); eprintln!("per-instance HBM blocks = {hbm_blocks}, DRAM blocks = {dram_blocks}");
eprintln!("num_instances = {}", cfg.cluster.num_instances); eprintln!("num_instances = {}", cfg.cluster.total_instances());
// Sample prefill times at a few prompt lengths. // Sample prefill times at a few prompt lengths.
eprintln!("prefill_time samples:"); eprintln!("prefill_time samples:");
for &n in &[256, 1024, 4096, 16384, 65536, 131072] { for &n in &[256, 1024, 4096, 16384, 65536, 131072] {
@@ -462,7 +466,7 @@ fn cmd_oracle(
let cfg = load(path, overrides)?; let cfg = load(path, overrides)?;
let block_bytes = cfg.model.kv_block_bytes() as f64; let block_bytes = cfg.model.kv_block_bytes() as f64;
let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64; let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64;
let aggregate_blocks = per_instance_blocks * cfg.cluster.num_instances as u64; let aggregate_blocks = per_instance_blocks * cfg.cluster.total_instances() as u64;
let capacity = match (capacity_blocks, per_instance) { let capacity = match (capacity_blocks, per_instance) {
(Some(_), true) => { (Some(_), true) => {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
@@ -518,7 +522,7 @@ fn cmd_oracle(
records.len(), records.len(),
capacity, capacity,
per_instance_blocks, per_instance_blocks,
cfg.cluster.num_instances, cfg.cluster.total_instances(),
if per_instance { if per_instance {
", per-instance mode" ", per-instance mode"
} else { } else {

View File

@@ -519,7 +519,7 @@ pub fn replay_fixed_placement(
let block_bytes = cfg.model.kv_block_bytes() as f64; let block_bytes = cfg.model.kv_block_bytes() as f64;
let l0_cap = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as usize; let l0_cap = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as usize;
let l1_cap = (cfg.hardware.dram_bytes / block_bytes).max(1.0) as usize; let l1_cap = (cfg.hardware.dram_bytes / block_bytes).max(1.0) as usize;
let num_instances = cfg.cluster.num_instances as usize; let num_instances = cfg.cluster.total_instances() as usize;
let mut caches: Vec<ReplayInstanceCache> = (0..num_instances) let mut caches: Vec<ReplayInstanceCache> = (0..num_instances)
.map(|_| ReplayInstanceCache::new(policy, l0_cap, l1_cap)) .map(|_| ReplayInstanceCache::new(policy, l0_cap, l1_cap))
.collect(); .collect();

View File

@@ -57,7 +57,7 @@ pub struct AdaptiveAffinityRouter {
impl AdaptiveAffinityRouter { impl AdaptiveAffinityRouter {
pub fn new(config: &Config) -> Self { pub fn new(config: &Config) -> Self {
let n = config.cluster.num_instances as usize; let n = config.cluster.total_instances() as usize;
let configured_fan_out = config.cluster.router.affinity_fan_out; let configured_fan_out = config.cluster.router.affinity_fan_out;
let max_fan_out = if configured_fan_out > 0 { let max_fan_out = if configured_fan_out > 0 {
configured_fan_out.max(2).min(n) configured_fan_out.max(2).min(n)

View File

@@ -53,7 +53,7 @@ pub struct LineageAffinityRouter {
impl LineageAffinityRouter { impl LineageAffinityRouter {
pub fn new(config: &Config) -> Self { pub fn new(config: &Config) -> Self {
let n = config.cluster.num_instances as usize; let n = config.cluster.total_instances() as usize;
let configured_fan_out = config.cluster.router.affinity_fan_out; let configured_fan_out = config.cluster.router.affinity_fan_out;
let max_fan_out = if configured_fan_out > 0 { let max_fan_out = if configured_fan_out > 0 {
configured_fan_out.max(2).min(n) configured_fan_out.max(2).min(n)

View File

@@ -181,7 +181,9 @@ mod tests {
hardware: test_hardware(), hardware: test_hardware(),
calibration: CalibrationConfig::default(), calibration: CalibrationConfig::default(),
cluster: ClusterConfig { cluster: ClusterConfig {
num_instances: 2, num_instances: Some(2),
buckets: Vec::new(),
global_router: Default::default(),
meta_store: MetaStoreConfig { meta_store: MetaStoreConfig {
ttl_seconds: 1000.0, ttl_seconds: 1000.0,
}, },

View File

@@ -51,7 +51,7 @@ pub struct PrefixAffinityRouter {
impl PrefixAffinityRouter { impl PrefixAffinityRouter {
pub fn new(config: &Config) -> Self { pub fn new(config: &Config) -> Self {
let n = config.cluster.num_instances as usize; let n = config.cluster.total_instances() as usize;
let cfg_fan = config.cluster.router.affinity_fan_out; let cfg_fan = config.cluster.router.affinity_fan_out;
// fan_out: if configured, use it; otherwise auto = max(2, n/8). // fan_out: if configured, use it; otherwise auto = max(2, n/8).
let fan_out = if cfg_fan > 0 { let fan_out = if cfg_fan > 0 {

View File

@@ -41,7 +41,9 @@ fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config {
}, },
calibration: CalibrationConfig::default(), calibration: CalibrationConfig::default(),
cluster: ClusterConfig { cluster: ClusterConfig {
num_instances: 4, num_instances: Some(4),
buckets: Vec::new(),
global_router: Default::default(),
meta_store: MetaStoreConfig { meta_store: MetaStoreConfig {
ttl_seconds: 1000.0, ttl_seconds: 1000.0,
}, },