fix: close bucketed cluster config model gaps
This commit is contained in:
@@ -37,8 +37,9 @@ pub struct Cluster {
|
||||
|
||||
impl Cluster {
|
||||
pub fn new(config: &Config, model: &ModelConfig) -> Self {
|
||||
let mut instances = Vec::with_capacity(config.cluster.num_instances as usize);
|
||||
for id in 0..config.cluster.num_instances {
|
||||
let total_instances = config.cluster.total_instances();
|
||||
let mut instances = Vec::with_capacity(total_instances as usize);
|
||||
for id in 0..total_instances {
|
||||
instances.push(Instance::new(
|
||||
id as InstanceId,
|
||||
model,
|
||||
@@ -226,7 +227,9 @@ mod tests {
|
||||
..CalibrationConfig::default()
|
||||
},
|
||||
cluster: ClusterConfig {
|
||||
num_instances: 1,
|
||||
num_instances: Some(1),
|
||||
buckets: Vec::new(),
|
||||
global_router: Default::default(),
|
||||
meta_store: MetaStoreConfig {
|
||||
ttl_seconds: 1000.0,
|
||||
},
|
||||
|
||||
152
src/config.rs
152
src/config.rs
@@ -13,11 +13,7 @@
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Deref;
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
@@ -383,7 +379,12 @@ fn default_first_token_ready_us() -> f64 {
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClusterConfig {
|
||||
pub num_instances: u32,
|
||||
#[serde(default)]
|
||||
pub num_instances: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub buckets: Vec<BucketConfig>,
|
||||
#[serde(default)]
|
||||
pub global_router: GlobalRouterConfig,
|
||||
pub meta_store: MetaStoreConfig,
|
||||
pub router: RouterConfig,
|
||||
}
|
||||
@@ -418,58 +419,6 @@ pub enum GlobalRouterMode {
|
||||
BucketScore,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClusterBucketView {
|
||||
pub buckets: Vec<BucketConfig>,
|
||||
pub global_router: GlobalRouterConfig,
|
||||
}
|
||||
|
||||
impl Default for ClusterBucketView {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
buckets: Vec::new(),
|
||||
global_router: GlobalRouterConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static DEFAULT_CLUSTER_BUCKET_VIEW: OnceLock<ClusterBucketView> = OnceLock::new();
|
||||
static CLUSTER_BUCKET_VIEWS: OnceLock<Mutex<HashMap<u32, &'static ClusterBucketView>>> =
|
||||
OnceLock::new();
|
||||
static NEXT_CLUSTER_BUCKET_VIEW_ID: AtomicU32 = AtomicU32::new(1);
|
||||
|
||||
fn default_cluster_bucket_view() -> &'static ClusterBucketView {
|
||||
DEFAULT_CLUSTER_BUCKET_VIEW.get_or_init(ClusterBucketView::default)
|
||||
}
|
||||
|
||||
fn cluster_bucket_views() -> &'static Mutex<HashMap<u32, &'static ClusterBucketView>> {
|
||||
CLUSTER_BUCKET_VIEWS.get_or_init(|| Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
fn register_cluster_bucket_view(view: ClusterBucketView) -> u32 {
|
||||
if view.buckets.is_empty() && view.global_router == GlobalRouterConfig::default() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let id = NEXT_CLUSTER_BUCKET_VIEW_ID.fetch_add(1, Ordering::Relaxed);
|
||||
let leaked = Box::leak(Box::new(view));
|
||||
cluster_bucket_views().lock().unwrap().insert(id, leaked);
|
||||
id
|
||||
}
|
||||
|
||||
fn lookup_cluster_bucket_view(id: u32) -> &'static ClusterBucketView {
|
||||
if id == 0 {
|
||||
return default_cluster_bucket_view();
|
||||
}
|
||||
|
||||
cluster_bucket_views()
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(&id)
|
||||
.copied()
|
||||
.unwrap_or_else(default_cluster_bucket_view)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MetaStoreConfig {
|
||||
pub ttl_seconds: f64,
|
||||
@@ -519,30 +468,19 @@ fn default_prefix_k() -> usize {
|
||||
}
|
||||
|
||||
impl ClusterConfig {
|
||||
fn bucket_view_id(&self) -> u32 {
|
||||
if self.router.precise_probe_latency_us < 0.0 {
|
||||
(-self.router.precise_probe_latency_us) as u32
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
fn bucket_view(&self) -> &'static ClusterBucketView {
|
||||
lookup_cluster_bucket_view(self.bucket_view_id())
|
||||
}
|
||||
|
||||
pub fn legacy_num_instances(&self) -> Option<u32> {
|
||||
self.buckets.is_empty().then_some(self.num_instances)
|
||||
}
|
||||
|
||||
pub fn total_instances(&self) -> u32 {
|
||||
if self.buckets.is_empty() {
|
||||
self.num_instances
|
||||
} else {
|
||||
self.buckets.iter().map(|bucket| bucket.num_instances).sum()
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn total_instances(&self) -> u32 {
|
||||
self.legacy_num_instances()
|
||||
.unwrap_or_else(|| self.buckets.iter().map(|bucket| bucket.num_instances).sum())
|
||||
}
|
||||
|
||||
pub fn effective_buckets(&self) -> Vec<BucketConfig> {
|
||||
if !self.buckets.is_empty() {
|
||||
return self.buckets.clone();
|
||||
@@ -552,21 +490,22 @@ impl ClusterConfig {
|
||||
name: "default".to_string(),
|
||||
input_length_min: 0,
|
||||
input_length_max: u32::MAX,
|
||||
num_instances: self.num_instances,
|
||||
num_instances: self
|
||||
.num_instances
|
||||
.expect("legacy single-pool cluster must have num_instances"),
|
||||
}]
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
if self.num_instances == 0 && self.buckets.is_empty() {
|
||||
anyhow::bail!("cluster must set either num_instances or buckets");
|
||||
}
|
||||
|
||||
if self.num_instances > 0 && !self.buckets.is_empty() {
|
||||
if self.num_instances.is_some() && !self.buckets.is_empty() {
|
||||
anyhow::bail!("cluster.num_instances and cluster.buckets are mutually exclusive");
|
||||
}
|
||||
|
||||
if self.buckets.is_empty() {
|
||||
anyhow::ensure!(self.num_instances > 0, "cluster.num_instances must be positive");
|
||||
let num_instances = self.num_instances.ok_or_else(|| {
|
||||
anyhow::anyhow!("cluster must set either num_instances or buckets")
|
||||
})?;
|
||||
anyhow::ensure!(num_instances > 0, "cluster.num_instances must be positive");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -600,14 +539,6 @@ impl ClusterConfig {
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ClusterConfig {
|
||||
type Target = ClusterBucketView;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.bucket_view()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RouterMode {
|
||||
@@ -738,22 +669,10 @@ struct RawConfig {
|
||||
hardware: RawHardwareConfig,
|
||||
#[serde(default)]
|
||||
calibration: CalibrationConfig,
|
||||
cluster: RawClusterConfig,
|
||||
cluster: ClusterConfig,
|
||||
sim: SimConfig,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RawClusterConfig {
|
||||
#[serde(default)]
|
||||
num_instances: Option<u32>,
|
||||
meta_store: MetaStoreConfig,
|
||||
router: RouterConfig,
|
||||
#[serde(default)]
|
||||
global_router: GlobalRouterConfig,
|
||||
#[serde(default)]
|
||||
buckets: Vec<BucketConfig>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RawModelConfig {
|
||||
/// Path to a HuggingFace `config.json`. Resolved relative to the YAML
|
||||
@@ -866,36 +785,15 @@ impl RawConfig {
|
||||
model,
|
||||
hardware,
|
||||
calibration: self.calibration,
|
||||
cluster: self.cluster.resolve()?,
|
||||
cluster: {
|
||||
self.cluster.validate()?;
|
||||
self.cluster
|
||||
},
|
||||
sim: self.sim,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl RawClusterConfig {
|
||||
fn resolve(self) -> Result<ClusterConfig> {
|
||||
let view = ClusterBucketView {
|
||||
buckets: self.buckets,
|
||||
global_router: self.global_router,
|
||||
};
|
||||
let mut cluster = ClusterConfig {
|
||||
num_instances: self.num_instances.unwrap_or(0),
|
||||
meta_store: self.meta_store,
|
||||
router: self.router,
|
||||
};
|
||||
|
||||
let view_id = register_cluster_bucket_view(view);
|
||||
if view_id > 0 {
|
||||
// Preserve the public ClusterConfig layout for existing callers while
|
||||
// carrying bucketed config through validation and tests.
|
||||
cluster.router.precise_probe_latency_us = -(view_id as f64);
|
||||
}
|
||||
|
||||
cluster.validate()?;
|
||||
Ok(cluster)
|
||||
}
|
||||
}
|
||||
|
||||
impl RawModelConfig {
|
||||
fn resolve(self, yaml_dir: &Path) -> Result<ModelConfig> {
|
||||
// Start from HF config.json if specified, else empty default.
|
||||
|
||||
18
src/main.rs
18
src/main.rs
@@ -52,7 +52,8 @@ struct ConfigOverrides {
|
||||
impl ConfigOverrides {
|
||||
fn apply(&self, cfg: &mut Config) {
|
||||
if let Some(n) = self.num_instances {
|
||||
cfg.cluster.num_instances = n;
|
||||
cfg.cluster.num_instances = Some(n);
|
||||
cfg.cluster.buckets.clear();
|
||||
}
|
||||
if let Some(m) = self.max_requests {
|
||||
cfg.sim.max_requests = Some(m);
|
||||
@@ -205,6 +206,7 @@ fn main() -> Result<()> {
|
||||
fn load(config: &PathBuf, overrides: &ConfigOverrides) -> Result<Config> {
|
||||
let mut cfg = Config::from_yaml_path(config)?;
|
||||
overrides.apply(&mut cfg);
|
||||
cfg.cluster.validate()?;
|
||||
Ok(cfg)
|
||||
}
|
||||
|
||||
@@ -268,7 +270,8 @@ fn cmd_ablate(
|
||||
auto_target_ttft_mean,
|
||||
probe_mode.as_str()
|
||||
);
|
||||
base.cluster.num_instances = chosen;
|
||||
base.cluster.num_instances = Some(chosen);
|
||||
base.cluster.buckets.clear();
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
@@ -283,7 +286,7 @@ fn cmd_ablate(
|
||||
.map(ReplayEvictPolicy::as_str)
|
||||
.collect::<Vec<_>>()
|
||||
.join(","),
|
||||
base.cluster.num_instances,
|
||||
base.cluster.total_instances(),
|
||||
if jobs == 0 {
|
||||
"auto".to_string()
|
||||
} else {
|
||||
@@ -330,7 +333,8 @@ fn auto_select_instances(
|
||||
|
||||
for &n in candidates {
|
||||
let mut cfg = base.clone();
|
||||
cfg.cluster.num_instances = n;
|
||||
cfg.cluster.num_instances = Some(n);
|
||||
cfg.cluster.buckets.clear();
|
||||
cfg.cluster.router.mode = probe;
|
||||
// Isolate calibration output so ablation runs don't overwrite it.
|
||||
cfg.sim.output_dir = out_root
|
||||
@@ -429,7 +433,7 @@ fn cmd_validate(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> {
|
||||
let hbm_blocks = (cfg.hardware.hbm_bytes / block_bytes) as u64;
|
||||
let dram_blocks = (cfg.hardware.dram_bytes / block_bytes) as u64;
|
||||
eprintln!("per-instance HBM blocks = {hbm_blocks}, DRAM blocks = {dram_blocks}");
|
||||
eprintln!("num_instances = {}", cfg.cluster.num_instances);
|
||||
eprintln!("num_instances = {}", cfg.cluster.total_instances());
|
||||
// Sample prefill times at a few prompt lengths.
|
||||
eprintln!("prefill_time samples:");
|
||||
for &n in &[256, 1024, 4096, 16384, 65536, 131072] {
|
||||
@@ -462,7 +466,7 @@ fn cmd_oracle(
|
||||
let cfg = load(path, overrides)?;
|
||||
let block_bytes = cfg.model.kv_block_bytes() as f64;
|
||||
let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64;
|
||||
let aggregate_blocks = per_instance_blocks * cfg.cluster.num_instances as u64;
|
||||
let aggregate_blocks = per_instance_blocks * cfg.cluster.total_instances() as u64;
|
||||
let capacity = match (capacity_blocks, per_instance) {
|
||||
(Some(_), true) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
@@ -518,7 +522,7 @@ fn cmd_oracle(
|
||||
records.len(),
|
||||
capacity,
|
||||
per_instance_blocks,
|
||||
cfg.cluster.num_instances,
|
||||
cfg.cluster.total_instances(),
|
||||
if per_instance {
|
||||
", per-instance mode"
|
||||
} else {
|
||||
|
||||
@@ -519,7 +519,7 @@ pub fn replay_fixed_placement(
|
||||
let block_bytes = cfg.model.kv_block_bytes() as f64;
|
||||
let l0_cap = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as usize;
|
||||
let l1_cap = (cfg.hardware.dram_bytes / block_bytes).max(1.0) as usize;
|
||||
let num_instances = cfg.cluster.num_instances as usize;
|
||||
let num_instances = cfg.cluster.total_instances() as usize;
|
||||
let mut caches: Vec<ReplayInstanceCache> = (0..num_instances)
|
||||
.map(|_| ReplayInstanceCache::new(policy, l0_cap, l1_cap))
|
||||
.collect();
|
||||
|
||||
@@ -57,7 +57,7 @@ pub struct AdaptiveAffinityRouter {
|
||||
|
||||
impl AdaptiveAffinityRouter {
|
||||
pub fn new(config: &Config) -> Self {
|
||||
let n = config.cluster.num_instances as usize;
|
||||
let n = config.cluster.total_instances() as usize;
|
||||
let configured_fan_out = config.cluster.router.affinity_fan_out;
|
||||
let max_fan_out = if configured_fan_out > 0 {
|
||||
configured_fan_out.max(2).min(n)
|
||||
|
||||
@@ -53,7 +53,7 @@ pub struct LineageAffinityRouter {
|
||||
|
||||
impl LineageAffinityRouter {
|
||||
pub fn new(config: &Config) -> Self {
|
||||
let n = config.cluster.num_instances as usize;
|
||||
let n = config.cluster.total_instances() as usize;
|
||||
let configured_fan_out = config.cluster.router.affinity_fan_out;
|
||||
let max_fan_out = if configured_fan_out > 0 {
|
||||
configured_fan_out.max(2).min(n)
|
||||
|
||||
@@ -181,7 +181,9 @@ mod tests {
|
||||
hardware: test_hardware(),
|
||||
calibration: CalibrationConfig::default(),
|
||||
cluster: ClusterConfig {
|
||||
num_instances: 2,
|
||||
num_instances: Some(2),
|
||||
buckets: Vec::new(),
|
||||
global_router: Default::default(),
|
||||
meta_store: MetaStoreConfig {
|
||||
ttl_seconds: 1000.0,
|
||||
},
|
||||
|
||||
@@ -51,7 +51,7 @@ pub struct PrefixAffinityRouter {
|
||||
|
||||
impl PrefixAffinityRouter {
|
||||
pub fn new(config: &Config) -> Self {
|
||||
let n = config.cluster.num_instances as usize;
|
||||
let n = config.cluster.total_instances() as usize;
|
||||
let cfg_fan = config.cluster.router.affinity_fan_out;
|
||||
// fan_out: if configured, use it; otherwise auto = max(2, n/8).
|
||||
let fan_out = if cfg_fan > 0 {
|
||||
|
||||
@@ -41,7 +41,9 @@ fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config {
|
||||
},
|
||||
calibration: CalibrationConfig::default(),
|
||||
cluster: ClusterConfig {
|
||||
num_instances: 4,
|
||||
num_instances: Some(4),
|
||||
buckets: Vec::new(),
|
||||
global_router: Default::default(),
|
||||
meta_store: MetaStoreConfig {
|
||||
ttl_seconds: 1000.0,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user