KVCache simulator for LLM serving cluster routing research

Discrete-event simulator for evaluating KV cache-aware routing policies
in prefill-disaggregated LLM serving clusters. Models a two-tier KV cache
hierarchy (L0 GPU HBM + L1 CPU DRAM) with RDMA/PCIe link contention,
architecture-derived roofline compute (MoE, MLA, DSA), and a cluster-wide
meta-store for prefix-aware routing decisions.

Includes 11 routing policies (random, round_robin, least_loaded,
least_tokens, ttl_aware, precise, min_pd, cache_load, cache_score,
estimated_ttft, prefix_affinity), HuggingFace config.json auto-parsing,
built-in GPU hardware presets (H100/H800/H20/A100/B200), and ablation
tooling for systematic policy comparison across real Alibaba serving traces.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-14 01:16:02 +08:00
commit ec73a95e05
52 changed files with 6005 additions and 0 deletions

167
src/cluster/cluster.rs Normal file
View File

@@ -0,0 +1,167 @@
//! Cluster: routes arrivals, performs the L0 / L1 / remote-RDMA fetch chain
//! described in the design diagram, and bookkeeps the global meta store.
use crate::cluster::meta_store::MetaStore;
use crate::config::{Config, ModelConfig};
use crate::instance::instance::AdmittedRequest;
use crate::instance::Instance;
use crate::router::{self, RouteDecision, Router};
use crate::trace::RequestRecord;
use crate::types::InstanceId;
#[derive(Debug, Clone)]
pub struct AdmissionStats {
pub instance: InstanceId,
pub l0_hit_blocks: u32,
pub l1_hit_blocks: u32,
pub remote_hit_blocks: u32,
pub miss_blocks: u32,
pub rdma_bytes: u64,
pub pcie_bytes: u64,
pub fetch_time_s: f64,
pub probe_overhead_s: f64,
pub ready_at: f64,
pub decision: RouteDecision,
}
pub struct Cluster {
pub instances: Vec<Instance>,
pub meta_store: MetaStore,
pub router: Box<dyn Router>,
pub block_size_tokens: u32,
pub kv_block_bytes: u64,
}
impl Cluster {
pub fn new(config: &Config, model: &ModelConfig) -> Self {
let mut instances = Vec::with_capacity(config.cluster.num_instances as usize);
for id in 0..config.cluster.num_instances {
instances.push(Instance::new(id as InstanceId, model, &config.hardware));
}
let meta_store = MetaStore::new(config.cluster.meta_store.ttl_seconds);
let router = router::build(config, config.sim.seed);
Self {
instances,
meta_store,
router,
block_size_tokens: model.block_size_tokens,
kv_block_bytes: model.kv_block_bytes(),
}
}
/// Route + admit a request. Returns the chosen instance plus rich
/// per-request stats for metrics. Does NOT schedule the BatchTick — the
/// simulator driver does that based on the returned `ready_at`.
pub fn route_and_admit(&mut self, req: &RequestRecord, now: f64) -> AdmissionStats {
let decision = self.router.route(req, &self.instances, &self.meta_store, now);
let inst_id = decision.chosen;
let probe_overhead_s = decision.probe_overhead_s;
// The router probe overhead delays the request's effective start time.
let effective_now = now + probe_overhead_s;
let inst = &mut self.instances[inst_id as usize];
let total_blocks = req.hash_ids.len() as u32;
// 1. L0 lookup (touches matched blocks).
let l0_hits = inst.cache.l0.longest_prefix(&req.hash_ids) as u32;
// 2. L1 lookup on the remaining suffix.
let suffix_after_l0 = &req.hash_ids[l0_hits as usize..];
let l1_hits = inst.cache.l1.longest_prefix(suffix_after_l0) as u32;
// L1->L0 transfer cost
let l1_bytes = (l1_hits as u64) * self.kv_block_bytes;
let mut t = effective_now;
if l1_hits > 0 {
t = inst.links.pcie.reserve(t, l1_bytes);
// Promote those blocks into L0
let mut evicted = Vec::new();
inst.cache.l0.insert_blocks(
&suffix_after_l0[..l1_hits as usize],
&mut evicted,
);
}
// 3. Remote v6d lookup for the still-remaining suffix.
let suffix_after_l1 = &suffix_after_l0[l1_hits as usize..];
let mut remote_hit_blocks: u32 = 0;
for &h in suffix_after_l1 {
// A block is remotely available iff some instance other than
// `inst_id` lists it (and not expired).
let owners = self.meta_store.instances_for(h, now);
let any_remote = owners.iter().any(|o| *o != inst_id);
if any_remote {
remote_hit_blocks += 1;
} else {
break; // contiguous prefix - stop on first miss
}
}
let remote_bytes = (remote_hit_blocks as u64) * self.kv_block_bytes;
if remote_hit_blocks > 0 {
// RDMA from peer host -> local DRAM, then PCIe -> GPU
let inst = &mut self.instances[inst_id as usize];
t = inst.links.rdma.reserve(t, remote_bytes);
t = inst.links.pcie.reserve(t, remote_bytes);
// Insert into local L1 (occupies LRU space) AND into L0
let pulled = &suffix_after_l1[..remote_hit_blocks as usize];
let mut evicted_l1 = Vec::new();
inst.cache.l1.insert_blocks(pulled, &mut evicted_l1);
let mut evicted_l0 = Vec::new();
inst.cache.l0.insert_blocks(pulled, &mut evicted_l0);
// The local instance now also owns these blocks - update meta_store.
for &h in pulled {
self.meta_store.insert(h, inst_id, now);
}
}
// 4. Miss = remaining tokens to prefill from scratch.
let miss_blocks = total_blocks - l0_hits - l1_hits - remote_hit_blocks;
let miss_tokens = miss_blocks * self.block_size_tokens;
// The newly-prefilled blocks (after the request runs) are inserted
// into L0 here, and into L1 / meta_store via async writeback. Doing
// this at admission time is OK because we're tracking presence, not
// actually moving bytes — the writeback latency is hidden behind
// request execution and we don't model meta_store inconsistency
// window beyond the TTL itself.
let inst = &mut self.instances[inst_id as usize];
let new_input_blocks = &req.hash_ids[(l0_hits + l1_hits + remote_hit_blocks) as usize..];
let mut evicted_l0 = Vec::new();
inst.cache.l0.insert_blocks(new_input_blocks, &mut evicted_l0);
let mut evicted_l1 = Vec::new();
inst.cache.l1.insert_blocks(new_input_blocks, &mut evicted_l1);
for &h in new_input_blocks {
self.meta_store.insert(h, inst_id, now);
}
// 5. Reserve KV slots for this request's prefill residency.
// PD disaggregation: decode runs elsewhere, so only the input
// blocks occupy HBM on this instance.
let reserved_blocks = total_blocks;
let admitted = AdmittedRequest {
req_id: req.req_id,
arrival: req.arrival,
ready_at: t,
prefill_tokens_remaining: miss_tokens,
reserved_blocks,
};
inst.admit(admitted);
let pcie_bytes = l1_bytes + remote_bytes;
let fetch_time_s = (t - effective_now).max(0.0);
AdmissionStats {
instance: inst_id,
l0_hit_blocks: l0_hits,
l1_hit_blocks: l1_hits,
remote_hit_blocks,
miss_blocks,
rdma_bytes: remote_bytes,
pcie_bytes,
fetch_time_s,
probe_overhead_s,
ready_at: t,
decision,
}
}
}

161
src/cluster/meta_store.rs Normal file
View File

@@ -0,0 +1,161 @@
//! Global redis-like KV-cache index.
//!
//! Maps `block_hash -> SmallVec<(instance_id, expires_at)>`. TTL eviction is
//! lazy (on read). The TTL-aware router uses `score_prefix` to score each
//! instance's predicted longest prefix without probing instances directly.
use ahash::AHashMap;
use smallvec::SmallVec;
use crate::types::InstanceId;
#[derive(Debug, Clone, Copy)]
struct Entry {
instance: InstanceId,
expires_at: f64,
}
#[derive(Debug, Default)]
pub struct MetaStore {
ttl_seconds: f64,
map: AHashMap<u64, SmallVec<[Entry; 4]>>,
}
impl MetaStore {
pub fn new(ttl_seconds: f64) -> Self {
Self {
ttl_seconds,
map: AHashMap::with_capacity(1 << 16),
}
}
pub fn ttl(&self) -> f64 {
self.ttl_seconds
}
/// Record that `instance` now holds `block_hash`.
pub fn insert(&mut self, block_hash: u64, instance: InstanceId, now: f64) {
let entry = Entry {
instance,
expires_at: now + self.ttl_seconds,
};
let bucket = self.map.entry(block_hash).or_default();
// refresh existing entry if present
for e in bucket.iter_mut() {
if e.instance == instance {
e.expires_at = entry.expires_at;
return;
}
}
bucket.push(entry);
}
/// Score each candidate instance by the longest leading prefix of
/// `hash_ids` for which the meta store believes that instance still holds
/// every block. Returns scores indexed by instance id.
pub fn score_prefix(&self, hash_ids: &[u64], now: f64, num_instances: usize) -> Vec<u32> {
if hash_ids.is_empty() {
return vec![0; num_instances];
}
// Walk hashes; at each step intersect the still-eligible instance set.
// Use a small bitset since num_instances is typically <= 1024.
let mut alive: Vec<bool> = vec![false; num_instances];
// First block: seed alive set
let first = hash_ids[0];
let mut any = false;
if let Some(bucket) = self.map.get(&first) {
for e in bucket {
if e.expires_at >= now {
let i = e.instance as usize;
if i < num_instances {
alive[i] = true;
any = true;
}
}
}
}
let mut scores = vec![0u32; num_instances];
if !any {
return scores;
}
for i in 0..num_instances {
if alive[i] {
scores[i] = 1;
}
}
// Subsequent blocks: an instance survives only if the meta store still
// lists it for that block (and not expired).
for (depth, &h) in hash_ids.iter().enumerate().skip(1) {
let bucket = match self.map.get(&h) {
Some(b) => b,
None => break,
};
// mark instances present for this block
let mut present = vec![false; num_instances];
let mut any2 = false;
for e in bucket {
if e.expires_at >= now {
let i = e.instance as usize;
if i < num_instances && alive[i] {
present[i] = true;
any2 = true;
}
}
}
if !any2 {
break;
}
for i in 0..num_instances {
if present[i] {
scores[i] = (depth + 1) as u32;
} else {
alive[i] = false;
}
}
}
scores
}
/// Lookup which (alive) instances claim to hold a given block.
pub fn instances_for(&self, hash: u64, now: f64) -> SmallVec<[InstanceId; 4]> {
let mut out = SmallVec::new();
if let Some(bucket) = self.map.get(&hash) {
for e in bucket {
if e.expires_at >= now {
out.push(e.instance);
}
}
}
out
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn score_prefix_basic() {
let mut m = MetaStore::new(60.0);
m.insert(10, 0, 0.0);
m.insert(11, 0, 0.0);
m.insert(12, 0, 0.0);
m.insert(10, 1, 0.0);
m.insert(11, 1, 0.0);
// instance 1 only has 10,11; instance 0 has 10,11,12
let s = m.score_prefix(&[10, 11, 12, 13], 1.0, 4);
assert_eq!(s[0], 3);
assert_eq!(s[1], 2);
assert_eq!(s[2], 0);
}
#[test]
fn ttl_expiry() {
let mut m = MetaStore::new(1.0);
m.insert(10, 0, 0.0);
let s_now = m.score_prefix(&[10], 0.5, 2);
assert_eq!(s_now[0], 1);
let s_later = m.score_prefix(&[10], 5.0, 2);
assert_eq!(s_later[0], 0);
}
}

6
src/cluster/mod.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod meta_store;
#[allow(clippy::module_inception)]
pub mod cluster;
pub use cluster::Cluster;
pub use meta_store::MetaStore;

510
src/config.rs Normal file
View File

@@ -0,0 +1,510 @@
//! Top-level configuration loaded from YAML.
//!
//! Two config styles are supported:
//!
//! **Architecture-derived** (preferred): set `hidden_size`, `num_attention_heads`,
//! `intermediate_size` and the simulator derives all roofline coefficients, KV
//! block sizes, and weight-stream costs from the model architecture. Supports
//! MoE, MLA (Multi-head Latent Attention), and DSA (DeepSeek Sparse Attention).
//!
//! **Legacy manual**: omit the architecture fields and set
//! `flops_per_token_prefill` + `attn_quadratic_coeff` directly. Backward
//! compatible with older YAML configs.
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub model: ModelConfig,
pub hardware: HardwareConfig,
pub cluster: ClusterConfig,
pub sim: SimConfig,
}
// ---------------------------------------------------------------------------
// Model
// ---------------------------------------------------------------------------
/// Model architecture + roofline coefficients.
///
/// If `hidden_size` is present the compute model is derived from architecture;
/// otherwise the legacy `flops_per_token_prefill` / `attn_quadratic_coeff`
/// fields are used directly.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ModelConfig {
#[serde(default)]
pub name: String,
pub num_layers: u32,
pub num_kv_heads: u32,
pub head_dim: u32,
pub dtype_bytes: u32,
pub block_size_tokens: u32,
// -- Architecture fields (enable auto-derivation when all three present) --
#[serde(default)]
pub hidden_size: Option<u32>,
#[serde(default)]
pub num_attention_heads: Option<u32>,
#[serde(default)]
pub intermediate_size: Option<u32>,
#[serde(default)]
pub moe: Option<MoeConfig>,
#[serde(default)]
pub mla: Option<MlaConfig>,
#[serde(default)]
pub attention: Option<AttentionConfig>,
// -- Legacy manual coefficients (used when hidden_size is absent) ---------
#[serde(default)]
pub flops_per_token_prefill: Option<f64>,
#[serde(default)]
pub attn_quadratic_coeff: Option<f64>,
#[serde(default)]
pub bytes_per_token_prefill: Option<f64>,
#[serde(default, skip_serializing)]
#[allow(dead_code)]
pub flops_per_token_decode: Option<f64>,
#[serde(default, skip_serializing)]
#[allow(dead_code)]
pub bytes_per_token_decode: Option<f64>,
}
/// Whether the config is architecture-derived or uses legacy manual knobs.
impl ModelConfig {
pub fn is_arch_mode(&self) -> bool {
self.hidden_size.is_some()
}
/// Bytes of KV cache per block.
///
/// For standard / GQA: `2 * L * kv_heads * head_dim * dtype * block_tokens`
/// For MLA: `L * (kv_lora_rank + qk_rope_head_dim) * dtype * block_tokens`
pub fn kv_block_bytes(&self) -> u64 {
if let Some(mla) = &self.mla {
self.num_layers as u64
* (mla.kv_lora_rank + mla.qk_rope_head_dim) as u64
* self.dtype_bytes as u64
* self.block_size_tokens as u64
} else {
2u64
* self.num_layers as u64
* self.num_kv_heads as u64
* self.head_dim as u64
* self.dtype_bytes as u64
* self.block_size_tokens as u64
}
}
}
// -- Sub-configs for MoE / MLA / Attention -----------------------------------
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MoeConfig {
pub num_experts: u32,
pub num_active_experts: u32,
#[serde(default)]
pub num_shared_experts: u32,
/// Per-expert FFN intermediate size (`moe_intermediate_size` in HF).
/// Falls back to parent `intermediate_size` if absent.
#[serde(default)]
pub expert_intermediate_size: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MlaConfig {
pub kv_lora_rank: u32,
pub q_lora_rank: u32,
pub qk_nope_head_dim: u32,
pub qk_rope_head_dim: u32,
pub v_head_dim: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AttentionConfig {
Dense,
SlidingWindow {
window_size: u32,
},
Dsa {
/// Tokens within this window attend fully.
dense_window: u32,
/// Beyond the window, attend to every `sparse_stride`-th token.
sparse_stride: u32,
/// Number of initial layers that use dense attention regardless.
#[serde(default)]
first_dense_layers: u32,
},
}
// ---------------------------------------------------------------------------
// Hardware
// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HardwareConfig {
pub gpu_flops: f64,
pub gpu_mem_bw: f64,
pub hbm_bytes: f64,
pub dram_bytes: f64,
pub pcie_bw: f64,
pub pcie_latency_us: f64,
pub rdma_bw: f64,
pub rdma_latency_us: f64,
#[serde(default = "default_max_batch_slots")]
pub max_batch_slots: u32,
#[serde(default = "default_prefill_chunk_tokens")]
pub prefill_chunk_tokens: u32,
}
fn default_max_batch_slots() -> u32 {
256
}
fn default_prefill_chunk_tokens() -> u32 {
2048
}
// ---------------------------------------------------------------------------
// Cluster
// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterConfig {
pub num_instances: u32,
pub meta_store: MetaStoreConfig,
pub router: RouterConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetaStoreConfig {
pub ttl_seconds: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterConfig {
pub mode: RouterMode,
#[serde(default = "default_probe_latency_us")]
pub precise_probe_latency_us: f64,
#[serde(default = "default_probe_topk")]
pub precise_probe_topk: u32,
#[serde(default = "default_load_alpha")]
pub load_alpha: f64,
/// Weight for load (queue_len) in cache_score: `2^(α·load + β·miss)`.
#[serde(default = "default_score_alpha")]
pub score_alpha: f64,
/// Weight for cache miss in cache_score: `2^(α·load + β·miss)`.
#[serde(default = "default_score_beta")]
pub score_beta: f64,
/// Number of leading blocks for prefix fingerprint in prefix_affinity.
#[serde(default = "default_prefix_k")]
pub prefix_k: usize,
/// Number of top-affinity instances to consider in prefix_affinity.
/// 0 means auto (n/8, min 2).
#[serde(default)]
pub affinity_fan_out: usize,
}
fn default_probe_latency_us() -> f64 {
50.0
}
fn default_probe_topk() -> u32 {
4
}
fn default_load_alpha() -> f64 {
1.0
}
fn default_score_alpha() -> f64 {
1.0
}
fn default_score_beta() -> f64 {
0.1
}
fn default_prefix_k() -> usize {
8
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RouterMode {
Random,
RoundRobin,
LeastLoaded,
LeastTokens,
TtlAware,
Precise,
MinPd,
CacheLoad,
CacheScore,
EstimatedTtft,
PrefixAffinity,
}
impl RouterMode {
pub fn parse(s: &str) -> Result<Self> {
match s {
"random" => Ok(Self::Random),
"round_robin" | "rr" => Ok(Self::RoundRobin),
"least_loaded" => Ok(Self::LeastLoaded),
"least_tokens" | "lt" => Ok(Self::LeastTokens),
"ttl_aware" | "ttl" => Ok(Self::TtlAware),
"precise" | "precise_aware" => Ok(Self::Precise),
"min_pd" | "minpd" | "pd" => Ok(Self::MinPd),
"cache_load" | "cl" => Ok(Self::CacheLoad),
"cache_score" | "cs" => Ok(Self::CacheScore),
"estimated_ttft" | "ettft" | "optimal" => Ok(Self::EstimatedTtft),
"prefix_affinity" | "affinity" | "pa" => Ok(Self::PrefixAffinity),
other => Err(anyhow::anyhow!("unknown router mode: {other}")),
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Random => "random",
Self::RoundRobin => "round_robin",
Self::LeastLoaded => "least_loaded",
Self::LeastTokens => "least_tokens",
Self::TtlAware => "ttl_aware",
Self::Precise => "precise",
Self::MinPd => "min_pd",
Self::CacheLoad => "cache_load",
Self::CacheScore => "cache_score",
Self::EstimatedTtft => "estimated_ttft",
Self::PrefixAffinity => "prefix_affinity",
}
}
}
// ---------------------------------------------------------------------------
// Sim
// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimConfig {
pub trace_path: String,
#[serde(default)]
pub max_requests: Option<u64>,
pub output_dir: String,
#[serde(default = "default_sample_interval")]
pub sample_interval_s: f64,
#[serde(default)]
pub seed: u64,
}
fn default_sample_interval() -> f64 {
1.0
}
impl Config {
/// Load from a YAML file, resolving `config_json` (HF model config) and
/// hardware `type` (preset) references if present.
pub fn from_yaml_path<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
let raw_str = std::fs::read_to_string(path)
.with_context(|| format!("reading config {}", path.display()))?;
let raw: RawConfig = serde_yaml::from_str(&raw_str)
.with_context(|| format!("parsing config {}", path.display()))?;
let yaml_dir = path.parent().unwrap_or(Path::new("."));
raw.resolve(yaml_dir)
.with_context(|| format!("resolving config {}", path.display()))
}
}
// ---------------------------------------------------------------------------
// Raw deserialization types — flexible YAML loading
// ---------------------------------------------------------------------------
//
// All model/hardware fields are `Option` so that `config_json` and `type`
// can supply base values, with explicit YAML fields acting as overrides.
// Existing YAML configs (no config_json / type) continue to work unchanged.
#[derive(Deserialize)]
struct RawConfig {
model: RawModelConfig,
hardware: RawHardwareConfig,
cluster: ClusterConfig,
sim: SimConfig,
}
#[derive(Deserialize)]
struct RawModelConfig {
/// Path to a HuggingFace `config.json`. Resolved relative to the YAML
/// file's directory. When present, architecture fields are loaded from
/// the JSON and any explicit YAML fields act as overrides.
#[serde(default)]
config_json: Option<String>,
#[serde(default)]
name: Option<String>,
#[serde(default)]
num_layers: Option<u32>,
#[serde(default)]
num_kv_heads: Option<u32>,
#[serde(default)]
head_dim: Option<u32>,
#[serde(default)]
dtype_bytes: Option<u32>,
#[serde(default)]
block_size_tokens: Option<u32>,
#[serde(default)]
hidden_size: Option<u32>,
#[serde(default)]
num_attention_heads: Option<u32>,
#[serde(default)]
intermediate_size: Option<u32>,
#[serde(default)]
moe: Option<MoeConfig>,
#[serde(default)]
mla: Option<MlaConfig>,
#[serde(default)]
attention: Option<AttentionConfig>,
#[serde(default)]
flops_per_token_prefill: Option<f64>,
#[serde(default)]
attn_quadratic_coeff: Option<f64>,
#[serde(default)]
bytes_per_token_prefill: Option<f64>,
#[serde(default)]
flops_per_token_decode: Option<f64>,
#[serde(default)]
bytes_per_token_decode: Option<f64>,
}
#[derive(Deserialize)]
struct RawHardwareConfig {
/// Hardware preset name (e.g. `"h100"`, `"8xb200"`). When present,
/// specs are loaded from the built-in preset database and any explicit
/// YAML fields override individual values.
#[serde(default, rename = "type")]
hw_type: Option<String>,
#[serde(default)]
gpu_flops: Option<f64>,
#[serde(default)]
gpu_mem_bw: Option<f64>,
#[serde(default)]
hbm_bytes: Option<f64>,
#[serde(default)]
dram_bytes: Option<f64>,
#[serde(default)]
pcie_bw: Option<f64>,
#[serde(default)]
pcie_latency_us: Option<f64>,
#[serde(default)]
rdma_bw: Option<f64>,
#[serde(default)]
rdma_latency_us: Option<f64>,
#[serde(default)]
max_batch_slots: Option<u32>,
#[serde(default)]
prefill_chunk_tokens: Option<u32>,
}
// -- Resolution (merge base + YAML overrides → final Config) ------------------
impl RawConfig {
fn resolve(self, yaml_dir: &Path) -> Result<Config> {
Ok(Config {
model: self.model.resolve(yaml_dir)?,
hardware: self.hardware.resolve()?,
cluster: self.cluster,
sim: self.sim,
})
}
}
impl RawModelConfig {
fn resolve(self, yaml_dir: &Path) -> Result<ModelConfig> {
// Start from HF config.json if specified, else empty default.
let mut m = if let Some(ref cj) = self.config_json {
let cj_path = if Path::new(cj).is_absolute() {
std::path::PathBuf::from(cj)
} else {
yaml_dir.join(cj)
};
crate::hf_config::parse(&cj_path)?
} else {
ModelConfig::default()
};
// Overlay: explicit YAML fields override the base.
if let Some(v) = self.name { m.name = v; }
if let Some(v) = self.num_layers { m.num_layers = v; }
if let Some(v) = self.num_kv_heads { m.num_kv_heads = v; }
if let Some(v) = self.head_dim { m.head_dim = v; }
if let Some(v) = self.dtype_bytes { m.dtype_bytes = v; }
if let Some(v) = self.block_size_tokens { m.block_size_tokens = v; }
if let Some(v) = self.hidden_size { m.hidden_size = Some(v); }
if let Some(v) = self.num_attention_heads { m.num_attention_heads = Some(v); }
if let Some(v) = self.intermediate_size { m.intermediate_size = Some(v); }
if self.moe.is_some() { m.moe = self.moe; }
if self.mla.is_some() { m.mla = self.mla; }
if self.attention.is_some() { m.attention = self.attention; }
if let Some(v) = self.flops_per_token_prefill { m.flops_per_token_prefill = Some(v); }
if let Some(v) = self.attn_quadratic_coeff { m.attn_quadratic_coeff = Some(v); }
if let Some(v) = self.bytes_per_token_prefill { m.bytes_per_token_prefill = Some(v); }
if let Some(v) = self.flops_per_token_decode { m.flops_per_token_decode = Some(v); }
if let Some(v) = self.bytes_per_token_decode { m.bytes_per_token_decode = Some(v); }
// Validate deployment-specific fields that HF config.json never provides.
anyhow::ensure!(
m.dtype_bytes > 0,
"model.dtype_bytes is required (not in HF config.json)"
);
anyhow::ensure!(
m.block_size_tokens > 0,
"model.block_size_tokens is required (not in HF config.json)"
);
Ok(m)
}
}
impl RawHardwareConfig {
fn resolve(self) -> Result<HardwareConfig> {
// Start from preset if specified, else zeros (all must come from YAML).
let mut hw = if let Some(ref t) = self.hw_type {
crate::hardware_presets::resolve(t).ok_or_else(|| {
anyhow::anyhow!(
"unknown hardware preset '{t}'. Available: {}",
crate::hardware_presets::AVAILABLE.join(", ")
)
})?
} else {
HardwareConfig {
gpu_flops: 0.0,
gpu_mem_bw: 0.0,
hbm_bytes: 0.0,
dram_bytes: 0.0,
pcie_bw: 0.0,
pcie_latency_us: 5.0,
rdma_bw: 0.0,
rdma_latency_us: 8.0,
max_batch_slots: default_max_batch_slots(),
prefill_chunk_tokens: default_prefill_chunk_tokens(),
}
};
// Overlay: explicit YAML fields override the preset / defaults.
if let Some(v) = self.gpu_flops { hw.gpu_flops = v; }
if let Some(v) = self.gpu_mem_bw { hw.gpu_mem_bw = v; }
if let Some(v) = self.hbm_bytes { hw.hbm_bytes = v; }
if let Some(v) = self.dram_bytes { hw.dram_bytes = v; }
if let Some(v) = self.pcie_bw { hw.pcie_bw = v; }
if let Some(v) = self.pcie_latency_us { hw.pcie_latency_us = v; }
if let Some(v) = self.rdma_bw { hw.rdma_bw = v; }
if let Some(v) = self.rdma_latency_us { hw.rdma_latency_us = v; }
if let Some(v) = self.max_batch_slots { hw.max_batch_slots = v; }
if let Some(v) = self.prefill_chunk_tokens { hw.prefill_chunk_tokens = v; }
// Validate minimum requirements.
anyhow::ensure!(hw.gpu_flops > 0.0, "hardware.gpu_flops is required");
anyhow::ensure!(hw.gpu_mem_bw > 0.0, "hardware.gpu_mem_bw is required");
anyhow::ensure!(hw.hbm_bytes > 0.0, "hardware.hbm_bytes is required");
Ok(hw)
}
}

170
src/driver.rs Normal file
View File

@@ -0,0 +1,170 @@
//! Simulation driver: pulls trace records, advances the event queue, runs
//! instance batch ticks, and emits metrics.
use anyhow::Result;
use std::collections::HashMap;
use std::path::Path;
use crate::cluster::Cluster;
use crate::config::Config;
use crate::metrics::per_request::{PerRequestRow, PerRequestWriter};
use crate::metrics::routing_log::RoutingLogWriter;
use crate::metrics::summary::Summary;
use crate::metrics::timeseries::{TimeseriesRow, TimeseriesWriter};
use crate::sim::{Event, EventQueue};
use crate::trace::{RequestRecord, TraceReader};
pub struct RunOutputs {
pub summary: Summary,
pub rows: Vec<PerRequestRow>,
}
#[derive(Debug, Clone)]
struct InflightInfo {
arrival: f64,
instance: u32,
total_blocks: u32,
l0_hit_blocks: u32,
l1_hit_blocks: u32,
remote_hit_blocks: u32,
miss_blocks: u32,
rdma_bytes: u64,
pcie_bytes: u64,
probe_overhead_s: f64,
}
pub fn run(config: &Config, output_subdir: Option<&str>) -> Result<RunOutputs> {
let mut cluster = Cluster::new(config, &config.model);
let mut q = EventQueue::new();
// Output directory
let base = Path::new(&config.sim.output_dir);
let out_dir = match output_subdir {
Some(s) => base.join(s),
None => base.to_path_buf(),
};
std::fs::create_dir_all(&out_dir)?;
let mut req_writer = PerRequestWriter::create(out_dir.join("per_request.csv"))?;
let mut ts_writer = TimeseriesWriter::create(out_dir.join("instances.csv"))?;
let mut rt_writer = RoutingLogWriter::create(out_dir.join("routing_log.jsonl"))?;
let mut trace = TraceReader::open(&config.sim.trace_path, config.sim.max_requests)?;
// Load all records (cheap for moderate traces) so we can index by req_id.
// For very large traces a streaming approach with a peekable iterator
// would be better; this keeps the driver simple.
let records: Vec<RequestRecord> = (&mut trace).collect::<Result<Vec<_>, _>>()?;
let mut by_id: HashMap<u64, RequestRecord> = HashMap::with_capacity(records.len());
for r in &records {
q.schedule(r.arrival, Event::Arrival { req_id: r.req_id });
by_id.insert(r.req_id, r.clone());
}
// Periodic samples
if config.sim.sample_interval_s > 0.0 && !records.is_empty() {
let max_t = records.iter().map(|r| r.arrival).fold(0.0_f64, f64::max);
let mut t = 0.0;
while t <= max_t + 60.0 {
q.schedule(t, Event::Sample);
t += config.sim.sample_interval_s;
}
}
let mut inflight: HashMap<u64, InflightInfo> = HashMap::new();
let mut rows: Vec<PerRequestRow> = Vec::with_capacity(records.len());
while let Some((now, ev)) = q.pop() {
match ev {
Event::Arrival { req_id } => {
let req = match by_id.get(&req_id) {
Some(r) => r.clone(),
None => continue,
};
let stats = cluster.route_and_admit(&req, now);
rt_writer.write(&stats.decision)?;
inflight.insert(
req_id,
InflightInfo {
arrival: req.arrival,
instance: stats.instance,
total_blocks: req.hash_ids.len() as u32,
l0_hit_blocks: stats.l0_hit_blocks,
l1_hit_blocks: stats.l1_hit_blocks,
remote_hit_blocks: stats.remote_hit_blocks,
miss_blocks: stats.miss_blocks,
rdma_bytes: stats.rdma_bytes,
pcie_bytes: stats.pcie_bytes,
probe_overhead_s: stats.probe_overhead_s,
},
);
let inst = &mut cluster.instances[stats.instance as usize];
if !inst.tick_scheduled {
inst.tick_scheduled = true;
let when = stats.ready_at.max(now);
q.schedule(when, Event::BatchTick { instance: stats.instance });
}
}
Event::BatchTick { instance } => {
let inst = &mut cluster.instances[instance as usize];
inst.tick_scheduled = false;
let result = inst.step(now);
for (rid, ttft, end) in result.completed {
if let Some(info) = inflight.remove(&rid) {
let row = PerRequestRow {
req_id: rid,
arrival: info.arrival,
ttft,
e2e: end - info.arrival,
instance: info.instance,
total_blocks: info.total_blocks,
l0_hit_blocks: info.l0_hit_blocks,
l1_hit_blocks: info.l1_hit_blocks,
remote_hit_blocks: info.remote_hit_blocks,
miss_blocks: info.miss_blocks,
rdma_bytes: info.rdma_bytes,
pcie_bytes: info.pcie_bytes,
probe_overhead_s: info.probe_overhead_s,
};
req_writer.write(&row)?;
rows.push(row);
}
}
if let Some(next) = result.next_tick {
let inst = &mut cluster.instances[instance as usize];
if !inst.tick_scheduled {
inst.tick_scheduled = true;
q.schedule(next.max(now), Event::BatchTick { instance });
}
}
}
Event::Sample => {
for inst in &cluster.instances {
let busy = if inst.queue_len() > 0 { 1 } else { 0 };
ts_writer.write(&TimeseriesRow {
t: now,
instance: inst.id,
queue_len: inst.queue_len(),
kv_blocks_used: inst.kv_blocks_used,
kv_blocks_total: inst.hbm_block_budget,
busy,
})?;
}
}
Event::Stop => break,
}
}
req_writer.finish()?;
ts_writer.finish()?;
rt_writer.finish()?;
let sim_duration_s = rows
.iter()
.map(|r| r.arrival + r.e2e)
.fold(0.0_f64, f64::max);
let router_name = config.cluster.router.mode.as_str().to_string();
let summary = Summary::from_rows(&router_name, &rows, sim_duration_s);
let summary_json = serde_json::to_string_pretty(&summary)?;
std::fs::write(out_dir.join("summary.json"), summary_json)?;
Ok(RunOutputs { summary, rows })
}

225
src/hardware_presets.rs Normal file
View File

@@ -0,0 +1,225 @@
//! Built-in hardware presets for common GPU configurations.
//!
//! Presets provide baseline specs for single GPUs and tensor-parallel (TP)
//! groups. All values can be overridden in the YAML config by specifying
//! explicit fields alongside `type`:
//!
//! ```yaml
//! hardware:
//! type: 8xb200
//! hbm_bytes: 500.0e9 # override total HBM with actual KV budget
//! ```
use crate::config::HardwareConfig;
/// All recognized preset names (for help/error messages).
pub const AVAILABLE: &[&str] = &[
"h100",
"h800",
"h20",
"a100-80gb",
"a100-40gb",
"b200",
"2xh100",
"4xh100",
"8xh100",
"2xh800",
"4xh800",
"8xh800",
"2xh20",
"4xh20",
"8xh20",
"2xb200",
"4xb200",
"8xb200",
];
/// Resolve a hardware preset by name.
///
/// Case-insensitive; hyphens, underscores, and spaces are stripped before
/// matching. Accepts `NxGPU` patterns (e.g. `8xb200`).
pub fn resolve(name: &str) -> Option<HardwareConfig> {
let key = normalize(name);
let (count, gpu) = parse_count_gpu(&key);
match gpu.as_str() {
"h100" => Some(make_config(count, &H100)),
"h800" => Some(make_config(count, &H800)),
"h20" => Some(make_config(count, &H20)),
"a10080gb" | "a100" => Some(make_config(count, &A100_80GB)),
"a10040gb" => Some(make_config(count, &A100_40GB)),
"b200" => Some(make_config(count, &B200)),
_ => None,
}
}
// ---------------------------------------------------------------------------
// Internals
// ---------------------------------------------------------------------------
fn normalize(s: &str) -> String {
s.to_ascii_lowercase().replace(['-', '_', ' '], "")
}
/// Parse `"8xh100"` → `(8, "h100")`, `"h100"` → `(1, "h100")`.
fn parse_count_gpu(s: &str) -> (u32, String) {
if let Some(pos) = s.find('x') {
if let Ok(n) = s[..pos].parse::<u32>() {
return (n, s[pos + 1..].to_string());
}
}
(1, s.to_string())
}
// -- Per-GPU base specs (single die, BF16 dense) -----------------------------
struct GpuBase {
flops: f64, // BF16 dense TFLOPS
mem_bw: f64, // HBM bandwidth (B/s)
hbm: f64, // Total HBM (bytes)
pcie_gen: u32, // PCIe generation (4/5/6)
}
const H100: GpuBase = GpuBase {
flops: 9.89e14, // 989 TFLOPS BF16
mem_bw: 3.35e12, // 3.35 TB/s HBM3
hbm: 80.0e9, // 80 GB
pcie_gen: 5,
};
const H800: GpuBase = GpuBase {
flops: 9.89e14, // same die as H100
mem_bw: 3.35e12, // 3.35 TB/s HBM3
hbm: 80.0e9, // 80 GB
pcie_gen: 5,
};
const H20: GpuBase = GpuBase {
flops: 1.48e14, // 148 TFLOPS BF16 (China-export Hopper)
mem_bw: 4.0e12, // 4.0 TB/s HBM3
hbm: 96.0e9, // 96 GB
pcie_gen: 5,
};
const A100_80GB: GpuBase = GpuBase {
flops: 3.12e14, // 312 TFLOPS BF16
mem_bw: 2.0e12, // 2.0 TB/s HBM2e
hbm: 80.0e9, // 80 GB
pcie_gen: 4,
};
const A100_40GB: GpuBase = GpuBase {
flops: 3.12e14, // 312 TFLOPS BF16
mem_bw: 1.555e12, // 1.555 TB/s HBM2e
hbm: 40.0e9, // 40 GB
pcie_gen: 4,
};
const B200: GpuBase = GpuBase {
flops: 2.25e15, // 2250 TFLOPS BF16
mem_bw: 8.0e12, // 8.0 TB/s HBM3e
hbm: 192.0e9, // 192 GB
pcie_gen: 6,
};
/// Build a [`HardwareConfig`] from a base GPU spec × TP count.
///
/// Compute, HBM bandwidth, and HBM capacity scale linearly with `n`.
/// PCIe bandwidth scales linearly (one link per GPU). RDMA bandwidth
/// assumes one NIC for ≤4 GPUs and two NICs for ≥8. Server DRAM is a
/// reasonable default based on typical deployment sizes.
fn make_config(n: u32, base: &GpuBase) -> HardwareConfig {
let f = n as f64;
// PCIe per-GPU bandwidth and latency by generation
let (pcie_per_gpu, pcie_lat) = match base.pcie_gen {
6 => (128.0e9, 4.0), // Gen6 x16
5 => (64.0e9, 5.0), // Gen5 x16
_ => (32.0e9, 5.0), // Gen4 x16
};
// RDMA: base NIC speed by PCIe gen, scaled for multi-NIC servers
let (rdma_base, rdma_lat) = match base.pcie_gen {
6 => (50.0e9, 6.0), // 400 Gbps NIC
_ => (25.0e9, 8.0), // 200 Gbps NIC
};
let rdma_scale = if n >= 8 { 2.0 } else { 1.0 };
// Server DRAM: rough defaults by deployment size
let dram = match n {
1 => 512.0e9,
2..=4 => 1.0e12,
_ => 1.5e12,
};
HardwareConfig {
gpu_flops: base.flops * f,
gpu_mem_bw: base.mem_bw * f,
hbm_bytes: base.hbm * f,
dram_bytes: dram,
pcie_bw: pcie_per_gpu * f,
pcie_latency_us: pcie_lat,
rdma_bw: rdma_base * rdma_scale,
rdma_latency_us: rdma_lat,
max_batch_slots: 256,
prefill_chunk_tokens: if n >= 4 { 4096 } else { 2048 },
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_single_gpu() {
let hw = resolve("h100").unwrap();
assert!((hw.gpu_flops - 9.89e14).abs() < 1e10);
assert!((hw.hbm_bytes - 80e9).abs() < 1e6);
assert_eq!(hw.prefill_chunk_tokens, 2048);
}
#[test]
fn resolve_tp_group() {
let hw = resolve("8xb200").unwrap();
assert!((hw.gpu_flops - 2.25e15 * 8.0).abs() < 1e11);
assert!((hw.hbm_bytes - 192e9 * 8.0).abs() < 1e6);
assert!((hw.pcie_bw - 128e9 * 8.0).abs() < 1e6);
assert_eq!(hw.prefill_chunk_tokens, 4096);
}
#[test]
fn resolve_case_and_separator_insensitive() {
assert!(resolve("H100").is_some());
assert!(resolve("8xB200").is_some());
assert!(resolve("8x-B200").is_some());
assert!(resolve("a100-80gb").is_some());
assert!(resolve("A100_80GB").is_some());
assert!(resolve("a100_80gb").is_some());
}
#[test]
fn resolve_unknown_returns_none() {
assert!(resolve("v100").is_none());
assert!(resolve("tpu-v5").is_none());
assert!(resolve("").is_none());
}
#[test]
fn a100_variants() {
let a80 = resolve("a100-80gb").unwrap();
let a40 = resolve("a100-40gb").unwrap();
assert!((a80.hbm_bytes - 80e9).abs() < 1e6);
assert!((a40.hbm_bytes - 40e9).abs() < 1e6);
assert!(a80.gpu_mem_bw > a40.gpu_mem_bw);
}
#[test]
fn scaling_is_linear() {
let s1 = resolve("h100").unwrap();
let s4 = resolve("4xh100").unwrap();
let s8 = resolve("8xh100").unwrap();
assert!((s4.gpu_flops - s1.gpu_flops * 4.0).abs() < 1.0);
assert!((s8.gpu_flops - s1.gpu_flops * 8.0).abs() < 1.0);
assert!((s4.gpu_mem_bw - s1.gpu_mem_bw * 4.0).abs() < 1.0);
assert!((s8.hbm_bytes - s1.hbm_bytes * 8.0).abs() < 1.0);
}
}

193
src/hf_config.rs Normal file
View File

@@ -0,0 +1,193 @@
//! Parse a HuggingFace `config.json` into [`ModelConfig`] fields.
//!
//! Handles common architectures: standard transformer, GQA, MoE, MLA
//! (Multi-head Latent Attention), and DSA (DeepSeek Sparse Attention).
use anyhow::{Context, Result};
use serde_json::Value;
use std::path::Path;
use crate::config::{AttentionConfig, MlaConfig, MoeConfig, ModelConfig};
/// Parse a HuggingFace config.json and return a partially-populated
/// [`ModelConfig`]. The caller must still set `dtype_bytes` and
/// `block_size_tokens` (not part of the HF schema).
pub fn parse(path: &Path) -> Result<ModelConfig> {
let raw = std::fs::read_to_string(path)
.with_context(|| format!("reading config.json at {}", path.display()))?;
let v: Value = serde_json::from_str(&raw)
.with_context(|| format!("parsing config.json at {}", path.display()))?;
parse_value(&v)
}
fn u32_field(v: &Value, key: &str) -> Option<u32> {
v.get(key).and_then(|x| x.as_u64()).map(|x| x as u32)
}
fn parse_value(v: &Value) -> Result<ModelConfig> {
let name = v
.get("model_type")
.and_then(|x| x.as_str())
.unwrap_or("unknown")
.to_string();
let num_layers = u32_field(v, "num_hidden_layers");
let hidden_size = u32_field(v, "hidden_size");
let num_attention_heads = u32_field(v, "num_attention_heads");
let num_kv_heads = u32_field(v, "num_key_value_heads")
.or(num_attention_heads); // default to MHA
let head_dim = u32_field(v, "head_dim").or_else(|| {
// Infer: hidden_size / num_attention_heads
match (hidden_size, num_attention_heads) {
(Some(h), Some(n)) if n > 0 => Some(h / n),
_ => None,
}
});
let intermediate_size = u32_field(v, "intermediate_size");
// --- MoE detection ---
let moe = u32_field(v, "n_routed_experts")
.or_else(|| u32_field(v, "num_local_experts"))
.or_else(|| u32_field(v, "num_experts"))
.map(|num_experts| MoeConfig {
num_experts,
num_active_experts: u32_field(v, "num_experts_per_tok")
.or_else(|| u32_field(v, "num_experts_per_topk"))
.unwrap_or(2),
num_shared_experts: u32_field(v, "n_shared_experts").unwrap_or(0),
expert_intermediate_size: u32_field(v, "moe_intermediate_size"),
});
// --- MLA detection (kv_lora_rank present → MLA) ---
let mla = u32_field(v, "kv_lora_rank").and_then(|kv_lora_rank| {
Some(MlaConfig {
kv_lora_rank,
q_lora_rank: u32_field(v, "q_lora_rank")?,
qk_nope_head_dim: u32_field(v, "qk_nope_head_dim")?,
qk_rope_head_dim: u32_field(v, "qk_rope_head_dim")?,
v_head_dim: u32_field(v, "v_head_dim")?,
})
});
// --- Attention pattern ---
let attention =
if let Some(first_dense) = u32_field(v, "first_k_dense_replace") {
// DSA-style model (GLM-5, DeepSeek-V3).
// dense_window and sparse_stride are typically not in config.json;
// use sensible defaults the user can override in YAML.
Some(AttentionConfig::Dsa {
dense_window: 4096,
sparse_stride: 8,
first_dense_layers: first_dense,
})
} else if let Some(sw) = v
.get("sliding_window")
.and_then(|x| x.as_u64())
.map(|x| x as u32)
{
Some(AttentionConfig::SlidingWindow { window_size: sw })
} else {
None // dense by default
};
Ok(ModelConfig {
name,
num_layers: num_layers.unwrap_or(0),
num_kv_heads: num_kv_heads.unwrap_or(0),
head_dim: head_dim.unwrap_or(0),
hidden_size,
num_attention_heads,
intermediate_size,
moe,
mla,
attention,
// Deployment fields: must come from YAML
dtype_bytes: 0,
block_size_tokens: 0,
..Default::default()
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_dense_model() {
let json = serde_json::json!({
"model_type": "qwen2",
"num_hidden_layers": 28,
"hidden_size": 3584,
"num_attention_heads": 28,
"num_key_value_heads": 4,
"intermediate_size": 18944,
});
let m = parse_value(&json).unwrap();
assert_eq!(m.num_layers, 28);
assert_eq!(m.hidden_size, Some(3584));
assert_eq!(m.num_kv_heads, 4);
assert_eq!(m.head_dim, 128); // 3584 / 28
assert!(m.moe.is_none());
assert!(m.mla.is_none());
assert!(m.attention.is_none());
}
#[test]
fn parse_qwen3_moe() {
let json = serde_json::json!({
"model_type": "qwen3_moe",
"num_hidden_layers": 62,
"hidden_size": 6144,
"num_attention_heads": 96,
"num_key_value_heads": 8,
"head_dim": 128,
"intermediate_size": 8192,
"num_experts": 160,
"num_experts_per_tok": 8,
"moe_intermediate_size": 2560,
});
let m = parse_value(&json).unwrap();
assert_eq!(m.num_layers, 62);
assert_eq!(m.num_kv_heads, 8);
assert_eq!(m.head_dim, 128);
let moe = m.moe.as_ref().unwrap();
assert_eq!(moe.num_experts, 160);
assert_eq!(moe.num_active_experts, 8);
assert_eq!(moe.expert_intermediate_size, Some(2560));
assert_eq!(moe.num_shared_experts, 0);
assert!(m.mla.is_none());
assert!(m.attention.is_none());
}
#[test]
fn parse_moe_mla_dsa() {
let json = serde_json::json!({
"model_type": "glm_moe_dsa",
"num_hidden_layers": 78,
"hidden_size": 6144,
"num_attention_heads": 64,
"num_key_value_heads": 64,
"head_dim": 64,
"intermediate_size": 12288,
"n_routed_experts": 256,
"num_experts_per_tok": 8,
"n_shared_experts": 1,
"moe_intermediate_size": 2048,
"kv_lora_rank": 512,
"q_lora_rank": 2048,
"qk_nope_head_dim": 192,
"qk_rope_head_dim": 64,
"v_head_dim": 256,
"first_k_dense_replace": 3,
});
let m = parse_value(&json).unwrap();
assert_eq!(m.num_layers, 78);
assert_eq!(m.head_dim, 64);
let moe = m.moe.as_ref().unwrap();
assert_eq!(moe.num_experts, 256);
assert_eq!(moe.num_active_experts, 8);
let mla = m.mla.as_ref().unwrap();
assert_eq!(mla.kv_lora_rank, 512);
assert!(matches!(m.attention, Some(AttentionConfig::Dsa { first_dense_layers: 3, .. })));
}
}

405
src/instance/compute.rs Normal file
View File

@@ -0,0 +1,405 @@
//! Roofline cost model for prefill (PD disaggregation — decode not modeled).
//!
//! Two construction modes:
//!
//! **Architecture-derived** (`ModelConfig.hidden_size` present):
//! All FLOPs, attention coefficients, and weight-stream costs are computed
//! from the model shape. Handles standard / GQA / MLA attention projections,
//! MoE routing, and DSA / sliding-window sub-quadratic attention patterns.
//!
//! **Legacy manual** (`hidden_size` absent): uses the raw
//! `flops_per_token_prefill` + `attn_quadratic_coeff` scalars from the YAML.
//!
//! ```text
//! prefill_time(N) = max(compute_time(N), mem_time)
//!
//! compute_time = sum over layers of:
//! (N * linear_flops + attn_coeff * N * effective_ctx(N)) / gpu_flops
//!
//! mem_time = num_layers * weight_bytes_per_layer / gpu_mem_bw
//! ```
//!
//! `effective_ctx(N)` equals `N` for dense attention (→ O(N²) total) but
//! is sub-linear for DSA / sliding-window.
use crate::config::{AttentionConfig, HardwareConfig, ModelConfig};
/// Resolved attention pattern used at runtime.
#[derive(Debug, Clone)]
pub enum AttentionPattern {
/// Full quadratic: effective_ctx = N.
Dense,
/// Sliding window: effective_ctx = min(N, window).
SlidingWindow { window: f64 },
/// DeepSeek Sparse Attention: effective_ctx = min(N, dense_window) +
/// max(0, N - dense_window) / sparse_stride.
Dsa { dense_window: f64, sparse_stride: f64 },
}
#[derive(Debug, Clone)]
pub struct ComputeModel {
/// Total transformer layers.
pub num_layers: f64,
/// How many initial layers use dense attention (rest use `attn_pattern`).
/// For `Dense` pattern this equals `num_layers`.
pub first_dense_layers: f64,
/// Non-attention FLOPs per token per layer (QKV proj + output proj + MLP).
pub linear_flops_per_token: f64,
/// Attention score coefficient: per-layer attention FLOPs =
/// `attn_coeff * N * effective_ctx(N)`.
pub attn_coeff: f64,
/// Attention pattern for non-dense layers.
pub attn_pattern: AttentionPattern,
/// Weight bytes read from HBM per layer (for memory-bound check).
pub weight_bytes_per_layer: f64,
/// Peak GPU FLOPs (aggregate across TP group).
pub gpu_flops: f64,
/// Peak GPU memory bandwidth (aggregate across TP group).
pub gpu_mem_bw: f64,
}
impl ComputeModel {
pub fn new(model: &ModelConfig, hw: &HardwareConfig) -> Self {
if model.is_arch_mode() {
Self::from_arch(model, hw)
} else {
Self::from_manual(model, hw)
}
}
// ----- Architecture-derived construction --------------------------------
fn from_arch(model: &ModelConfig, hw: &HardwareConfig) -> Self {
let h = model.hidden_size.unwrap() as f64;
let n_heads = model.num_attention_heads.unwrap_or(model.num_kv_heads) as f64;
let n_kv = model.num_kv_heads as f64;
let hd = model.head_dim as f64;
let inter = model.intermediate_size.unwrap_or(0) as f64;
let dtype = model.dtype_bytes as f64;
// --- Attention linear FLOPs/token/layer ---
let attn_linear = if let Some(mla) = &model.mla {
let qlr = mla.q_lora_rank as f64;
let kvlr = mla.kv_lora_rank as f64;
let qk_hd = (mla.qk_nope_head_dim + mla.qk_rope_head_dim) as f64;
let qk_rd = mla.qk_rope_head_dim as f64;
let vhd = mla.v_head_dim as f64;
// Q: down-project + up-project
let q = 2.0 * h * qlr + 2.0 * qlr * n_heads * qk_hd;
// KV: down-project (compressed latent + RoPE key)
let kv = 2.0 * h * (kvlr + qk_rd);
// Output: up-project
let o = 2.0 * n_heads * vhd * h;
q + kv + o
} else {
// Standard / GQA
let qkv = 2.0 * h * (n_heads + 2.0 * n_kv) * hd;
let o = 2.0 * n_heads * hd * h;
qkv + o
};
// --- MLP FLOPs/token/layer (SwiGLU: gate + up + down = 3 matmuls) ---
let mlp = if let Some(moe) = &model.moe {
let expert_inter = moe.expert_intermediate_size
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
let active = moe.num_active_experts as f64;
let shared = moe.num_shared_experts as f64;
active * 6.0 * h * expert_inter + shared * 6.0 * h * inter
} else {
6.0 * h * inter
};
let linear_flops = attn_linear + mlp;
// --- Attention quadratic coefficient ---
// attn_flops_per_layer(N) = attn_coeff * N * effective_ctx(N)
let attn_coeff = if let Some(mla) = &model.mla {
let kvlr = mla.kv_lora_rank as f64;
let qk_rd = mla.qk_rope_head_dim as f64;
// Absorbed QK^T: each head dots over (kv_lora_rank + qk_rope_head_dim) dims.
// Absorbed V: each head dots over kv_lora_rank dims.
2.0 * n_heads * (2.0 * kvlr + qk_rd)
} else {
// Standard: QK^T + attn@V, each 2 * n_heads * head_dim per pair.
4.0 * n_heads * hd
};
// --- Weight bytes per layer (active params only for MoE) ---
let attn_wt = if let Some(mla) = &model.mla {
let qlr = mla.q_lora_rank as f64;
let kvlr = mla.kv_lora_rank as f64;
let qk_hd = (mla.qk_nope_head_dim + mla.qk_rope_head_dim) as f64;
let qk_rd = mla.qk_rope_head_dim as f64;
let vhd = mla.v_head_dim as f64;
(h * qlr + qlr * n_heads * qk_hd
+ h * (kvlr + qk_rd)
+ n_heads * vhd * h)
* dtype
} else {
((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * dtype
};
let mlp_wt = if let Some(moe) = &model.moe {
let expert_inter = moe.expert_intermediate_size
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
let active = moe.num_active_experts as f64;
let shared = moe.num_shared_experts as f64;
(active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * dtype
} else {
3.0 * h * inter * dtype
};
let weight_bytes = attn_wt + mlp_wt;
// --- Attention pattern ---
let (attn_pattern, first_dense) = match &model.attention {
Some(AttentionConfig::Dsa {
dense_window,
sparse_stride,
first_dense_layers,
}) => (
AttentionPattern::Dsa {
dense_window: *dense_window as f64,
sparse_stride: *sparse_stride as f64,
},
*first_dense_layers as f64,
),
Some(AttentionConfig::SlidingWindow { window_size }) => (
AttentionPattern::SlidingWindow {
window: *window_size as f64,
},
0.0,
),
Some(AttentionConfig::Dense) | None => (
AttentionPattern::Dense,
model.num_layers as f64,
),
};
Self {
num_layers: model.num_layers as f64,
first_dense_layers: first_dense,
linear_flops_per_token: linear_flops,
attn_coeff,
attn_pattern,
weight_bytes_per_layer: weight_bytes,
gpu_flops: hw.gpu_flops,
gpu_mem_bw: hw.gpu_mem_bw,
}
}
// ----- Legacy manual construction ---------------------------------------
fn from_manual(model: &ModelConfig, hw: &HardwareConfig) -> Self {
Self {
num_layers: model.num_layers as f64,
first_dense_layers: model.num_layers as f64,
linear_flops_per_token: model.flops_per_token_prefill.unwrap_or(0.0),
attn_coeff: model.attn_quadratic_coeff.unwrap_or(0.0),
attn_pattern: AttentionPattern::Dense,
weight_bytes_per_layer: 0.0,
gpu_flops: hw.gpu_flops,
gpu_mem_bw: hw.gpu_mem_bw,
}
}
// ----- Prefill time -----------------------------------------------------
/// Effective context length a single token attends to at sequence length N.
fn effective_ctx(&self, n: f64, dense_layer: bool) -> f64 {
if dense_layer {
return n;
}
match &self.attn_pattern {
AttentionPattern::Dense => n,
AttentionPattern::SlidingWindow { window } => n.min(*window),
AttentionPattern::Dsa {
dense_window,
sparse_stride,
} => {
if n <= *dense_window {
n
} else {
*dense_window + (n - *dense_window) / *sparse_stride
}
}
}
}
/// Time (s) to prefill `n` tokens.
pub fn prefill_time(&self, n: u32) -> f64 {
if n == 0 {
return 0.0;
}
let n = n as f64;
let linear = n * self.linear_flops_per_token;
// Compute FLOPs across all layers (dense + sparse may differ).
let dense_layers = self.first_dense_layers;
let sparse_layers = self.num_layers - dense_layers;
let dense_flops = dense_layers
* (linear + self.attn_coeff * n * self.effective_ctx(n, true));
let sparse_flops = sparse_layers
* (linear + self.attn_coeff * n * self.effective_ctx(n, false));
let total_flops = dense_flops + sparse_flops;
let compute_time = total_flops / self.gpu_flops;
// Weight stream: all layers' active weights read once from HBM.
let mem_time = self.weight_bytes_per_layer * self.num_layers / self.gpu_mem_bw;
compute_time.max(mem_time)
}
/// Print human-readable derived coefficients (for `validate` output).
pub fn describe(&self) -> String {
let pattern_str = match &self.attn_pattern {
AttentionPattern::Dense => "dense".to_string(),
AttentionPattern::SlidingWindow { window } => format!("sliding_window({})", *window as u64),
AttentionPattern::Dsa {
dense_window,
sparse_stride,
} => format!(
"dsa(window={}, stride={}, {} dense layers)",
*dense_window as u64, *sparse_stride as u64, self.first_dense_layers as u64
),
};
format!(
"linear_flops/tok/layer={:.3e}, attn_coeff={:.0}, pattern={}, \
weight_bytes/layer={:.2e}",
self.linear_flops_per_token, self.attn_coeff, pattern_str,
self.weight_bytes_per_layer,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cm_legacy() -> ComputeModel {
ComputeModel {
num_layers: 28.0,
first_dense_layers: 28.0,
linear_flops_per_token: 1.4e10,
attn_coeff: 1024.0,
attn_pattern: AttentionPattern::Dense,
weight_bytes_per_layer: 0.0,
gpu_flops: 9.89e14,
gpu_mem_bw: 3.35e12,
}
}
#[test]
fn prefill_monotonic_in_n() {
let m = cm_legacy();
let mut prev = 0.0;
for &n in &[1u32, 8, 64, 512, 4096, 32768] {
let t = m.prefill_time(n);
assert!(t > prev, "prefill_time should be monotonic; n={n} t={t}");
prev = t;
}
}
#[test]
fn quadratic_dominates_for_long_prompt() {
let m = cm_legacy();
let lin = m.prefill_time(1024);
let big = m.prefill_time(32768);
assert!(big / lin > 32.0);
}
#[test]
fn zero_tokens_is_free() {
let m = cm_legacy();
assert_eq!(m.prefill_time(0), 0.0);
}
#[test]
fn dsa_subquadratic() {
// With DSA (window=4096, stride=8) the cost at 128k should be
// MUCH less than pure quadratic.
let dense = ComputeModel {
num_layers: 78.0,
first_dense_layers: 78.0,
linear_flops_per_token: 1.0e9,
attn_coeff: 139264.0,
attn_pattern: AttentionPattern::Dense,
weight_bytes_per_layer: 0.0,
gpu_flops: 1.8e16,
gpu_mem_bw: 6.4e13,
};
let dsa = ComputeModel {
attn_pattern: AttentionPattern::Dsa {
dense_window: 4096.0,
sparse_stride: 8.0,
},
first_dense_layers: 3.0,
..dense.clone()
};
let n = 131072; // 128k tokens
let t_dense = dense.prefill_time(n);
let t_dsa = dsa.prefill_time(n);
// DSA should be dramatically cheaper at long context.
assert!(
t_dsa < t_dense * 0.3,
"DSA should be <30% of dense at 128k: dense={t_dense:.3} dsa={t_dsa:.3}"
);
// But still monotonic.
assert!(t_dsa > dsa.prefill_time(n / 2));
}
#[test]
fn mem_bound_short_prefill() {
// With very heavy weights and a short prompt, memory should dominate.
let m = ComputeModel {
num_layers: 10.0,
first_dense_layers: 10.0,
linear_flops_per_token: 1.0e6, // tiny compute
attn_coeff: 1.0,
attn_pattern: AttentionPattern::Dense,
weight_bytes_per_layer: 1.0e12, // 1 TB per layer
gpu_flops: 1.0e15,
gpu_mem_bw: 1.0e12,
};
let t1 = m.prefill_time(1);
let t8 = m.prefill_time(8);
// Memory time = 10 * 1e12 / 1e12 = 10s, should dominate.
assert!((t1 - 10.0).abs() < 0.01);
// Doubling tokens shouldn't change time much (mem-bound).
assert!((t8 - t1).abs() / t1 < 0.01);
}
#[test]
fn arch_derives_from_model_config() {
// Minimal dense model: verify from_arch produces something sensible.
let model = ModelConfig {
name: "test".into(),
num_layers: 4,
num_kv_heads: 2,
head_dim: 64,
dtype_bytes: 2,
block_size_tokens: 16,
hidden_size: Some(256),
num_attention_heads: Some(4),
intermediate_size: Some(512),
..Default::default()
};
let hw = HardwareConfig {
gpu_flops: 1e14,
gpu_mem_bw: 1e12,
hbm_bytes: 1e9,
dram_bytes: 4e9,
pcie_bw: 32e9,
pcie_latency_us: 1.0,
rdma_bw: 12e9,
rdma_latency_us: 5.0,
max_batch_slots: 32,
prefill_chunk_tokens: 1024,
};
let cm = ComputeModel::new(&model, &hw);
assert!(cm.linear_flops_per_token > 0.0);
assert!(cm.attn_coeff > 0.0);
assert!(cm.weight_bytes_per_layer > 0.0);
let t = cm.prefill_time(1024);
assert!(t > 0.0);
}
}

191
src/instance/instance.rs Normal file
View File

@@ -0,0 +1,191 @@
//! One simulated **prefill** serving instance.
//!
//! This simulator assumes **PD (prefill/decode) disaggregation**: prefill
//! and decode run on dedicated instance pools, and only prefill instances
//! are modeled here. The decode side is invisible to the KV-cache-aware
//! routing problem we are studying — once prefill finishes, the KV cache is
//! shipped to a decode instance via a separate (out-of-scope) path.
//!
//! As a result this `Instance`:
//! * Owns a two-tier KV cache (L0 = HBM, L1 = DRAM / v6d) used only for
//! prefill prefix reuse.
//! * Owns the PCIe / RDMA links used for cache fetches.
//! * Runs a simple FCFS chunked-prefill scheduler: one request's prefill
//! chunk per step, up to `prefill_chunk_tokens` per chunk.
//!
//! The cluster (`crate::cluster`) is responsible for routing arrivals,
//! consulting the global meta store, and inserting fetched blocks into the
//! instance's caches before handing the request off via `admit`.
use std::collections::VecDeque;
use crate::config::{HardwareConfig, ModelConfig};
use crate::instance::compute::ComputeModel;
use crate::instance::kv_cache::TwoTierCache;
use crate::network::InstanceLinks;
use crate::types::{InstanceId, ReqId};
#[derive(Debug, Clone)]
pub struct AdmittedRequest {
pub req_id: ReqId,
pub arrival: f64,
/// Earliest time at which the KV fetch chain (L1 + RDMA + PCIe) for this
/// request has completed, so its prefill compute can begin.
pub ready_at: f64,
/// Tokens still needing prefill compute (after cache hits accounted for).
pub prefill_tokens_remaining: u32,
/// KV blocks reserved on this instance's HBM for the lifetime of this
/// request's prefill (= number of input blocks).
pub reserved_blocks: u32,
}
#[derive(Debug)]
pub struct StepResult {
pub completed: Vec<(ReqId, f64, f64)>, // (req_id, ttft, end_time)
pub next_tick: Option<f64>,
}
pub struct Instance {
pub id: InstanceId,
pub cache: TwoTierCache,
pub links: InstanceLinks,
pub compute: ComputeModel,
pub block_size_tokens: u32,
pub hbm_block_budget: u32,
pub dram_block_budget: u32,
pub max_batch_slots: u32,
pub prefill_chunk_tokens: u32,
pub kv_blocks_used: u32,
/// Admitted but not yet ready (waiting for fetch chain to land).
pending: VecDeque<AdmittedRequest>,
/// Ready and currently being prefilled (FCFS, one at a time per step).
prefilling: VecDeque<AdmittedRequest>,
/// True if a BatchTick is already on the global queue for us.
pub tick_scheduled: bool,
}
impl Instance {
pub fn new(id: InstanceId, model: &ModelConfig, hw: &HardwareConfig) -> Self {
let block_bytes = model.kv_block_bytes() as f64;
let hbm_blocks = (hw.hbm_bytes / block_bytes).max(1.0) as u32;
let dram_blocks = (hw.dram_bytes / block_bytes).max(1.0) as u32;
Self {
id,
cache: TwoTierCache::new(hbm_blocks as usize, dram_blocks as usize),
links: InstanceLinks::from_hw(hw),
compute: ComputeModel::new(model, hw),
block_size_tokens: model.block_size_tokens,
hbm_block_budget: hbm_blocks,
dram_block_budget: dram_blocks,
max_batch_slots: hw.max_batch_slots,
prefill_chunk_tokens: hw.prefill_chunk_tokens,
kv_blocks_used: 0,
pending: VecDeque::new(),
prefilling: VecDeque::new(),
tick_scheduled: false,
}
}
pub fn queue_len(&self) -> u32 {
(self.pending.len() + self.prefilling.len()) as u32
}
/// Total prefill tokens remaining across all pending and prefilling requests.
pub fn waiting_tokens(&self) -> u64 {
self.pending
.iter()
.chain(self.prefilling.iter())
.map(|r| r.prefill_tokens_remaining as u64)
.sum()
}
/// Estimated wall-clock time to drain all currently queued requests.
///
/// Sums `compute.prefill_time(tokens_j)` for each queued request,
/// capturing the non-linear (quadratic / DSA) cost accurately.
pub fn estimated_drain_time(&self) -> f64 {
self.pending
.iter()
.chain(self.prefilling.iter())
.map(|r| self.compute.prefill_time(r.prefill_tokens_remaining))
.sum()
}
pub fn admit(&mut self, req: AdmittedRequest) {
self.pending.push_back(req);
}
/// Run one batch step. Returns any requests that finished prefill during
/// this step plus the next wakeup time for the instance.
pub fn step(&mut self, now: f64) -> StepResult {
let mut completed = Vec::new();
// 1. Drain ready pending requests into prefilling, respecting KV
// budget and slot cap. A request whose fetch chain is complete
// *and* has zero prefill tokens (full cache hit) finishes
// immediately at `now`.
while let Some(front) = self.pending.front() {
if front.ready_at > now {
break;
}
if self.prefilling.len() as u32 >= self.max_batch_slots {
break;
}
if self.kv_blocks_used + front.reserved_blocks > self.hbm_block_budget {
break;
}
let r = self.pending.pop_front().unwrap();
self.kv_blocks_used += r.reserved_blocks;
if r.prefill_tokens_remaining == 0 {
// Full cache hit: nothing to compute. TTFT == fetch time.
let ttft = now - r.arrival;
self.kv_blocks_used = self.kv_blocks_used.saturating_sub(r.reserved_blocks);
completed.push((r.req_id, ttft, now));
} else {
self.prefilling.push_back(r);
}
}
// 2. Run one chunked-prefill step on the head of `prefilling`.
let chunk_tokens = self
.prefilling
.front()
.map(|r| r.prefill_tokens_remaining.min(self.prefill_chunk_tokens))
.unwrap_or(0);
if chunk_tokens == 0 {
// Nothing compute-bound in flight right now.
return StepResult {
completed,
next_tick: self.next_wakeup(now),
};
}
let dt = self.compute.prefill_time(chunk_tokens);
let t_end = now + dt;
let head = self.prefilling.front_mut().unwrap();
head.prefill_tokens_remaining -= chunk_tokens;
if head.prefill_tokens_remaining == 0 {
let done = self.prefilling.pop_front().unwrap();
let ttft = t_end - done.arrival;
self.kv_blocks_used = self.kv_blocks_used.saturating_sub(done.reserved_blocks);
completed.push((done.req_id, ttft, t_end));
}
StepResult {
completed,
next_tick: self.next_wakeup(t_end),
}
}
fn next_wakeup(&self, after: f64) -> Option<f64> {
if !self.prefilling.is_empty() {
return Some(after);
}
self.pending.front().map(|r| r.ready_at.max(after))
}
}

226
src/instance/kv_cache.rs Normal file
View File

@@ -0,0 +1,226 @@
//! Two-tier LRU KV cache (L0 = GPU HBM, L1 = CPU DRAM / v6d).
//!
//! Each tier stores block hashes; the unit of accounting is one 16-token
//! block. `longest_prefix` walks the hash slice front-to-back and returns the
//! count of leading blocks present in the tier (and touches them so they
//! stay hot).
//!
//! On insert, evicted block hashes are returned so the caller (instance) can
//! propagate them to the global meta store if desired.
use ahash::AHashMap;
/// Doubly-linked-list-backed LRU keyed by block hash.
#[derive(Debug)]
pub struct LruBlocks {
capacity: usize,
map: AHashMap<u64, usize>,
nodes: Vec<Node>,
head: Option<usize>, // most recently used
tail: Option<usize>, // least recently used
free: Vec<usize>,
}
#[derive(Debug, Clone, Copy)]
struct Node {
key: u64,
prev: Option<usize>,
next: Option<usize>,
}
impl LruBlocks {
pub fn new(capacity: usize) -> Self {
Self {
capacity,
map: AHashMap::with_capacity(capacity),
nodes: Vec::with_capacity(capacity),
head: None,
tail: None,
free: Vec::new(),
}
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
pub fn contains(&self, key: u64) -> bool {
self.map.contains_key(&key)
}
/// Touch (move to MRU) if present. Returns whether the key was present.
pub fn touch(&mut self, key: u64) -> bool {
if let Some(&idx) = self.map.get(&key) {
self.move_to_head(idx);
true
} else {
false
}
}
/// Insert blocks; evicted hashes appended to `evicted_out`. Reinserting an
/// existing block just touches it.
pub fn insert_blocks(&mut self, hashes: &[u64], evicted_out: &mut Vec<u64>) {
for &h in hashes {
if self.touch(h) {
continue;
}
// need to make room?
if self.map.len() == self.capacity {
if let Some(tail_idx) = self.tail {
let tail_key = self.nodes[tail_idx].key;
self.detach(tail_idx);
self.map.remove(&tail_key);
self.free.push(tail_idx);
evicted_out.push(tail_key);
}
}
// allocate node
let idx = if let Some(i) = self.free.pop() {
self.nodes[i] = Node { key: h, prev: None, next: None };
i
} else {
let i = self.nodes.len();
self.nodes.push(Node { key: h, prev: None, next: None });
i
};
self.map.insert(h, idx);
self.attach_to_head(idx);
}
}
/// Longest leading prefix of `hashes` present; touches the matched blocks.
pub fn longest_prefix(&mut self, hashes: &[u64]) -> usize {
let mut n = 0usize;
for &h in hashes {
if !self.touch(h) {
break;
}
n += 1;
}
n
}
/// Read-only longest prefix without LRU updates (used for routing probes).
pub fn longest_prefix_peek(&self, hashes: &[u64]) -> usize {
let mut n = 0usize;
for &h in hashes {
if !self.map.contains_key(&h) {
break;
}
n += 1;
}
n
}
fn move_to_head(&mut self, idx: usize) {
if Some(idx) == self.head {
return;
}
self.detach(idx);
self.attach_to_head(idx);
}
fn detach(&mut self, idx: usize) {
let (prev, next) = {
let n = &self.nodes[idx];
(n.prev, n.next)
};
if let Some(p) = prev {
self.nodes[p].next = next;
} else {
// it was the head
self.head = next;
}
if let Some(nx) = next {
self.nodes[nx].prev = prev;
} else {
// it was the tail
self.tail = prev;
}
self.nodes[idx].prev = None;
self.nodes[idx].next = None;
}
fn attach_to_head(&mut self, idx: usize) {
let old_head = self.head;
self.nodes[idx].prev = None;
self.nodes[idx].next = old_head;
if let Some(h) = old_head {
self.nodes[h].prev = Some(idx);
}
self.head = Some(idx);
if self.tail.is_none() {
self.tail = Some(idx);
}
}
}
/// Two-tier (HBM, DRAM) cache.
#[derive(Debug)]
pub struct TwoTierCache {
pub l0: LruBlocks,
pub l1: LruBlocks,
}
impl TwoTierCache {
pub fn new(l0_cap: usize, l1_cap: usize) -> Self {
Self {
l0: LruBlocks::new(l0_cap),
l1: LruBlocks::new(l1_cap),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lcp_full_partial_empty() {
let mut c = LruBlocks::new(8);
let mut ev = Vec::new();
c.insert_blocks(&[1, 2, 3, 4], &mut ev);
assert_eq!(c.longest_prefix(&[1, 2, 3, 4, 5, 6]), 4);
assert_eq!(c.longest_prefix(&[1, 2, 9]), 2);
assert_eq!(c.longest_prefix(&[99, 1]), 0);
}
#[test]
fn lru_eviction_order() {
let mut c = LruBlocks::new(3);
let mut ev = Vec::new();
c.insert_blocks(&[1, 2, 3], &mut ev);
assert!(ev.is_empty());
// touch 1 -> MRU
c.touch(1);
// insert 4 -> evicts LRU which should be 2
c.insert_blocks(&[4], &mut ev);
assert_eq!(ev, vec![2]);
assert!(c.contains(1));
assert!(c.contains(3));
assert!(c.contains(4));
assert!(!c.contains(2));
}
#[test]
fn longest_prefix_touches_blocks() {
let mut c = LruBlocks::new(3);
let mut ev = Vec::new();
c.insert_blocks(&[1, 2, 3], &mut ev);
// touch 1 via prefix lookup (only the first matching block: 1)
assert_eq!(c.longest_prefix(&[1, 99]), 1);
// now insert 4 -> LRU should be 2 (since 3 was just inserted MRU after 2,
// 1 is freshest, then 3, then 2)
c.insert_blocks(&[4], &mut ev);
assert_eq!(ev, vec![2]);
}
}

6
src/instance/mod.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod compute;
pub mod kv_cache;
#[allow(clippy::module_inception)]
pub mod instance;
pub use instance::Instance;

13
src/lib.rs Normal file
View File

@@ -0,0 +1,13 @@
pub mod cluster;
pub mod config;
pub mod driver;
pub mod hardware_presets;
pub mod hf_config;
pub mod instance;
pub mod metrics;
pub mod network;
pub mod oracle;
pub mod router;
pub mod sim;
pub mod trace;
pub mod types;

271
src/main.rs Normal file
View File

@@ -0,0 +1,271 @@
use anyhow::{Context, Result};
use clap::{Args, Parser, Subcommand};
use std::path::PathBuf;
use kvcache_simulator::config::{Config, RouterMode};
use kvcache_simulator::{driver, oracle, trace::TraceReader};
#[derive(Debug, Parser)]
#[command(name = "kvcache-sim", about = "Cluster-level KV cache simulator")]
struct Cli {
#[command(subcommand)]
cmd: Cmd,
}
/// Optional CLI overrides applied on top of the YAML config so the same
/// config can be reused across sweeps without editing the file.
#[derive(Debug, Args, Clone, Default)]
struct ConfigOverrides {
/// Override `cluster.num_instances`.
#[arg(long)]
num_instances: Option<u32>,
/// Override `sim.max_requests` (cap on processed trace records).
#[arg(long)]
max_requests: Option<u64>,
/// Override `sim.trace_path`.
#[arg(long)]
trace: Option<PathBuf>,
/// Override `sim.output_dir`.
#[arg(long)]
output_dir: Option<PathBuf>,
/// Override `sim.seed`.
#[arg(long)]
seed: Option<u64>,
/// Override `cluster.router.precise_probe_topk`.
#[arg(long)]
precise_topk: Option<u32>,
/// Override `cluster.meta_store.ttl_seconds`.
#[arg(long)]
ttl_seconds: Option<f64>,
}
impl ConfigOverrides {
fn apply(&self, cfg: &mut Config) {
if let Some(n) = self.num_instances {
cfg.cluster.num_instances = n;
}
if let Some(m) = self.max_requests {
cfg.sim.max_requests = Some(m);
}
if let Some(t) = &self.trace {
cfg.sim.trace_path = t.to_string_lossy().into_owned();
}
if let Some(o) = &self.output_dir {
cfg.sim.output_dir = o.to_string_lossy().into_owned();
}
if let Some(s) = self.seed {
cfg.sim.seed = s;
}
if let Some(k) = self.precise_topk {
cfg.cluster.router.precise_probe_topk = k;
}
if let Some(ttl) = self.ttl_seconds {
cfg.cluster.meta_store.ttl_seconds = ttl;
}
}
}
#[derive(Debug, Subcommand)]
enum Cmd {
/// Run a single simulation with the router specified in the config.
Run {
#[arg(short, long)]
config: PathBuf,
#[command(flatten)]
overrides: ConfigOverrides,
},
/// Run the same trace under multiple routers and compare summaries.
Ablate {
#[arg(short, long)]
config: PathBuf,
/// Comma-separated router modes
#[arg(
short,
long,
default_value = "random,least_loaded,least_tokens,ttl_aware,min_pd,cache_load,cache_score,estimated_ttft,prefix_affinity"
)]
routers: String,
#[command(flatten)]
overrides: ConfigOverrides,
},
/// Parse the config and trace head; do not run a simulation.
Validate {
#[arg(short, long)]
config: PathBuf,
#[command(flatten)]
overrides: ConfigOverrides,
},
/// Offline oracle analysis: theoretical hit-rate ceilings (unlimited
/// cache and offline-optimal Belady eviction at finite capacity), plus
/// LRU at the same capacity for comparison.
Oracle {
#[arg(short, long)]
config: PathBuf,
#[command(flatten)]
overrides: ConfigOverrides,
/// Cache capacity (in 16-token blocks) used for the Belady and LRU
/// analyses. Defaults to `num_instances * per_instance_HBM_blocks`
/// (the cluster-aggregate capacity).
#[arg(long)]
capacity_blocks: Option<u64>,
/// Use the per-instance HBM block budget instead of the
/// cluster-aggregate. Mutually exclusive with --capacity-blocks.
#[arg(long, default_value_t = false)]
per_instance: bool,
/// Optional output JSON path. Defaults to `<output_dir>/oracle.json`.
#[arg(long)]
out: Option<PathBuf>,
},
}
fn main() -> Result<()> {
let cli = Cli::parse();
match cli.cmd {
Cmd::Run { config, overrides } => cmd_run(&config, &overrides),
Cmd::Ablate {
config,
routers,
overrides,
} => cmd_ablate(&config, &routers, &overrides),
Cmd::Validate { config, overrides } => cmd_validate(&config, &overrides),
Cmd::Oracle {
config,
overrides,
capacity_blocks,
per_instance,
out,
} => cmd_oracle(&config, &overrides, capacity_blocks, per_instance, out.as_deref()),
}
}
fn load(config: &PathBuf, overrides: &ConfigOverrides) -> Result<Config> {
let mut cfg = Config::from_yaml_path(config)?;
overrides.apply(&mut cfg);
Ok(cfg)
}
fn cmd_run(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> {
let cfg = load(path, overrides)?;
let out = driver::run(&cfg, None)?;
println!("{}", serde_json::to_string_pretty(&out.summary)?);
Ok(())
}
fn cmd_ablate(path: &PathBuf, routers: &str, overrides: &ConfigOverrides) -> Result<()> {
let base = load(path, overrides)?;
let modes: Vec<RouterMode> = routers
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(RouterMode::parse)
.collect::<Result<Vec<_>>>()
.with_context(|| format!("parsing --routers='{routers}'"))?;
let mut all = Vec::new();
for mode in modes {
let mut cfg = base.clone();
cfg.cluster.router.mode = mode;
let sub = mode.as_str().to_string();
eprintln!("[ablate] running router={}", sub);
let out = driver::run(&cfg, Some(&sub))?;
all.push(out.summary);
}
let agg_path = std::path::Path::new(&base.sim.output_dir).join("ablation.json");
std::fs::create_dir_all(&base.sim.output_dir)?;
std::fs::write(&agg_path, serde_json::to_string_pretty(&all)?)?;
println!("{}", serde_json::to_string_pretty(&all)?);
eprintln!("[ablate] wrote {}", agg_path.display());
Ok(())
}
fn cmd_validate(path: &PathBuf, overrides: &ConfigOverrides) -> Result<()> {
use kvcache_simulator::instance::compute::ComputeModel;
let cfg = load(path, overrides)?;
eprintln!("config OK: {}", cfg.model.name);
eprintln!("mode = {}", if cfg.model.is_arch_mode() { "architecture-derived" } else { "legacy manual" });
let cm = ComputeModel::new(&cfg.model, &cfg.hardware);
eprintln!("compute: {}", cm.describe());
eprintln!("kv_block_bytes = {} ({:.2} MB{})",
cfg.model.kv_block_bytes(),
cfg.model.kv_block_bytes() as f64 / 1e6,
if cfg.model.mla.is_some() { ", MLA compressed" } else { "" },
);
let block_bytes = cfg.model.kv_block_bytes() as f64;
let hbm_blocks = (cfg.hardware.hbm_bytes / block_bytes) as u64;
let dram_blocks = (cfg.hardware.dram_bytes / block_bytes) as u64;
eprintln!("per-instance HBM blocks = {hbm_blocks}, DRAM blocks = {dram_blocks}");
eprintln!("num_instances = {}", cfg.cluster.num_instances);
// Sample prefill times at a few prompt lengths.
eprintln!("prefill_time samples:");
for &n in &[256, 1024, 4096, 16384, 65536, 131072] {
let t = cm.prefill_time(n);
eprintln!(" N={n:>7} -> {t:.4} s");
}
let reader = TraceReader::open(&cfg.sim.trace_path, Some(5))?;
for rec in reader {
let rec = rec?;
eprintln!(
" req {} chat={} t={:.3}s in={} out={} blocks={}",
rec.req_id,
rec.chat_id,
rec.arrival,
rec.input_len,
rec.output_len,
rec.hash_ids.len()
);
}
Ok(())
}
fn cmd_oracle(
path: &PathBuf,
overrides: &ConfigOverrides,
capacity_blocks: Option<u64>,
per_instance: bool,
out_path: Option<&std::path::Path>,
) -> Result<()> {
let cfg = load(path, overrides)?;
let block_bytes = cfg.model.kv_block_bytes() as f64;
let per_instance_blocks = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as u64;
let aggregate_blocks = per_instance_blocks * cfg.cluster.num_instances as u64;
let capacity = match (capacity_blocks, per_instance) {
(Some(_), true) => {
return Err(anyhow::anyhow!(
"--capacity-blocks and --per-instance are mutually exclusive"
))
}
(Some(c), false) => c,
(None, true) => per_instance_blocks,
(None, false) => aggregate_blocks,
};
eprintln!(
"[oracle] loading trace {} (max_requests={:?})",
cfg.sim.trace_path, cfg.sim.max_requests
);
let reader = TraceReader::open(&cfg.sim.trace_path, cfg.sim.max_requests)?;
let records: Vec<_> = reader.collect::<Result<Vec<_>, _>>()?;
eprintln!(
"[oracle] loaded {} requests; analyzing with capacity = {} blocks \
({} per-instance × {} instances{})",
records.len(),
capacity,
per_instance_blocks,
cfg.cluster.num_instances,
if per_instance { ", per-instance mode" } else { "" }
);
let result = oracle::analyze(&records, capacity);
let json = serde_json::to_string_pretty(&result)?;
println!("{}", json);
let target = match out_path {
Some(p) => p.to_path_buf(),
None => std::path::Path::new(&cfg.sim.output_dir).join("oracle.json"),
};
if let Some(parent) = target.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(&target, &json)?;
eprintln!("[oracle] wrote {}", target.display());
Ok(())
}

7
src/metrics/mod.rs Normal file
View File

@@ -0,0 +1,7 @@
pub mod per_request;
pub mod routing_log;
pub mod summary;
pub mod timeseries;
pub use per_request::PerRequestRow;
pub use summary::Summary;

View File

@@ -0,0 +1,42 @@
use anyhow::Result;
use serde::Serialize;
use std::path::Path;
#[derive(Debug, Clone, Serialize)]
pub struct PerRequestRow {
pub req_id: u64,
pub arrival: f64,
pub ttft: f64,
pub e2e: f64,
pub instance: u32,
pub total_blocks: u32,
pub l0_hit_blocks: u32,
pub l1_hit_blocks: u32,
pub remote_hit_blocks: u32,
pub miss_blocks: u32,
pub rdma_bytes: u64,
pub pcie_bytes: u64,
pub probe_overhead_s: f64,
}
pub struct PerRequestWriter {
inner: csv::Writer<std::fs::File>,
}
impl PerRequestWriter {
pub fn create<P: AsRef<Path>>(path: P) -> Result<Self> {
let f = std::fs::File::create(path)?;
let inner = csv::Writer::from_writer(f);
Ok(Self { inner })
}
pub fn write(&mut self, row: &PerRequestRow) -> Result<()> {
self.inner.serialize(row)?;
Ok(())
}
pub fn finish(mut self) -> Result<()> {
self.inner.flush()?;
Ok(())
}
}

View File

@@ -0,0 +1,29 @@
use anyhow::Result;
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
use crate::router::RouteDecision;
pub struct RoutingLogWriter {
inner: BufWriter<File>,
}
impl RoutingLogWriter {
pub fn create<P: AsRef<Path>>(path: P) -> Result<Self> {
let f = File::create(path)?;
Ok(Self { inner: BufWriter::new(f) })
}
pub fn write(&mut self, decision: &RouteDecision) -> Result<()> {
let line = serde_json::to_string(decision)?;
self.inner.write_all(line.as_bytes())?;
self.inner.write_all(b"\n")?;
Ok(())
}
pub fn finish(mut self) -> Result<()> {
self.inner.flush()?;
Ok(())
}
}

80
src/metrics/summary.rs Normal file
View File

@@ -0,0 +1,80 @@
use serde::Serialize;
use crate::metrics::per_request::PerRequestRow;
#[derive(Debug, Clone, Serialize, Default)]
pub struct Summary {
pub router: String,
pub num_requests: u64,
pub sim_duration_s: f64,
pub throughput_req_per_s: f64,
pub ttft_mean: f64,
pub ttft_p50: f64,
pub ttft_p95: f64,
pub ttft_p99: f64,
pub e2e_mean: f64,
pub e2e_p50: f64,
pub e2e_p95: f64,
pub e2e_p99: f64,
pub total_blocks: u64,
pub hit_rate_l0: f64,
pub hit_rate_l1: f64,
pub hit_rate_remote: f64,
pub miss_rate: f64,
pub total_rdma_bytes: u64,
pub total_pcie_bytes: u64,
}
impl Summary {
pub fn from_rows(router: &str, rows: &[PerRequestRow], sim_duration_s: f64) -> Self {
if rows.is_empty() {
return Summary {
router: router.to_string(),
..Default::default()
};
}
let mut ttfts: Vec<f64> = rows.iter().map(|r| r.ttft).collect();
let mut e2es: Vec<f64> = rows.iter().map(|r| r.e2e).collect();
ttfts.sort_by(|a, b| a.partial_cmp(b).unwrap());
e2es.sort_by(|a, b| a.partial_cmp(b).unwrap());
let pct = |v: &[f64], q: f64| -> f64 {
let n = v.len();
let idx = ((n as f64 - 1.0) * q).round() as usize;
v[idx.min(n - 1)]
};
let mean = |v: &[f64]| -> f64 {
if v.is_empty() {
0.0
} else {
v.iter().sum::<f64>() / v.len() as f64
}
};
let total_blocks: u64 = rows.iter().map(|r| r.total_blocks as u64).sum();
let l0: u64 = rows.iter().map(|r| r.l0_hit_blocks as u64).sum();
let l1: u64 = rows.iter().map(|r| r.l1_hit_blocks as u64).sum();
let remote: u64 = rows.iter().map(|r| r.remote_hit_blocks as u64).sum();
let miss: u64 = rows.iter().map(|r| r.miss_blocks as u64).sum();
let denom = total_blocks.max(1) as f64;
Summary {
router: router.to_string(),
num_requests: rows.len() as u64,
sim_duration_s,
throughput_req_per_s: rows.len() as f64 / sim_duration_s.max(1e-9),
ttft_mean: mean(&ttfts),
ttft_p50: pct(&ttfts, 0.50),
ttft_p95: pct(&ttfts, 0.95),
ttft_p99: pct(&ttfts, 0.99),
e2e_mean: mean(&e2es),
e2e_p50: pct(&e2es, 0.50),
e2e_p95: pct(&e2es, 0.95),
e2e_p99: pct(&e2es, 0.99),
total_blocks,
hit_rate_l0: l0 as f64 / denom,
hit_rate_l1: l1 as f64 / denom,
hit_rate_remote: remote as f64 / denom,
miss_rate: miss as f64 / denom,
total_rdma_bytes: rows.iter().map(|r| r.rdma_bytes).sum(),
total_pcie_bytes: rows.iter().map(|r| r.pcie_bytes).sum(),
}
}
}

34
src/metrics/timeseries.rs Normal file
View File

@@ -0,0 +1,34 @@
use anyhow::Result;
use serde::Serialize;
use std::path::Path;
#[derive(Debug, Clone, Serialize)]
pub struct TimeseriesRow {
pub t: f64,
pub instance: u32,
pub queue_len: u32,
pub kv_blocks_used: u32,
pub kv_blocks_total: u32,
pub busy: u8,
}
pub struct TimeseriesWriter {
inner: csv::Writer<std::fs::File>,
}
impl TimeseriesWriter {
pub fn create<P: AsRef<Path>>(path: P) -> Result<Self> {
let f = std::fs::File::create(path)?;
Ok(Self { inner: csv::Writer::from_writer(f) })
}
pub fn write(&mut self, row: &TimeseriesRow) -> Result<()> {
self.inner.serialize(row)?;
Ok(())
}
pub fn finish(mut self) -> Result<()> {
self.inner.flush()?;
Ok(())
}
}

84
src/network.rs Normal file
View File

@@ -0,0 +1,84 @@
//! Network cost models for RDMA (cross-instance) and PCIe (host<->GPU).
//!
//! Each link is modeled as a token bucket via a `next_free` cursor: a fetch of
//! `bytes` starting at `now` waits until `next_free`, then advances the cursor
//! by `bytes / bw`. Latency is added on top of transfer time. This captures
//! contention without simulating individual packets.
use crate::config::HardwareConfig;
#[derive(Debug, Clone)]
pub struct LinkModel {
pub bw_bytes_per_s: f64,
pub latency_s: f64,
next_free: f64,
}
impl LinkModel {
pub fn new(bw_bytes_per_s: f64, latency_s: f64) -> Self {
Self {
bw_bytes_per_s,
latency_s,
next_free: 0.0,
}
}
/// Reserve a transfer of `bytes` starting at `now`. Returns the absolute
/// time at which the bytes have all arrived (advances internal cursor).
pub fn reserve(&mut self, now: f64, bytes: u64) -> f64 {
if bytes == 0 {
return now + self.latency_s;
}
let xfer = bytes as f64 / self.bw_bytes_per_s;
let start = self.next_free.max(now);
self.next_free = start + xfer;
self.next_free + self.latency_s
}
/// Pure cost (no contention): how long to push `bytes` over this link.
pub fn cost(&self, bytes: u64) -> f64 {
if bytes == 0 {
self.latency_s
} else {
self.latency_s + bytes as f64 / self.bw_bytes_per_s
}
}
}
/// Per-instance bundle of links: PCIe (host<->GPU) and RDMA (host<->remote).
#[derive(Debug, Clone)]
pub struct InstanceLinks {
pub pcie: LinkModel,
pub rdma: LinkModel,
}
impl InstanceLinks {
pub fn from_hw(hw: &HardwareConfig) -> Self {
Self {
pcie: LinkModel::new(hw.pcie_bw, hw.pcie_latency_us * 1e-6),
rdma: LinkModel::new(hw.rdma_bw, hw.rdma_latency_us * 1e-6),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn link_cost_matches_formula() {
let l = LinkModel::new(1.0e9, 1.0e-6);
// 1 GB / (1 GB/s) = 1s, plus 1us latency
let t = l.cost(1_000_000_000);
assert!((t - (1.0 + 1e-6)).abs() < 1e-9);
}
#[test]
fn reserve_serializes_concurrent_transfers() {
let mut l = LinkModel::new(1.0e9, 0.0);
let t1 = l.reserve(0.0, 500_000_000); // 0.5s
let t2 = l.reserve(0.0, 500_000_000); // contended -> 1.0s
assert!((t1 - 0.5).abs() < 1e-9);
assert!((t2 - 1.0).abs() < 1e-9);
}
}

279
src/oracle.rs Normal file
View File

@@ -0,0 +1,279 @@
//! Offline oracle analyzers for upper-bound KV-cache hit rates.
//!
//! Two analyses, both treating the cluster as a single aggregated cache so
//! the result is independent of routing — i.e. they answer the question
//! "what is the best the cluster could possibly do?":
//!
//! 1. **Unlimited capacity**: longest-prefix-match against an unbounded
//! cache. The only misses are blocks that the prefix walk encounters for
//! the first time. Sets the absolute ceiling.
//!
//! 2. **Belady (offline optimal eviction) at finite capacity**: classic
//! OPT replacement — evict the cached block whose *next* access is
//! furthest in the future. Run alongside an LRU baseline at the same
//! capacity so the gap tells you how much room LRU is leaving.
//!
//! Hit accounting uses prefix-match semantics matching the rest of the
//! simulator: a block at position k in a request counts as a hit only if
//! all positions 0..k are also in the cache.
use ahash::{AHashMap, AHashSet};
use serde::Serialize;
use std::collections::BinaryHeap;
use crate::instance::kv_cache::LruBlocks;
use crate::trace::RequestRecord;
#[derive(Debug, Clone, Serialize)]
pub struct OracleResult {
pub num_requests: u64,
pub total_blocks: u64,
pub unique_blocks: u64,
pub unlimited: TierResult,
pub belady_finite: TierResult,
pub lru_finite: TierResult,
}
#[derive(Debug, Clone, Serialize, Default)]
pub struct TierResult {
pub label: String,
pub capacity_blocks: u64,
pub hits: u64,
pub misses: u64,
pub hit_rate: f64,
}
impl TierResult {
fn from_counts(label: &str, capacity_blocks: u64, hits: u64, total: u64) -> Self {
let misses = total.saturating_sub(hits);
TierResult {
label: label.to_string(),
capacity_blocks,
hits,
misses,
hit_rate: if total == 0 { 0.0 } else { hits as f64 / total as f64 },
}
}
}
pub fn analyze(records: &[RequestRecord], capacity_blocks: u64) -> OracleResult {
// total / unique counters
let total_blocks: u64 = records.iter().map(|r| r.hash_ids.len() as u64).sum();
let mut unique = AHashSet::new();
for r in records {
for &h in &r.hash_ids {
unique.insert(h);
}
}
// 1. Unlimited cache
let unlimited_hits = run_unlimited(records);
let unlimited = TierResult::from_counts(
"unlimited",
u64::MAX,
unlimited_hits,
total_blocks,
);
// 2. Precompute next-use index for Belady
let next_use = build_next_use(records);
// 3. Belady at the given capacity
let belady_hits = run_belady(records, &next_use, capacity_blocks as usize);
let belady = TierResult::from_counts("belady", capacity_blocks, belady_hits, total_blocks);
// 4. LRU baseline at the same capacity
let lru_hits = run_lru(records, capacity_blocks as usize);
let lru = TierResult::from_counts("lru", capacity_blocks, lru_hits, total_blocks);
OracleResult {
num_requests: records.len() as u64,
total_blocks,
unique_blocks: unique.len() as u64,
unlimited,
belady_finite: belady,
lru_finite: lru,
}
}
fn run_unlimited(records: &[RequestRecord]) -> u64 {
let mut seen: AHashSet<u64> = AHashSet::with_capacity(1 << 18);
let mut hits: u64 = 0;
for r in records {
// Longest prefix match against `seen`
for &h in &r.hash_ids {
if seen.contains(&h) {
hits += 1;
} else {
break;
}
}
for &h in &r.hash_ids {
seen.insert(h);
}
}
hits
}
fn run_lru(records: &[RequestRecord], capacity: usize) -> u64 {
if capacity == 0 {
return 0;
}
let mut cache = LruBlocks::new(capacity);
let mut hits: u64 = 0;
let mut evicted = Vec::new();
for r in records {
hits += cache.longest_prefix(&r.hash_ids) as u64;
evicted.clear();
cache.insert_blocks(&r.hash_ids, &mut evicted);
}
hits
}
/// For each (request_idx, position_in_hash_ids) compute the next request
/// index whose `hash_ids` contains the same block (`u32::MAX` if none).
fn build_next_use(records: &[RequestRecord]) -> Vec<Vec<u32>> {
let n = records.len();
let mut next_use: Vec<Vec<u32>> = Vec::with_capacity(n);
for r in records {
next_use.push(vec![u32::MAX; r.hash_ids.len()]);
}
let mut last_seen: AHashMap<u64, u32> = AHashMap::with_capacity(1 << 18);
for i in (0..n).rev() {
let r = &records[i];
for (j, &h) in r.hash_ids.iter().enumerate() {
next_use[i][j] = *last_seen.get(&h).unwrap_or(&u32::MAX);
}
for &h in &r.hash_ids {
last_seen.insert(h, i as u32);
}
}
next_use
}
/// Belady (offline OPT) eviction over the trace.
///
/// Implementation: lazy-deletion max-heap keyed by next-use index. Each
/// cache entry has a version; the heap may contain stale entries from
/// previous insertions, which we skip on pop.
fn run_belady(records: &[RequestRecord], next_use: &[Vec<u32>], capacity: usize) -> u64 {
if capacity == 0 {
return 0;
}
// block_hash -> (current_version, current_next_use)
let mut in_cache: AHashMap<u64, (u64, u32)> = AHashMap::with_capacity(capacity);
// (next_use, version, block_hash) — BinaryHeap is max-heap, which is what
// we want for "evict the entry whose next access is furthest".
let mut heap: BinaryHeap<(u32, u64, u64)> = BinaryHeap::with_capacity(capacity);
let mut version: u64 = 0;
let mut hits: u64 = 0;
for (i, r) in records.iter().enumerate() {
// 1. Longest-prefix hit accounting against current cache.
for &h in &r.hash_ids {
if in_cache.contains_key(&h) {
hits += 1;
} else {
break;
}
}
// 2. Insert / update each block in the request with its new next-use.
for (j, &h) in r.hash_ids.iter().enumerate() {
let nu = next_use[i][j];
if let Some(slot) = in_cache.get_mut(&h) {
version += 1;
slot.0 = version;
slot.1 = nu;
heap.push((nu, version, h));
continue;
}
// Need to make room?
if in_cache.len() == capacity {
// Evict max next_use entry, skipping stale heap entries.
loop {
let (nu_top, ver_top, h_top) = match heap.pop() {
Some(x) => x,
None => break,
};
if let Some(&(cur_ver, cur_nu)) = in_cache.get(&h_top) {
if cur_ver == ver_top && cur_nu == nu_top {
in_cache.remove(&h_top);
break;
}
}
// stale; loop
}
}
version += 1;
in_cache.insert(h, (version, nu));
heap.push((nu, version, h));
}
}
hits
}
#[cfg(test)]
mod tests {
use super::*;
fn req(id: u64, t: f64, hashes: Vec<u64>) -> RequestRecord {
RequestRecord {
req_id: id,
chat_id: id as i64,
arrival: t,
input_len: (hashes.len() as u32) * 16,
output_len: 16,
hash_ids: hashes,
}
}
#[test]
fn unlimited_first_occurrence_misses() {
let recs = vec![
req(0, 0.0, vec![1, 2, 3]),
req(1, 1.0, vec![1, 2, 3, 4]),
req(2, 2.0, vec![1, 2, 3, 4, 5]),
];
let out = analyze(&recs, 100);
// total = 3 + 4 + 5 = 12
assert_eq!(out.total_blocks, 12);
// unique = {1,2,3,4,5} = 5
assert_eq!(out.unique_blocks, 5);
// unlimited hits = 0 (req 0 all miss) + 3 (req 1 has [1,2,3] cached, then 4 miss) + 4
assert_eq!(out.unlimited.hits, 7);
assert!((out.unlimited.hit_rate - 7.0 / 12.0).abs() < 1e-9);
}
#[test]
fn belady_beats_lru_when_lru_thrashes() {
// Capacity 2. Pattern designed so LRU thrashes but Belady keeps the
// useful block: A B A C A B A C A ...
let mut recs = Vec::new();
let pattern = [1u64, 2, 1, 3, 1, 2, 1, 3];
for (i, &h) in pattern.iter().enumerate() {
recs.push(req(i as u64, i as f64, vec![h]));
}
let out = analyze(&recs, 2);
assert!(
out.belady_finite.hits >= out.lru_finite.hits,
"belady should be at least as good as lru: belady={} lru={}",
out.belady_finite.hits,
out.lru_finite.hits
);
}
#[test]
fn unlimited_is_upper_bound() {
let recs = vec![
req(0, 0.0, vec![10, 20, 30]),
req(1, 1.0, vec![10, 20, 30, 40, 50]),
req(2, 2.0, vec![60]),
req(3, 3.0, vec![10, 20, 30, 40, 50, 60]),
];
let out = analyze(&recs, 3);
assert!(out.unlimited.hit_rate >= out.belady_finite.hit_rate);
assert!(out.belady_finite.hit_rate >= out.lru_finite.hit_rate - 1e-9);
}
}

89
src/router/cache_load.rs Normal file
View File

@@ -0,0 +1,89 @@
//! Load-filtered cache-aware routing.
//!
//! **Step 1** — filter: sort all instances by `queue_len` ascending and take the
//! least-loaded quarter (≥ 2 instances).
//!
//! **Step 2** — select: among that pool, pick the instance with the highest
//! meta-store prefix score. Tiebreak on lowest `queue_len`.
//!
//! This cleanly separates concerns: step 1 guarantees the request won't land
//! on a saturated instance, while step 2 maximises cache reuse within the
//! load-safe pool. The 1/4 fraction keeps the pool large enough that good
//! cache candidates are rarely excluded.
use crate::cluster::meta_store::MetaStore;
use crate::instance::Instance;
use crate::router::{CandidateInfo, RouteDecision, Router};
use crate::trace::RequestRecord;
pub struct CacheLoadRouter;
impl CacheLoadRouter {
pub fn new() -> Self {
Self
}
}
impl Default for CacheLoadRouter {
fn default() -> Self {
Self::new()
}
}
impl Router for CacheLoadRouter {
fn name(&self) -> &'static str {
"cache_load"
}
fn route(
&mut self,
req: &RequestRecord,
instances: &[Instance],
meta: &MetaStore,
now: f64,
) -> RouteDecision {
let n = instances.len();
let scores = meta.score_prefix(&req.hash_ids, now, n);
// Step 1: least-loaded 1/4 of instances (by queue_len).
let pool_size = (n / 4).max(2).min(n);
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by_key(|&i| instances[i].queue_len());
let pool = &indices[..pool_size];
// Step 2: among the pool, pick highest prefix score.
// Tiebreak: lowest queue_len.
let mut best_idx = pool[0];
let mut best_prefix = scores[pool[0]];
let mut best_queue = instances[pool[0]].queue_len();
for &i in &pool[1..] {
let p = scores[i];
let q = instances[i].queue_len();
if p > best_prefix || (p == best_prefix && q < best_queue) {
best_idx = i;
best_prefix = p;
best_queue = q;
}
}
let mut candidates = Vec::with_capacity(pool_size);
for &i in pool {
candidates.push(CandidateInfo {
instance: instances[i].id,
predicted_prefix: scores[i],
load_blocks: instances[i].kv_blocks_used,
queue_len: instances[i].queue_len(),
});
}
RouteDecision {
req_id: req.req_id,
mode: "cache_load",
chosen: instances[best_idx].id,
probe_overhead_s: 0.0,
candidates,
reason: "least-loaded 1/4, then best prefix",
}
}
}

111
src/router/cache_score.rs Normal file
View File

@@ -0,0 +1,111 @@
//! Combined-score cache-aware routing with exponential weighting.
//!
//! Each instance is scored by:
//!
//! ```text
//! score_i = 2^(α · load_i + β · miss_i)
//! ```
//!
//! where
//!
//! - `load_i = queue_len()` — requests pending or prefilling on instance i,
//! - `miss_i = input_blocks prefix_blocks` — cache-miss blocks,
//! - `α` = `score_alpha` (YAML config, default 1.0),
//! - `β` = `score_beta` (YAML config, default 0.1).
//!
//! The instance with the **lowest** score is chosen. Since `2^x` is
//! monotonic, this is equivalent to minimising the linear exponent
//! `α·load + β·miss`, but the exponential framing highlights that
//! differences are amplified exponentially — a small edge in the exponent
//! creates a large gap in score.
//!
//! **Tuning guide**:
//!
//! - `α` controls how aggressively load is penalised.
//! - `β` controls how aggressively cache misses are penalised.
//! - Ratio `α/β` is what matters: higher → more load-sensitive.
//! - Defaults (`α=1.0, β=0.1`): 1 extra queue position ≈ 10 extra miss
//! blocks, which is a good starting point when block_size is large (512)
//! and queues are short (010).
//!
//! Ties are broken by fewest `queue_len`, then highest `prefix`.
use crate::cluster::meta_store::MetaStore;
use crate::instance::Instance;
use crate::router::{CandidateInfo, RouteDecision, Router};
use crate::trace::RequestRecord;
pub struct CacheScoreRouter {
alpha: f64,
beta: f64,
}
impl CacheScoreRouter {
pub fn new(alpha: f64, beta: f64) -> Self {
Self { alpha, beta }
}
}
impl Router for CacheScoreRouter {
fn name(&self) -> &'static str {
"cache_score"
}
fn route(
&mut self,
req: &RequestRecord,
instances: &[Instance],
meta: &MetaStore,
now: f64,
) -> RouteDecision {
let n = instances.len();
let scores = meta.score_prefix(&req.hash_ids, now, n);
let input_blocks = req.hash_ids.len() as f64;
let mut best_idx: usize = 0;
let mut best_exp = f64::INFINITY;
let mut best_queue = u32::MAX;
let mut best_prefix = 0u32;
let mut candidates = Vec::with_capacity(n);
for (i, inst) in instances.iter().enumerate() {
let prefix = scores[i] as f64;
let miss = (input_blocks - prefix).max(0.0);
let q = inst.queue_len() as f64;
// Minimise the exponent: α·load + β·miss
// (equivalent to minimising 2^exponent)
let exponent = self.alpha * q + self.beta * miss;
candidates.push(CandidateInfo {
instance: inst.id,
predicted_prefix: scores[i],
load_blocks: inst.kv_blocks_used,
queue_len: inst.queue_len(),
});
// Minimise (exponent, queue_len ASC, prefix DESC).
let better = exponent < best_exp
|| (exponent == best_exp && inst.queue_len() < best_queue)
|| (exponent == best_exp
&& inst.queue_len() == best_queue
&& scores[i] > best_prefix);
if better {
best_exp = exponent;
best_idx = i;
best_queue = inst.queue_len();
best_prefix = scores[i];
}
}
RouteDecision {
req_id: req.req_id,
mode: "cache_score",
chosen: instances[best_idx].id,
probe_overhead_s: 0.0,
candidates,
reason: "argmin 2^(α·load + β·miss)",
}
}
}

View File

@@ -0,0 +1,128 @@
//! First-principles TTFT-optimal routing.
//!
//! Estimates the actual time-to-first-token for each candidate instance:
//!
//! `TTFT(r,i) = drain(i) + fetch(r,i) + prefill(miss)`
//!
//! - **drain** — exact queue drain time: sum of per-request `prefill_time()`
//! using the architecture-aware compute model (quadratic / DSA).
//!
//! - **fetch** — RDMA fetch time for blocks cached elsewhere in the cluster
//! but not on instance `i` locally.
//!
//! - **prefill** — compute for cluster-wide cache-miss tokens (constant
//! across instances, cancels in the argmin).
//!
//! The router minimises `drain(i) + fetch(r,i)`, with ties broken by
//! lowest `queue_len` then most local cache. The fetch overlap with queue
//! drain is handled by keeping the additive form: this gives double
//! incentive to prefer instances with local cache, which empirically
//! outperforms the `max(drain, fetch)` alternative because even small
//! RDMA savings compound across thousands of routing decisions.
use crate::cluster::meta_store::MetaStore;
use crate::config::Config;
use crate::instance::Instance;
use crate::router::{CandidateInfo, RouteDecision, Router};
use crate::trace::RequestRecord;
pub struct EstimatedTtftRouter {
/// Bytes per KV block (for RDMA cost estimation).
kv_block_bytes: f64,
/// RDMA bandwidth in bytes/s.
rdma_bw: f64,
/// RDMA per-transfer latency in seconds.
rdma_latency_s: f64,
}
impl EstimatedTtftRouter {
pub fn new(config: &Config) -> Self {
Self {
kv_block_bytes: config.model.kv_block_bytes() as f64,
rdma_bw: config.hardware.rdma_bw,
rdma_latency_s: config.hardware.rdma_latency_us * 1e-6,
}
}
/// Estimate RDMA fetch time for `remote_blocks` blocks.
fn fetch_time(&self, remote_blocks: u32) -> f64 {
if remote_blocks == 0 {
return 0.0;
}
let bytes = remote_blocks as f64 * self.kv_block_bytes;
bytes / self.rdma_bw + self.rdma_latency_s
}
}
impl Router for EstimatedTtftRouter {
fn name(&self) -> &'static str {
"estimated_ttft"
}
fn route(
&mut self,
req: &RequestRecord,
instances: &[Instance],
meta: &MetaStore,
now: f64,
) -> RouteDecision {
let n = instances.len();
let scores = meta.score_prefix(&req.hash_ids, now, n);
// Cluster-wide max prefix: blocks reachable via RDMA from any peer.
let cluster_prefix = scores.iter().copied().max().unwrap_or(0);
let mut best: u32 = 0;
let mut best_cost = f64::INFINITY;
let mut best_queue = u32::MAX;
let mut best_local = 0u32;
let mut candidates = Vec::with_capacity(n);
for inst in instances {
let i = inst.id as usize;
let local_prefix = scores[i];
// 1. Exact queue drain time (architecture-aware, per-request sum).
let drain = inst.estimated_drain_time();
// 2. RDMA fetch cost for blocks not locally cached.
let remote_blocks = cluster_prefix.saturating_sub(local_prefix);
let fetch = self.fetch_time(remote_blocks);
// Additive cost: drain + fetch.
// The additive form gives explicit incentive to prefer local cache
// (lower fetch) even when the queue is non-empty, which reduces
// total RDMA traffic and improves TTFT in aggregate.
let cost = drain + fetch;
candidates.push(CandidateInfo {
instance: inst.id,
predicted_prefix: local_prefix,
load_blocks: inst.kv_blocks_used,
queue_len: inst.queue_len(),
});
// Minimise (cost, queue_len, -local_prefix).
let ql = inst.queue_len();
let better = cost < best_cost
|| (cost == best_cost && ql < best_queue)
|| (cost == best_cost && ql == best_queue && local_prefix > best_local);
if better {
best_cost = cost;
best = inst.id;
best_queue = ql;
best_local = local_prefix;
}
}
RouteDecision {
req_id: req.req_id,
mode: "estimated_ttft",
chosen: best,
probe_overhead_s: 0.0,
candidates,
reason: "argmin(drain_time + fetch_time)",
}
}
}

View File

@@ -0,0 +1,54 @@
use crate::cluster::meta_store::MetaStore;
use crate::instance::Instance;
use crate::router::{CandidateInfo, RouteDecision, Router};
use crate::trace::RequestRecord;
pub struct LeastLoadedRouter {
pub alpha: f64,
}
impl LeastLoadedRouter {
pub fn new(alpha: f64) -> Self {
Self { alpha }
}
}
impl Router for LeastLoadedRouter {
fn name(&self) -> &'static str {
"least_loaded"
}
fn route(
&mut self,
req: &RequestRecord,
instances: &[Instance],
_meta: &MetaStore,
_now: f64,
) -> RouteDecision {
let mut best = 0u32;
let mut best_score = f64::INFINITY;
let mut candidates = Vec::with_capacity(instances.len());
for inst in instances {
let load = inst.kv_blocks_used as f64
+ self.alpha * inst.queue_len() as f64;
candidates.push(CandidateInfo {
instance: inst.id,
predicted_prefix: 0,
load_blocks: inst.kv_blocks_used,
queue_len: inst.queue_len(),
});
if load < best_score {
best_score = load;
best = inst.id;
}
}
RouteDecision {
req_id: req.req_id,
mode: "least_loaded",
chosen: best,
probe_overhead_s: 0.0,
candidates,
reason: "argmin(kv_used + alpha * queue_len)",
}
}
}

View File

@@ -0,0 +1,73 @@
//! Least-waiting-tokens routing.
//!
//! Pure load-balancing baseline that picks the instance with the fewest
//! total prefill tokens remaining across its pending and prefilling queues.
//! Unlike `least_loaded` (which mixes KV memory pressure with queue depth),
//! this directly minimises the expected wait time by accounting for the
//! actual compute backlog in tokens.
//!
//! Tiebreak: fewest `queue_len`, then lowest instance ID.
use crate::cluster::meta_store::MetaStore;
use crate::instance::Instance;
use crate::router::{CandidateInfo, RouteDecision, Router};
use crate::trace::RequestRecord;
pub struct LeastTokensRouter;
impl LeastTokensRouter {
pub fn new() -> Self {
Self
}
}
impl Default for LeastTokensRouter {
fn default() -> Self {
Self::new()
}
}
impl Router for LeastTokensRouter {
fn name(&self) -> &'static str {
"least_tokens"
}
fn route(
&mut self,
req: &RequestRecord,
instances: &[Instance],
_meta: &MetaStore,
_now: f64,
) -> RouteDecision {
let mut best: u32 = 0;
let mut best_key: (u64, u32) = (u64::MAX, u32::MAX);
let mut candidates = Vec::with_capacity(instances.len());
for inst in instances {
let wt = inst.waiting_tokens();
let ql = inst.queue_len();
candidates.push(CandidateInfo {
instance: inst.id,
predicted_prefix: 0,
load_blocks: inst.kv_blocks_used,
queue_len: ql,
});
let key = (wt, ql);
if key < best_key {
best_key = key;
best = inst.id;
}
}
RouteDecision {
req_id: req.req_id,
mode: "least_tokens",
chosen: best,
probe_overhead_s: 0.0,
candidates,
reason: "argmin(waiting_prefill_tokens)",
}
}
}

124
src/router/min_pd.rs Normal file
View File

@@ -0,0 +1,124 @@
//! Minimum P*D routing.
//!
//! For each instance compute:
//! - `P` = real prefill tokens this request will do if routed there
//! - `D` = ongoing requests currently on that instance
//! (pending + prefilling)
//!
//! Score = `P * D`, pick the instance that minimizes it.
//!
//! `P` accounts for the **actual** prefill work after the cluster fetch
//! chain runs: the fetch chain serves any block cached anywhere in the
//! cluster (L0 → L1 → remote v6d via RDMA), so prefill compute only runs
//! for blocks that are absent cluster-wide *and* for blocks past the
//! instance-local prefix (the cluster only fetches a contiguous leading
//! prefix — any gap ends the fetch chain and the rest must be prefilled).
//!
//! Concretely, for instance `i`:
//!
//! ```text
//! local_prefix_i = meta_store.score_prefix(req, now)[i] // blocks
//! cluster_prefix = max over all j of meta_store_score[j] // blocks
//! effective_prefix_i = min(cluster_prefix, input_blocks)
//! - if local_prefix_i == cluster_prefix the fetch chain stays local,
//! - otherwise the prefill still skips cluster_prefix blocks because
//! the missing tail is fetched via RDMA from a peer.
//! P_i = (input_blocks - effective_prefix_i) * block_size_tokens
//! ```
//!
//! This makes `P` nearly instance-independent on well-populated clusters
//! (so `min_pd` degenerates to balanced load with a cache-affinity
//! tiebreak), which is exactly what you want when RDMA is cheap relative
//! to prefill compute.
//!
//! Tiebreaks (essential on 128-instance clusters where many instances are
//! idle and the raw product collapses to zero):
//! 1. minimum `P*D`
//! 2. then minimum `D` — prefer the less-loaded instance
//! 3. then maximum `local_prefix_i` — prefer local affinity to avoid
//! paying the RDMA fetch cost when P and D are already tied
use crate::cluster::meta_store::MetaStore;
use crate::instance::Instance;
use crate::router::{CandidateInfo, RouteDecision, Router};
use crate::trace::RequestRecord;
pub struct MinPdRouter;
impl MinPdRouter {
pub fn new() -> Self {
Self
}
}
impl Default for MinPdRouter {
fn default() -> Self {
Self::new()
}
}
impl Router for MinPdRouter {
fn name(&self) -> &'static str {
"min_pd"
}
fn route(
&mut self,
req: &RequestRecord,
instances: &[Instance],
meta: &MetaStore,
now: f64,
) -> RouteDecision {
let n = instances.len();
let scores = meta.score_prefix(&req.hash_ids, now, n);
let block_size = instances[0].block_size_tokens as u64;
let input_blocks = req.hash_ids.len() as u64;
// Cluster-wide max prefix: longest contiguous prefix that EXISTS
// somewhere in the cluster (and will be fetched via remote RDMA if
// not local). This determines the effective prefill work for every
// candidate, not just the one that owns the blocks.
let cluster_prefix_blocks = scores.iter().copied().max().unwrap_or(0) as u64;
let effective_prefix_blocks = cluster_prefix_blocks.min(input_blocks);
let miss_blocks = input_blocks.saturating_sub(effective_prefix_blocks);
let p_base = miss_blocks.saturating_mul(block_size); // tokens to prefill
let mut candidates = Vec::with_capacity(n);
let mut best: u32 = instances[0].id;
// Minimize (P*D, D, -local_prefix).
// P is nearly instance-independent; D is the real discriminator.
// When tied on D, prefer the instance with the best local prefix
// (avoids the RDMA fetch cost).
let mut best_key: (u128, u64, i64) = (u128::MAX, u64::MAX, i64::MAX);
for inst in instances {
let i = inst.id as usize;
let d = inst.queue_len() as u64;
let pd = p_base as u128 * d as u128;
let local_prefix = scores[i] as i64;
candidates.push(CandidateInfo {
instance: inst.id,
predicted_prefix: scores[i],
load_blocks: inst.kv_blocks_used,
queue_len: inst.queue_len(),
});
// minimize (pd, d, -local_prefix)
let key = (pd, d, -local_prefix);
if key < best_key {
best_key = key;
best = inst.id;
}
}
RouteDecision {
req_id: req.req_id,
mode: "min_pd",
chosen: best,
probe_overhead_s: 0.0,
candidates,
reason: "argmin(P*D), P=cluster-wide miss tokens, D=ongoing reqs",
}
}
}

80
src/router/mod.rs Normal file
View File

@@ -0,0 +1,80 @@
//! Cluster-level routing strategies.
pub mod cache_load;
pub mod cache_score;
pub mod estimated_ttft;
pub mod least_loaded;
pub mod least_tokens;
pub mod min_pd;
pub mod precise_aware;
pub mod prefix_affinity;
pub mod random;
pub mod ttl_aware;
use serde::Serialize;
use crate::cluster::meta_store::MetaStore;
use crate::config::Config;
use crate::instance::Instance;
use crate::trace::RequestRecord;
use crate::types::InstanceId;
#[derive(Debug, Clone, Serialize)]
pub struct CandidateInfo {
pub instance: InstanceId,
pub predicted_prefix: u32,
pub load_blocks: u32,
pub queue_len: u32,
}
#[derive(Debug, Clone, Serialize)]
pub struct RouteDecision {
pub req_id: u64,
pub mode: &'static str,
pub chosen: InstanceId,
pub probe_overhead_s: f64,
pub candidates: Vec<CandidateInfo>,
pub reason: &'static str,
}
pub trait Router: Send {
fn name(&self) -> &'static str;
fn route(
&mut self,
req: &RequestRecord,
instances: &[Instance],
meta: &MetaStore,
now: f64,
) -> RouteDecision;
}
pub fn build(full: &Config, seed: u64) -> Box<dyn Router> {
use crate::config::RouterMode::*;
let cfg = &full.cluster.router;
match cfg.mode {
Random => Box::new(random::RandomRouter::new(seed)) as Box<dyn Router>,
RoundRobin => Box::new(random::RoundRobinRouter::new()) as Box<dyn Router>,
LeastLoaded => {
Box::new(least_loaded::LeastLoadedRouter::new(cfg.load_alpha)) as Box<dyn Router>
}
TtlAware => Box::new(ttl_aware::TtlAwareRouter::new(cfg.load_alpha)) as Box<dyn Router>,
Precise => Box::new(precise_aware::PreciseRouter::new(
cfg.precise_probe_topk,
cfg.precise_probe_latency_us * 1e-6,
cfg.load_alpha,
)) as Box<dyn Router>,
MinPd => Box::new(min_pd::MinPdRouter::new()) as Box<dyn Router>,
LeastTokens => Box::new(least_tokens::LeastTokensRouter::new()) as Box<dyn Router>,
CacheLoad => Box::new(cache_load::CacheLoadRouter::new()) as Box<dyn Router>,
CacheScore => {
Box::new(cache_score::CacheScoreRouter::new(cfg.score_alpha, cfg.score_beta))
as Box<dyn Router>
}
EstimatedTtft => {
Box::new(estimated_ttft::EstimatedTtftRouter::new(full)) as Box<dyn Router>
}
PrefixAffinity => {
Box::new(prefix_affinity::PrefixAffinityRouter::new(full)) as Box<dyn Router>
}
}
}

120
src/router/precise_aware.rs Normal file
View File

@@ -0,0 +1,120 @@
//! KV-aware routing via meta-store candidate selection + precise probing.
//!
//! The global meta store is used as a *candidate pre-filter*: we score
//! every instance's predicted prefix from the store, take the top-K by
//! (predicted_prefix DESC, load ASC), and then exact-probe those K
//! candidates' actual L0+L1 caches to get the true longest prefix. This
//! catches two cases where the meta store is wrong:
//!
//! - the store is stale (block evicted from L0/L1 but TTL not yet up),
//! - the store undercounts because some blocks' TTL expired individually.
//!
//! Because the candidate set is sourced from the meta store rather than
//! from a load ranking, this router is a strict superset of `ttl_aware`:
//! any instance the meta store would pick is a candidate here, and the
//! exact probe can only move the decision toward a truthfully-better
//! instance. Each probe adds `probe_latency_s` to the request's
//! effective arrival time.
//!
//! If the meta store returns zero-prefix for every instance (e.g. cold
//! start, or a request whose blocks have never been seen), we fall back
//! to the top-K least-loaded instances so we still place the request.
use crate::cluster::meta_store::MetaStore;
use crate::instance::Instance;
use crate::router::{CandidateInfo, RouteDecision, Router};
use crate::trace::RequestRecord;
pub struct PreciseRouter {
pub topk: u32,
pub probe_latency_s: f64,
pub alpha: f64,
}
impl PreciseRouter {
pub fn new(topk: u32, probe_latency_s: f64, alpha: f64) -> Self {
Self { topk, probe_latency_s, alpha }
}
fn load_of(&self, inst: &Instance) -> f64 {
inst.kv_blocks_used as f64 + self.alpha * inst.queue_len() as f64
}
}
impl Router for PreciseRouter {
fn name(&self) -> &'static str {
"precise"
}
fn route(
&mut self,
req: &RequestRecord,
instances: &[Instance],
meta: &MetaStore,
now: f64,
) -> RouteDecision {
let n = instances.len();
let k = (self.topk as usize).min(n).max(1);
// 1. Meta-store candidate set: rank all instances by
// (predicted_prefix DESC, load ASC) and take the top-K.
let meta_scores = meta.score_prefix(&req.hash_ids, now, n);
let any_meta_hit = meta_scores.iter().any(|&p| p > 0);
let mut ranked: Vec<usize> = (0..n).collect();
if any_meta_hit {
ranked.sort_by(|&a, &b| {
let pa = meta_scores[a];
let pb = meta_scores[b];
// prefix desc, then load asc
pb.cmp(&pa)
.then_with(|| {
self.load_of(&instances[a])
.partial_cmp(&self.load_of(&instances[b]))
.unwrap_or(std::cmp::Ordering::Equal)
})
});
} else {
// Cold start fallback: pure load order.
ranked.sort_by(|&a, &b| {
self.load_of(&instances[a])
.partial_cmp(&self.load_of(&instances[b]))
.unwrap_or(std::cmp::Ordering::Equal)
});
}
let probed = &ranked[..k];
// 2. Exact probe each candidate and pick
// argmax(exact_prefix, tiebreak: -load).
let mut candidates = Vec::with_capacity(k);
let mut best = probed[0] as u32;
let mut best_key: (i64, f64) = (i64::MIN, f64::INFINITY);
for &i in probed {
let inst = &instances[i];
let l0 = inst.cache.l0.longest_prefix_peek(&req.hash_ids);
let l1 = inst.cache.l1.longest_prefix_peek(&req.hash_ids[l0..]);
let predicted = (l0 + l1) as u32;
let load = self.load_of(inst);
candidates.push(CandidateInfo {
instance: inst.id,
predicted_prefix: predicted,
load_blocks: inst.kv_blocks_used,
queue_len: inst.queue_len(),
});
let key = (predicted as i64, -load);
if key > (best_key.0, -best_key.1) {
best_key = (predicted as i64, load);
best = inst.id;
}
}
RouteDecision {
req_id: req.req_id,
mode: "precise",
chosen: best,
probe_overhead_s: k as f64 * self.probe_latency_s,
candidates,
reason: "exact-probe top-K meta-store candidates",
}
}
}

View File

@@ -0,0 +1,196 @@
//! Prefix-affinity routing with load-aware fallback.
//!
//! **Key insight**: in real LLM traces, 99%+ of requests share a common
//! system-prompt prefix (dozens to hundreds of 16-token blocks). If we
//! *consistently* route requests with the same prefix to the same small set
//! of instances, L0 (HBM) cache hit rates increase dramatically because the
//! working set per instance is concentrated rather than scattered.
//!
//! Algorithm (rendezvous hashing + drain-time-aware selection):
//!
//! 1. **Fingerprint**: hash the first `K` blocks of the request to produce a
//! prefix fingerprint that captures the system prompt identity.
//!
//! 2. **Rendezvous ranking**: for each instance `i`, compute
//! `rendezvous(fingerprint, i)` — a deterministic pseudo-random score.
//! Sort instances by this score descending to get a stable, per-prefix
//! ordering.
//!
//! 3. **Select from top candidates**: among the top `fan_out` instances in
//! the rendezvous ranking, pick the one with the lowest estimated drain
//! time (architecture-aware, per-request sum). This accounts for
//! heterogeneous request sizes in the queue.
//!
//! 4. **Overload fallback**: if all top candidates have queue length above a
//! threshold, expand to the full instance set and use estimated-TTFT
//! scoring (drain + fetch) for the best selection.
//!
//! The combination ensures:
//! - **Cache locality**: same-prefix requests cluster on a few instances,
//! building strong L0 cache entries that benefit subsequent requests.
//! - **Load balance**: within the affinity group, drain-time-aware selection
//! avoids hot-spotting from large-prompt requests.
//! - **Zero overhead**: no per-instance probes needed; fingerprint +
//! rendezvous are pure arithmetic.
use crate::cluster::meta_store::MetaStore;
use crate::config::Config;
use crate::instance::Instance;
use crate::router::{CandidateInfo, RouteDecision, Router};
use crate::trace::RequestRecord;
pub struct PrefixAffinityRouter {
/// Number of leading block hashes used for the prefix fingerprint.
prefix_k: usize,
/// Number of top-affinity instances to consider before fallback.
fan_out: usize,
/// Queue-length threshold: if all top candidates exceed this, expand to
/// the full instance set.
overload_threshold: u32,
/// Bytes per KV block (for RDMA cost estimation in fallback path).
kv_block_bytes: f64,
/// RDMA bandwidth in bytes/s.
rdma_bw: f64,
/// RDMA per-transfer latency in seconds.
rdma_latency_s: f64,
}
impl PrefixAffinityRouter {
pub fn new(config: &Config) -> Self {
let n = config.cluster.num_instances as usize;
let cfg_fan = config.cluster.router.affinity_fan_out;
// fan_out: if configured, use it; otherwise auto = max(2, n/8).
let fan_out = if cfg_fan > 0 {
cfg_fan.min(n)
} else {
(n / 8).max(2).min(n)
};
Self {
prefix_k: config.cluster.router.prefix_k,
fan_out,
overload_threshold: 4,
kv_block_bytes: config.model.kv_block_bytes() as f64,
rdma_bw: config.hardware.rdma_bw,
rdma_latency_s: config.hardware.rdma_latency_us * 1e-6,
}
}
/// Compute a prefix fingerprint from the first K block hashes.
fn fingerprint(hash_ids: &[u64], k: usize) -> u64 {
let n = hash_ids.len().min(k);
let mut fp: u64 = 0xcbf29ce484222325; // FNV offset basis
for &h in &hash_ids[..n] {
fp ^= h;
fp = fp.wrapping_mul(0x100000001b3); // FNV prime
}
fp
}
/// Rendezvous hash: deterministic pseudo-random score for (fingerprint, instance_id).
/// Higher score = higher affinity.
fn rendezvous_score(fp: u64, instance_id: u32) -> u64 {
let mut h = fp ^ (instance_id as u64).wrapping_mul(0x9e3779b97f4a7c15);
// Splitmix64 finalizer
h = h.wrapping_add(0x9e3779b97f4a7c15);
h = (h ^ (h >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
h = (h ^ (h >> 27)).wrapping_mul(0x94d049bb133111eb);
h ^ (h >> 31)
}
/// Estimate RDMA fetch time for `remote_blocks` blocks.
fn fetch_time(&self, remote_blocks: u32) -> f64 {
if remote_blocks == 0 {
return 0.0;
}
let bytes = remote_blocks as f64 * self.kv_block_bytes;
bytes / self.rdma_bw + self.rdma_latency_s
}
}
impl Router for PrefixAffinityRouter {
fn name(&self) -> &'static str {
"prefix_affinity"
}
fn route(
&mut self,
req: &RequestRecord,
instances: &[Instance],
meta: &MetaStore,
now: f64,
) -> RouteDecision {
let n = instances.len();
let fp = Self::fingerprint(&req.hash_ids, self.prefix_k);
// Build rendezvous-ranked list of (score, index).
let mut ranked: Vec<(u64, usize)> = (0..n)
.map(|i| (Self::rendezvous_score(fp, instances[i].id), i))
.collect();
ranked.sort_unstable_by(|a, b| b.0.cmp(&a.0)); // descending score
// Collect candidate info for logging (also needed for fallback).
let scores = meta.score_prefix(&req.hash_ids, now, n);
let candidates: Vec<CandidateInfo> = instances
.iter()
.map(|inst| CandidateInfo {
instance: inst.id,
predicted_prefix: scores[inst.id as usize],
load_blocks: inst.kv_blocks_used,
queue_len: inst.queue_len(),
})
.collect();
// Phase 1: among top fan_out instances, pick lowest drain time.
let top_k = self.fan_out.min(n);
let mut best_idx = ranked[0].1;
let mut best_drain = instances[best_idx].estimated_drain_time();
let mut best_ql = instances[best_idx].queue_len();
let mut all_overloaded = best_ql > self.overload_threshold;
for &(_, idx) in &ranked[1..top_k] {
let drain = instances[idx].estimated_drain_time();
let ql = instances[idx].queue_len();
if drain < best_drain || (drain == best_drain && ql < best_ql) {
best_idx = idx;
best_drain = drain;
best_ql = ql;
}
if ql <= self.overload_threshold {
all_overloaded = false;
}
}
// Phase 2: if all top candidates are overloaded, search globally
// using estimated-TTFT (drain + fetch) for optimal fallback.
let reason;
if all_overloaded {
reason = "affinity fallback: min(drain+fetch)";
let cluster_prefix = scores.iter().copied().max().unwrap_or(0);
let mut best_cost = f64::INFINITY;
for &(_, idx) in ranked.iter() {
let inst = &instances[idx];
let drain = inst.estimated_drain_time();
let local_prefix = scores[idx];
let remote_blocks = cluster_prefix.saturating_sub(local_prefix);
let cost = drain + self.fetch_time(remote_blocks);
let ql = inst.queue_len();
if cost < best_cost || (cost == best_cost && ql < best_ql) {
best_cost = cost;
best_idx = idx;
best_ql = ql;
}
}
} else {
reason = "prefix affinity: top-K min drain";
}
RouteDecision {
req_id: req.req_id,
mode: "prefix_affinity",
chosen: instances[best_idx].id,
probe_overhead_s: 0.0,
candidates,
reason,
}
}
}

90
src/router/random.rs Normal file
View File

@@ -0,0 +1,90 @@
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use crate::cluster::meta_store::MetaStore;
use crate::instance::Instance;
use crate::router::{CandidateInfo, RouteDecision, Router};
use crate::trace::RequestRecord;
use crate::types::InstanceId;
pub struct RandomRouter {
rng: ChaCha8Rng,
}
impl RandomRouter {
pub fn new(seed: u64) -> Self {
Self { rng: ChaCha8Rng::seed_from_u64(seed) }
}
}
impl Router for RandomRouter {
fn name(&self) -> &'static str {
"random"
}
fn route(
&mut self,
req: &RequestRecord,
instances: &[Instance],
_meta: &MetaStore,
_now: f64,
) -> RouteDecision {
let n = instances.len();
let chosen = self.rng.gen_range(0..n) as InstanceId;
RouteDecision {
req_id: req.req_id,
mode: "random",
chosen,
probe_overhead_s: 0.0,
candidates: vec![CandidateInfo {
instance: chosen,
predicted_prefix: 0,
load_blocks: instances[chosen as usize].kv_blocks_used,
queue_len: instances[chosen as usize].queue_len(),
}],
reason: "uniform random",
}
}
}
#[derive(Default)]
pub struct RoundRobinRouter {
next: u32,
}
impl RoundRobinRouter {
pub fn new() -> Self {
Self::default()
}
}
impl Router for RoundRobinRouter {
fn name(&self) -> &'static str {
"round_robin"
}
fn route(
&mut self,
req: &RequestRecord,
instances: &[Instance],
_meta: &MetaStore,
_now: f64,
) -> RouteDecision {
let n = instances.len() as u32;
let chosen = self.next % n;
self.next = self.next.wrapping_add(1);
RouteDecision {
req_id: req.req_id,
mode: "round_robin",
chosen,
probe_overhead_s: 0.0,
candidates: vec![CandidateInfo {
instance: chosen,
predicted_prefix: 0,
load_blocks: instances[chosen as usize].kv_blocks_used,
queue_len: instances[chosen as usize].queue_len(),
}],
reason: "round robin",
}
}
}

59
src/router/ttl_aware.rs Normal file
View File

@@ -0,0 +1,59 @@
use crate::cluster::meta_store::MetaStore;
use crate::instance::Instance;
use crate::router::{CandidateInfo, RouteDecision, Router};
use crate::trace::RequestRecord;
pub struct TtlAwareRouter {
pub alpha: f64,
}
impl TtlAwareRouter {
pub fn new(alpha: f64) -> Self {
Self { alpha }
}
}
impl Router for TtlAwareRouter {
fn name(&self) -> &'static str {
"ttl_aware"
}
fn route(
&mut self,
req: &RequestRecord,
instances: &[Instance],
meta: &MetaStore,
now: f64,
) -> RouteDecision {
let n = instances.len();
let scores = meta.score_prefix(&req.hash_ids, now, n);
let mut best = 0u32;
let mut best_key = (i64::MIN, f64::INFINITY); // maximize prefix, then minimize load
let mut candidates = Vec::with_capacity(n);
for inst in instances {
let p = scores[inst.id as usize];
let load = inst.kv_blocks_used as f64
+ self.alpha * inst.queue_len() as f64;
candidates.push(CandidateInfo {
instance: inst.id,
predicted_prefix: p,
load_blocks: inst.kv_blocks_used,
queue_len: inst.queue_len(),
});
let key = (p as i64, -load);
// we want max prefix, min load -> compare (p, -load) lexicographically max
if key > (best_key.0, -best_key.1) {
best_key = (p as i64, load);
best = inst.id;
}
}
RouteDecision {
req_id: req.req_id,
mode: "ttl_aware",
chosen: best,
probe_overhead_s: 0.0,
candidates,
reason: "max meta_store prefix, tie -> least loaded",
}
}
}

113
src/sim/engine.rs Normal file
View File

@@ -0,0 +1,113 @@
//! Discrete-event engine.
//!
//! Single-threaded virtual time `f64` seconds. Events are stored in a min-heap
//! keyed by `(time, seq)` so equal-time events fire in insertion order.
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use super::events::Event;
#[derive(Debug)]
struct Slot {
time: f64,
seq: u64,
event: Event,
}
impl Eq for Slot {}
impl PartialEq for Slot {
fn eq(&self, other: &Self) -> bool {
self.time == other.time && self.seq == other.seq
}
}
impl Ord for Slot {
fn cmp(&self, other: &Self) -> Ordering {
// Reverse so BinaryHeap acts as a min-heap.
other
.time
.partial_cmp(&self.time)
.unwrap_or(Ordering::Equal)
.then_with(|| other.seq.cmp(&self.seq))
}
}
impl PartialOrd for Slot {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Default)]
pub struct EventQueue {
heap: BinaryHeap<Slot>,
seq: u64,
now: f64,
}
impl EventQueue {
pub fn new() -> Self {
Self::default()
}
pub fn now(&self) -> f64 {
self.now
}
pub fn schedule(&mut self, time: f64, event: Event) {
let t = time.max(self.now);
self.seq += 1;
self.heap.push(Slot { time: t, seq: self.seq, event });
}
pub fn pop(&mut self) -> Option<(f64, Event)> {
let slot = self.heap.pop()?;
if slot.time > self.now {
self.now = slot.time;
}
Some((slot.time, slot.event))
}
pub fn len(&self) -> usize {
self.heap.len()
}
pub fn is_empty(&self) -> bool {
self.heap.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::InstanceId;
#[test]
fn pops_in_time_order() {
let mut q = EventQueue::new();
q.schedule(2.0, Event::BatchTick { instance: 0 as InstanceId });
q.schedule(1.0, Event::BatchTick { instance: 1 });
q.schedule(1.5, Event::BatchTick { instance: 2 });
let (t1, _) = q.pop().unwrap();
let (t2, _) = q.pop().unwrap();
let (t3, _) = q.pop().unwrap();
assert!(t1 <= t2 && t2 <= t3);
assert!((t1 - 1.0).abs() < 1e-12);
assert!((t3 - 2.0).abs() < 1e-12);
}
#[test]
fn equal_time_fifo() {
let mut q = EventQueue::new();
q.schedule(1.0, Event::BatchTick { instance: 7 });
q.schedule(1.0, Event::BatchTick { instance: 8 });
let (_, e1) = q.pop().unwrap();
let (_, e2) = q.pop().unwrap();
match (e1, e2) {
(Event::BatchTick { instance: a }, Event::BatchTick { instance: b }) => {
assert_eq!(a, 7);
assert_eq!(b, 8);
}
_ => panic!("wrong events"),
}
}
}

15
src/sim/events.rs Normal file
View File

@@ -0,0 +1,15 @@
//! Event types for the discrete-event engine.
use crate::types::{InstanceId, ReqId};
#[derive(Debug)]
pub enum Event {
/// New trace request arrives at the cluster router.
Arrival { req_id: ReqId },
/// Per-instance scheduler tick (continuous batching).
BatchTick { instance: InstanceId },
/// Periodic time-series sample of all instances.
Sample,
/// Stop the simulation early (used internally).
Stop,
}

5
src/sim/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
pub mod engine;
pub mod events;
pub use engine::EventQueue;
pub use events::Event;

102
src/trace.rs Normal file
View File

@@ -0,0 +1,102 @@
//! Streaming JSONL reader for the qwen-bailian trace format.
//!
//! Schema (per upstream README):
//! chat_id: i64
//! parent_chat_id: i64 (-1 = root)
//! timestamp: f64 (seconds since trace start)
//! input_length: i64
//! output_length: i64
//! type: string (text/search/image/file)
//! turn: i64
//! hash_ids: [i64] (16-token blocks, salted SipHash)
use anyhow::{Context, Result};
use serde::Deserialize;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
#[derive(Debug, Clone, Deserialize)]
struct RawRecord {
#[serde(default)]
chat_id: i64,
#[serde(default)]
timestamp: f64,
#[serde(default)]
input_length: i64,
#[serde(default)]
output_length: i64,
#[serde(default)]
hash_ids: Vec<i64>,
}
#[derive(Debug, Clone)]
pub struct RequestRecord {
pub req_id: u64,
pub chat_id: i64,
pub arrival: f64,
pub input_len: u32,
pub output_len: u32,
pub hash_ids: Vec<u64>,
}
pub struct TraceReader {
inner: BufReader<File>,
next_id: u64,
line_buf: String,
max_requests: Option<u64>,
}
impl TraceReader {
pub fn open<P: AsRef<Path>>(path: P, max_requests: Option<u64>) -> Result<Self> {
let path = path.as_ref();
let f = File::open(path)
.with_context(|| format!("opening trace {}", path.display()))?;
Ok(Self {
inner: BufReader::with_capacity(1 << 20, f),
next_id: 0,
line_buf: String::with_capacity(4096),
max_requests,
})
}
}
impl Iterator for TraceReader {
type Item = Result<RequestRecord>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(cap) = self.max_requests {
if self.next_id >= cap {
return None;
}
}
loop {
self.line_buf.clear();
match self.inner.read_line(&mut self.line_buf) {
Ok(0) => return None,
Ok(_) => {
let trimmed = self.line_buf.trim();
if trimmed.is_empty() {
continue;
}
let parsed: Result<RawRecord, _> = serde_json::from_str(trimmed);
let raw = match parsed {
Ok(r) => r,
Err(e) => return Some(Err(anyhow::anyhow!("trace parse: {e}"))),
};
let id = self.next_id;
self.next_id += 1;
return Some(Ok(RequestRecord {
req_id: id,
chat_id: raw.chat_id,
arrival: raw.timestamp,
input_len: raw.input_length.max(0) as u32,
output_len: raw.output_length.max(0) as u32,
hash_ids: raw.hash_ids.into_iter().map(|h| h as u64).collect(),
}));
}
Err(e) => return Some(Err(e.into())),
}
}
}
}

4
src/types.rs Normal file
View File

@@ -0,0 +1,4 @@
//! Shared simple types.
pub type InstanceId = u32;
pub type ReqId = u64;