KVCache simulator for LLM serving cluster routing research
Discrete-event simulator for evaluating KV cache-aware routing policies in prefill-disaggregated LLM serving clusters. Models a two-tier KV cache hierarchy (L0 GPU HBM + L1 CPU DRAM) with RDMA/PCIe link contention, architecture-derived roofline compute (MoE, MLA, DSA), and a cluster-wide meta-store for prefix-aware routing decisions. Includes 11 routing policies (random, round_robin, least_loaded, least_tokens, ttl_aware, precise, min_pd, cache_load, cache_score, estimated_ttft, prefix_affinity), HuggingFace config.json auto-parsing, built-in GPU hardware presets (H100/H800/H20/A100/B200), and ablation tooling for systematic policy comparison across real Alibaba serving traces. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
405
src/instance/compute.rs
Normal file
405
src/instance/compute.rs
Normal file
@@ -0,0 +1,405 @@
|
||||
//! Roofline cost model for prefill (PD disaggregation — decode not modeled).
|
||||
//!
|
||||
//! Two construction modes:
|
||||
//!
|
||||
//! **Architecture-derived** (`ModelConfig.hidden_size` present):
|
||||
//! All FLOPs, attention coefficients, and weight-stream costs are computed
|
||||
//! from the model shape. Handles standard / GQA / MLA attention projections,
|
||||
//! MoE routing, and DSA / sliding-window sub-quadratic attention patterns.
|
||||
//!
|
||||
//! **Legacy manual** (`hidden_size` absent): uses the raw
|
||||
//! `flops_per_token_prefill` + `attn_quadratic_coeff` scalars from the YAML.
|
||||
//!
|
||||
//! ```text
|
||||
//! prefill_time(N) = max(compute_time(N), mem_time)
|
||||
//!
|
||||
//! compute_time = sum over layers of:
|
||||
//! (N * linear_flops + attn_coeff * N * effective_ctx(N)) / gpu_flops
|
||||
//!
|
||||
//! mem_time = num_layers * weight_bytes_per_layer / gpu_mem_bw
|
||||
//! ```
|
||||
//!
|
||||
//! `effective_ctx(N)` equals `N` for dense attention (→ O(N²) total) but
|
||||
//! is sub-linear for DSA / sliding-window.
|
||||
|
||||
use crate::config::{AttentionConfig, HardwareConfig, ModelConfig};
|
||||
|
||||
/// Resolved attention pattern used at runtime.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AttentionPattern {
|
||||
/// Full quadratic: effective_ctx = N.
|
||||
Dense,
|
||||
/// Sliding window: effective_ctx = min(N, window).
|
||||
SlidingWindow { window: f64 },
|
||||
/// DeepSeek Sparse Attention: effective_ctx = min(N, dense_window) +
|
||||
/// max(0, N - dense_window) / sparse_stride.
|
||||
Dsa { dense_window: f64, sparse_stride: f64 },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComputeModel {
|
||||
/// Total transformer layers.
|
||||
pub num_layers: f64,
|
||||
/// How many initial layers use dense attention (rest use `attn_pattern`).
|
||||
/// For `Dense` pattern this equals `num_layers`.
|
||||
pub first_dense_layers: f64,
|
||||
/// Non-attention FLOPs per token per layer (QKV proj + output proj + MLP).
|
||||
pub linear_flops_per_token: f64,
|
||||
/// Attention score coefficient: per-layer attention FLOPs =
|
||||
/// `attn_coeff * N * effective_ctx(N)`.
|
||||
pub attn_coeff: f64,
|
||||
/// Attention pattern for non-dense layers.
|
||||
pub attn_pattern: AttentionPattern,
|
||||
/// Weight bytes read from HBM per layer (for memory-bound check).
|
||||
pub weight_bytes_per_layer: f64,
|
||||
/// Peak GPU FLOPs (aggregate across TP group).
|
||||
pub gpu_flops: f64,
|
||||
/// Peak GPU memory bandwidth (aggregate across TP group).
|
||||
pub gpu_mem_bw: f64,
|
||||
}
|
||||
|
||||
impl ComputeModel {
|
||||
pub fn new(model: &ModelConfig, hw: &HardwareConfig) -> Self {
|
||||
if model.is_arch_mode() {
|
||||
Self::from_arch(model, hw)
|
||||
} else {
|
||||
Self::from_manual(model, hw)
|
||||
}
|
||||
}
|
||||
|
||||
// ----- Architecture-derived construction --------------------------------
|
||||
|
||||
fn from_arch(model: &ModelConfig, hw: &HardwareConfig) -> Self {
|
||||
let h = model.hidden_size.unwrap() as f64;
|
||||
let n_heads = model.num_attention_heads.unwrap_or(model.num_kv_heads) as f64;
|
||||
let n_kv = model.num_kv_heads as f64;
|
||||
let hd = model.head_dim as f64;
|
||||
let inter = model.intermediate_size.unwrap_or(0) as f64;
|
||||
let dtype = model.dtype_bytes as f64;
|
||||
|
||||
// --- Attention linear FLOPs/token/layer ---
|
||||
let attn_linear = if let Some(mla) = &model.mla {
|
||||
let qlr = mla.q_lora_rank as f64;
|
||||
let kvlr = mla.kv_lora_rank as f64;
|
||||
let qk_hd = (mla.qk_nope_head_dim + mla.qk_rope_head_dim) as f64;
|
||||
let qk_rd = mla.qk_rope_head_dim as f64;
|
||||
let vhd = mla.v_head_dim as f64;
|
||||
// Q: down-project + up-project
|
||||
let q = 2.0 * h * qlr + 2.0 * qlr * n_heads * qk_hd;
|
||||
// KV: down-project (compressed latent + RoPE key)
|
||||
let kv = 2.0 * h * (kvlr + qk_rd);
|
||||
// Output: up-project
|
||||
let o = 2.0 * n_heads * vhd * h;
|
||||
q + kv + o
|
||||
} else {
|
||||
// Standard / GQA
|
||||
let qkv = 2.0 * h * (n_heads + 2.0 * n_kv) * hd;
|
||||
let o = 2.0 * n_heads * hd * h;
|
||||
qkv + o
|
||||
};
|
||||
|
||||
// --- MLP FLOPs/token/layer (SwiGLU: gate + up + down = 3 matmuls) ---
|
||||
let mlp = if let Some(moe) = &model.moe {
|
||||
let expert_inter = moe.expert_intermediate_size
|
||||
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
|
||||
let active = moe.num_active_experts as f64;
|
||||
let shared = moe.num_shared_experts as f64;
|
||||
active * 6.0 * h * expert_inter + shared * 6.0 * h * inter
|
||||
} else {
|
||||
6.0 * h * inter
|
||||
};
|
||||
|
||||
let linear_flops = attn_linear + mlp;
|
||||
|
||||
// --- Attention quadratic coefficient ---
|
||||
// attn_flops_per_layer(N) = attn_coeff * N * effective_ctx(N)
|
||||
let attn_coeff = if let Some(mla) = &model.mla {
|
||||
let kvlr = mla.kv_lora_rank as f64;
|
||||
let qk_rd = mla.qk_rope_head_dim as f64;
|
||||
// Absorbed QK^T: each head dots over (kv_lora_rank + qk_rope_head_dim) dims.
|
||||
// Absorbed V: each head dots over kv_lora_rank dims.
|
||||
2.0 * n_heads * (2.0 * kvlr + qk_rd)
|
||||
} else {
|
||||
// Standard: QK^T + attn@V, each 2 * n_heads * head_dim per pair.
|
||||
4.0 * n_heads * hd
|
||||
};
|
||||
|
||||
// --- Weight bytes per layer (active params only for MoE) ---
|
||||
let attn_wt = if let Some(mla) = &model.mla {
|
||||
let qlr = mla.q_lora_rank as f64;
|
||||
let kvlr = mla.kv_lora_rank as f64;
|
||||
let qk_hd = (mla.qk_nope_head_dim + mla.qk_rope_head_dim) as f64;
|
||||
let qk_rd = mla.qk_rope_head_dim as f64;
|
||||
let vhd = mla.v_head_dim as f64;
|
||||
(h * qlr + qlr * n_heads * qk_hd
|
||||
+ h * (kvlr + qk_rd)
|
||||
+ n_heads * vhd * h)
|
||||
* dtype
|
||||
} else {
|
||||
((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * dtype
|
||||
};
|
||||
let mlp_wt = if let Some(moe) = &model.moe {
|
||||
let expert_inter = moe.expert_intermediate_size
|
||||
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
|
||||
let active = moe.num_active_experts as f64;
|
||||
let shared = moe.num_shared_experts as f64;
|
||||
(active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * dtype
|
||||
} else {
|
||||
3.0 * h * inter * dtype
|
||||
};
|
||||
let weight_bytes = attn_wt + mlp_wt;
|
||||
|
||||
// --- Attention pattern ---
|
||||
let (attn_pattern, first_dense) = match &model.attention {
|
||||
Some(AttentionConfig::Dsa {
|
||||
dense_window,
|
||||
sparse_stride,
|
||||
first_dense_layers,
|
||||
}) => (
|
||||
AttentionPattern::Dsa {
|
||||
dense_window: *dense_window as f64,
|
||||
sparse_stride: *sparse_stride as f64,
|
||||
},
|
||||
*first_dense_layers as f64,
|
||||
),
|
||||
Some(AttentionConfig::SlidingWindow { window_size }) => (
|
||||
AttentionPattern::SlidingWindow {
|
||||
window: *window_size as f64,
|
||||
},
|
||||
0.0,
|
||||
),
|
||||
Some(AttentionConfig::Dense) | None => (
|
||||
AttentionPattern::Dense,
|
||||
model.num_layers as f64,
|
||||
),
|
||||
};
|
||||
|
||||
Self {
|
||||
num_layers: model.num_layers as f64,
|
||||
first_dense_layers: first_dense,
|
||||
linear_flops_per_token: linear_flops,
|
||||
attn_coeff,
|
||||
attn_pattern,
|
||||
weight_bytes_per_layer: weight_bytes,
|
||||
gpu_flops: hw.gpu_flops,
|
||||
gpu_mem_bw: hw.gpu_mem_bw,
|
||||
}
|
||||
}
|
||||
|
||||
// ----- Legacy manual construction ---------------------------------------
|
||||
|
||||
fn from_manual(model: &ModelConfig, hw: &HardwareConfig) -> Self {
|
||||
Self {
|
||||
num_layers: model.num_layers as f64,
|
||||
first_dense_layers: model.num_layers as f64,
|
||||
linear_flops_per_token: model.flops_per_token_prefill.unwrap_or(0.0),
|
||||
attn_coeff: model.attn_quadratic_coeff.unwrap_or(0.0),
|
||||
attn_pattern: AttentionPattern::Dense,
|
||||
weight_bytes_per_layer: 0.0,
|
||||
gpu_flops: hw.gpu_flops,
|
||||
gpu_mem_bw: hw.gpu_mem_bw,
|
||||
}
|
||||
}
|
||||
|
||||
// ----- Prefill time -----------------------------------------------------
|
||||
|
||||
/// Effective context length a single token attends to at sequence length N.
|
||||
fn effective_ctx(&self, n: f64, dense_layer: bool) -> f64 {
|
||||
if dense_layer {
|
||||
return n;
|
||||
}
|
||||
match &self.attn_pattern {
|
||||
AttentionPattern::Dense => n,
|
||||
AttentionPattern::SlidingWindow { window } => n.min(*window),
|
||||
AttentionPattern::Dsa {
|
||||
dense_window,
|
||||
sparse_stride,
|
||||
} => {
|
||||
if n <= *dense_window {
|
||||
n
|
||||
} else {
|
||||
*dense_window + (n - *dense_window) / *sparse_stride
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Time (s) to prefill `n` tokens.
|
||||
pub fn prefill_time(&self, n: u32) -> f64 {
|
||||
if n == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
let n = n as f64;
|
||||
let linear = n * self.linear_flops_per_token;
|
||||
|
||||
// Compute FLOPs across all layers (dense + sparse may differ).
|
||||
let dense_layers = self.first_dense_layers;
|
||||
let sparse_layers = self.num_layers - dense_layers;
|
||||
|
||||
let dense_flops = dense_layers
|
||||
* (linear + self.attn_coeff * n * self.effective_ctx(n, true));
|
||||
let sparse_flops = sparse_layers
|
||||
* (linear + self.attn_coeff * n * self.effective_ctx(n, false));
|
||||
let total_flops = dense_flops + sparse_flops;
|
||||
|
||||
let compute_time = total_flops / self.gpu_flops;
|
||||
// Weight stream: all layers' active weights read once from HBM.
|
||||
let mem_time = self.weight_bytes_per_layer * self.num_layers / self.gpu_mem_bw;
|
||||
|
||||
compute_time.max(mem_time)
|
||||
}
|
||||
|
||||
/// Print human-readable derived coefficients (for `validate` output).
|
||||
pub fn describe(&self) -> String {
|
||||
let pattern_str = match &self.attn_pattern {
|
||||
AttentionPattern::Dense => "dense".to_string(),
|
||||
AttentionPattern::SlidingWindow { window } => format!("sliding_window({})", *window as u64),
|
||||
AttentionPattern::Dsa {
|
||||
dense_window,
|
||||
sparse_stride,
|
||||
} => format!(
|
||||
"dsa(window={}, stride={}, {} dense layers)",
|
||||
*dense_window as u64, *sparse_stride as u64, self.first_dense_layers as u64
|
||||
),
|
||||
};
|
||||
format!(
|
||||
"linear_flops/tok/layer={:.3e}, attn_coeff={:.0}, pattern={}, \
|
||||
weight_bytes/layer={:.2e}",
|
||||
self.linear_flops_per_token, self.attn_coeff, pattern_str,
|
||||
self.weight_bytes_per_layer,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn cm_legacy() -> ComputeModel {
|
||||
ComputeModel {
|
||||
num_layers: 28.0,
|
||||
first_dense_layers: 28.0,
|
||||
linear_flops_per_token: 1.4e10,
|
||||
attn_coeff: 1024.0,
|
||||
attn_pattern: AttentionPattern::Dense,
|
||||
weight_bytes_per_layer: 0.0,
|
||||
gpu_flops: 9.89e14,
|
||||
gpu_mem_bw: 3.35e12,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefill_monotonic_in_n() {
|
||||
let m = cm_legacy();
|
||||
let mut prev = 0.0;
|
||||
for &n in &[1u32, 8, 64, 512, 4096, 32768] {
|
||||
let t = m.prefill_time(n);
|
||||
assert!(t > prev, "prefill_time should be monotonic; n={n} t={t}");
|
||||
prev = t;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quadratic_dominates_for_long_prompt() {
|
||||
let m = cm_legacy();
|
||||
let lin = m.prefill_time(1024);
|
||||
let big = m.prefill_time(32768);
|
||||
assert!(big / lin > 32.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_tokens_is_free() {
|
||||
let m = cm_legacy();
|
||||
assert_eq!(m.prefill_time(0), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dsa_subquadratic() {
|
||||
// With DSA (window=4096, stride=8) the cost at 128k should be
|
||||
// MUCH less than pure quadratic.
|
||||
let dense = ComputeModel {
|
||||
num_layers: 78.0,
|
||||
first_dense_layers: 78.0,
|
||||
linear_flops_per_token: 1.0e9,
|
||||
attn_coeff: 139264.0,
|
||||
attn_pattern: AttentionPattern::Dense,
|
||||
weight_bytes_per_layer: 0.0,
|
||||
gpu_flops: 1.8e16,
|
||||
gpu_mem_bw: 6.4e13,
|
||||
};
|
||||
let dsa = ComputeModel {
|
||||
attn_pattern: AttentionPattern::Dsa {
|
||||
dense_window: 4096.0,
|
||||
sparse_stride: 8.0,
|
||||
},
|
||||
first_dense_layers: 3.0,
|
||||
..dense.clone()
|
||||
};
|
||||
let n = 131072; // 128k tokens
|
||||
let t_dense = dense.prefill_time(n);
|
||||
let t_dsa = dsa.prefill_time(n);
|
||||
// DSA should be dramatically cheaper at long context.
|
||||
assert!(
|
||||
t_dsa < t_dense * 0.3,
|
||||
"DSA should be <30% of dense at 128k: dense={t_dense:.3} dsa={t_dsa:.3}"
|
||||
);
|
||||
// But still monotonic.
|
||||
assert!(t_dsa > dsa.prefill_time(n / 2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mem_bound_short_prefill() {
|
||||
// With very heavy weights and a short prompt, memory should dominate.
|
||||
let m = ComputeModel {
|
||||
num_layers: 10.0,
|
||||
first_dense_layers: 10.0,
|
||||
linear_flops_per_token: 1.0e6, // tiny compute
|
||||
attn_coeff: 1.0,
|
||||
attn_pattern: AttentionPattern::Dense,
|
||||
weight_bytes_per_layer: 1.0e12, // 1 TB per layer
|
||||
gpu_flops: 1.0e15,
|
||||
gpu_mem_bw: 1.0e12,
|
||||
};
|
||||
let t1 = m.prefill_time(1);
|
||||
let t8 = m.prefill_time(8);
|
||||
// Memory time = 10 * 1e12 / 1e12 = 10s, should dominate.
|
||||
assert!((t1 - 10.0).abs() < 0.01);
|
||||
// Doubling tokens shouldn't change time much (mem-bound).
|
||||
assert!((t8 - t1).abs() / t1 < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn arch_derives_from_model_config() {
|
||||
// Minimal dense model: verify from_arch produces something sensible.
|
||||
let model = ModelConfig {
|
||||
name: "test".into(),
|
||||
num_layers: 4,
|
||||
num_kv_heads: 2,
|
||||
head_dim: 64,
|
||||
dtype_bytes: 2,
|
||||
block_size_tokens: 16,
|
||||
hidden_size: Some(256),
|
||||
num_attention_heads: Some(4),
|
||||
intermediate_size: Some(512),
|
||||
..Default::default()
|
||||
};
|
||||
let hw = HardwareConfig {
|
||||
gpu_flops: 1e14,
|
||||
gpu_mem_bw: 1e12,
|
||||
hbm_bytes: 1e9,
|
||||
dram_bytes: 4e9,
|
||||
pcie_bw: 32e9,
|
||||
pcie_latency_us: 1.0,
|
||||
rdma_bw: 12e9,
|
||||
rdma_latency_us: 5.0,
|
||||
max_batch_slots: 32,
|
||||
prefill_chunk_tokens: 1024,
|
||||
};
|
||||
let cm = ComputeModel::new(&model, &hw);
|
||||
assert!(cm.linear_flops_per_token > 0.0);
|
||||
assert!(cm.attn_coeff > 0.0);
|
||||
assert!(cm.weight_bytes_per_layer > 0.0);
|
||||
let t = cm.prefill_time(1024);
|
||||
assert!(t > 0.0);
|
||||
}
|
||||
}
|
||||
191
src/instance/instance.rs
Normal file
191
src/instance/instance.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
//! One simulated **prefill** serving instance.
|
||||
//!
|
||||
//! This simulator assumes **PD (prefill/decode) disaggregation**: prefill
|
||||
//! and decode run on dedicated instance pools, and only prefill instances
|
||||
//! are modeled here. The decode side is invisible to the KV-cache-aware
|
||||
//! routing problem we are studying — once prefill finishes, the KV cache is
|
||||
//! shipped to a decode instance via a separate (out-of-scope) path.
|
||||
//!
|
||||
//! As a result this `Instance`:
|
||||
//! * Owns a two-tier KV cache (L0 = HBM, L1 = DRAM / v6d) used only for
|
||||
//! prefill prefix reuse.
|
||||
//! * Owns the PCIe / RDMA links used for cache fetches.
|
||||
//! * Runs a simple FCFS chunked-prefill scheduler: one request's prefill
|
||||
//! chunk per step, up to `prefill_chunk_tokens` per chunk.
|
||||
//!
|
||||
//! The cluster (`crate::cluster`) is responsible for routing arrivals,
|
||||
//! consulting the global meta store, and inserting fetched blocks into the
|
||||
//! instance's caches before handing the request off via `admit`.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
|
||||
use crate::config::{HardwareConfig, ModelConfig};
|
||||
use crate::instance::compute::ComputeModel;
|
||||
use crate::instance::kv_cache::TwoTierCache;
|
||||
use crate::network::InstanceLinks;
|
||||
use crate::types::{InstanceId, ReqId};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AdmittedRequest {
|
||||
pub req_id: ReqId,
|
||||
pub arrival: f64,
|
||||
/// Earliest time at which the KV fetch chain (L1 + RDMA + PCIe) for this
|
||||
/// request has completed, so its prefill compute can begin.
|
||||
pub ready_at: f64,
|
||||
/// Tokens still needing prefill compute (after cache hits accounted for).
|
||||
pub prefill_tokens_remaining: u32,
|
||||
/// KV blocks reserved on this instance's HBM for the lifetime of this
|
||||
/// request's prefill (= number of input blocks).
|
||||
pub reserved_blocks: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StepResult {
|
||||
pub completed: Vec<(ReqId, f64, f64)>, // (req_id, ttft, end_time)
|
||||
pub next_tick: Option<f64>,
|
||||
}
|
||||
|
||||
pub struct Instance {
|
||||
pub id: InstanceId,
|
||||
pub cache: TwoTierCache,
|
||||
pub links: InstanceLinks,
|
||||
pub compute: ComputeModel,
|
||||
pub block_size_tokens: u32,
|
||||
pub hbm_block_budget: u32,
|
||||
pub dram_block_budget: u32,
|
||||
pub max_batch_slots: u32,
|
||||
pub prefill_chunk_tokens: u32,
|
||||
|
||||
pub kv_blocks_used: u32,
|
||||
|
||||
/// Admitted but not yet ready (waiting for fetch chain to land).
|
||||
pending: VecDeque<AdmittedRequest>,
|
||||
/// Ready and currently being prefilled (FCFS, one at a time per step).
|
||||
prefilling: VecDeque<AdmittedRequest>,
|
||||
|
||||
/// True if a BatchTick is already on the global queue for us.
|
||||
pub tick_scheduled: bool,
|
||||
}
|
||||
|
||||
impl Instance {
|
||||
pub fn new(id: InstanceId, model: &ModelConfig, hw: &HardwareConfig) -> Self {
|
||||
let block_bytes = model.kv_block_bytes() as f64;
|
||||
let hbm_blocks = (hw.hbm_bytes / block_bytes).max(1.0) as u32;
|
||||
let dram_blocks = (hw.dram_bytes / block_bytes).max(1.0) as u32;
|
||||
Self {
|
||||
id,
|
||||
cache: TwoTierCache::new(hbm_blocks as usize, dram_blocks as usize),
|
||||
links: InstanceLinks::from_hw(hw),
|
||||
compute: ComputeModel::new(model, hw),
|
||||
block_size_tokens: model.block_size_tokens,
|
||||
hbm_block_budget: hbm_blocks,
|
||||
dram_block_budget: dram_blocks,
|
||||
max_batch_slots: hw.max_batch_slots,
|
||||
prefill_chunk_tokens: hw.prefill_chunk_tokens,
|
||||
kv_blocks_used: 0,
|
||||
pending: VecDeque::new(),
|
||||
prefilling: VecDeque::new(),
|
||||
tick_scheduled: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn queue_len(&self) -> u32 {
|
||||
(self.pending.len() + self.prefilling.len()) as u32
|
||||
}
|
||||
|
||||
/// Total prefill tokens remaining across all pending and prefilling requests.
|
||||
pub fn waiting_tokens(&self) -> u64 {
|
||||
self.pending
|
||||
.iter()
|
||||
.chain(self.prefilling.iter())
|
||||
.map(|r| r.prefill_tokens_remaining as u64)
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Estimated wall-clock time to drain all currently queued requests.
|
||||
///
|
||||
/// Sums `compute.prefill_time(tokens_j)` for each queued request,
|
||||
/// capturing the non-linear (quadratic / DSA) cost accurately.
|
||||
pub fn estimated_drain_time(&self) -> f64 {
|
||||
self.pending
|
||||
.iter()
|
||||
.chain(self.prefilling.iter())
|
||||
.map(|r| self.compute.prefill_time(r.prefill_tokens_remaining))
|
||||
.sum()
|
||||
}
|
||||
|
||||
pub fn admit(&mut self, req: AdmittedRequest) {
|
||||
self.pending.push_back(req);
|
||||
}
|
||||
|
||||
/// Run one batch step. Returns any requests that finished prefill during
|
||||
/// this step plus the next wakeup time for the instance.
|
||||
pub fn step(&mut self, now: f64) -> StepResult {
|
||||
let mut completed = Vec::new();
|
||||
|
||||
// 1. Drain ready pending requests into prefilling, respecting KV
|
||||
// budget and slot cap. A request whose fetch chain is complete
|
||||
// *and* has zero prefill tokens (full cache hit) finishes
|
||||
// immediately at `now`.
|
||||
while let Some(front) = self.pending.front() {
|
||||
if front.ready_at > now {
|
||||
break;
|
||||
}
|
||||
if self.prefilling.len() as u32 >= self.max_batch_slots {
|
||||
break;
|
||||
}
|
||||
if self.kv_blocks_used + front.reserved_blocks > self.hbm_block_budget {
|
||||
break;
|
||||
}
|
||||
let r = self.pending.pop_front().unwrap();
|
||||
self.kv_blocks_used += r.reserved_blocks;
|
||||
if r.prefill_tokens_remaining == 0 {
|
||||
// Full cache hit: nothing to compute. TTFT == fetch time.
|
||||
let ttft = now - r.arrival;
|
||||
self.kv_blocks_used = self.kv_blocks_used.saturating_sub(r.reserved_blocks);
|
||||
completed.push((r.req_id, ttft, now));
|
||||
} else {
|
||||
self.prefilling.push_back(r);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Run one chunked-prefill step on the head of `prefilling`.
|
||||
let chunk_tokens = self
|
||||
.prefilling
|
||||
.front()
|
||||
.map(|r| r.prefill_tokens_remaining.min(self.prefill_chunk_tokens))
|
||||
.unwrap_or(0);
|
||||
|
||||
if chunk_tokens == 0 {
|
||||
// Nothing compute-bound in flight right now.
|
||||
return StepResult {
|
||||
completed,
|
||||
next_tick: self.next_wakeup(now),
|
||||
};
|
||||
}
|
||||
|
||||
let dt = self.compute.prefill_time(chunk_tokens);
|
||||
let t_end = now + dt;
|
||||
|
||||
let head = self.prefilling.front_mut().unwrap();
|
||||
head.prefill_tokens_remaining -= chunk_tokens;
|
||||
if head.prefill_tokens_remaining == 0 {
|
||||
let done = self.prefilling.pop_front().unwrap();
|
||||
let ttft = t_end - done.arrival;
|
||||
self.kv_blocks_used = self.kv_blocks_used.saturating_sub(done.reserved_blocks);
|
||||
completed.push((done.req_id, ttft, t_end));
|
||||
}
|
||||
|
||||
StepResult {
|
||||
completed,
|
||||
next_tick: self.next_wakeup(t_end),
|
||||
}
|
||||
}
|
||||
|
||||
fn next_wakeup(&self, after: f64) -> Option<f64> {
|
||||
if !self.prefilling.is_empty() {
|
||||
return Some(after);
|
||||
}
|
||||
self.pending.front().map(|r| r.ready_at.max(after))
|
||||
}
|
||||
}
|
||||
226
src/instance/kv_cache.rs
Normal file
226
src/instance/kv_cache.rs
Normal file
@@ -0,0 +1,226 @@
|
||||
//! Two-tier LRU KV cache (L0 = GPU HBM, L1 = CPU DRAM / v6d).
|
||||
//!
|
||||
//! Each tier stores block hashes; the unit of accounting is one 16-token
|
||||
//! block. `longest_prefix` walks the hash slice front-to-back and returns the
|
||||
//! count of leading blocks present in the tier (and touches them so they
|
||||
//! stay hot).
|
||||
//!
|
||||
//! On insert, evicted block hashes are returned so the caller (instance) can
|
||||
//! propagate them to the global meta store if desired.
|
||||
|
||||
use ahash::AHashMap;
|
||||
|
||||
/// Doubly-linked-list-backed LRU keyed by block hash.
|
||||
#[derive(Debug)]
|
||||
pub struct LruBlocks {
|
||||
capacity: usize,
|
||||
map: AHashMap<u64, usize>,
|
||||
nodes: Vec<Node>,
|
||||
head: Option<usize>, // most recently used
|
||||
tail: Option<usize>, // least recently used
|
||||
free: Vec<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct Node {
|
||||
key: u64,
|
||||
prev: Option<usize>,
|
||||
next: Option<usize>,
|
||||
}
|
||||
|
||||
impl LruBlocks {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
capacity,
|
||||
map: AHashMap::with_capacity(capacity),
|
||||
nodes: Vec::with_capacity(capacity),
|
||||
head: None,
|
||||
tail: None,
|
||||
free: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.capacity
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.map.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.map.is_empty()
|
||||
}
|
||||
|
||||
pub fn contains(&self, key: u64) -> bool {
|
||||
self.map.contains_key(&key)
|
||||
}
|
||||
|
||||
/// Touch (move to MRU) if present. Returns whether the key was present.
|
||||
pub fn touch(&mut self, key: u64) -> bool {
|
||||
if let Some(&idx) = self.map.get(&key) {
|
||||
self.move_to_head(idx);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert blocks; evicted hashes appended to `evicted_out`. Reinserting an
|
||||
/// existing block just touches it.
|
||||
pub fn insert_blocks(&mut self, hashes: &[u64], evicted_out: &mut Vec<u64>) {
|
||||
for &h in hashes {
|
||||
if self.touch(h) {
|
||||
continue;
|
||||
}
|
||||
// need to make room?
|
||||
if self.map.len() == self.capacity {
|
||||
if let Some(tail_idx) = self.tail {
|
||||
let tail_key = self.nodes[tail_idx].key;
|
||||
self.detach(tail_idx);
|
||||
self.map.remove(&tail_key);
|
||||
self.free.push(tail_idx);
|
||||
evicted_out.push(tail_key);
|
||||
}
|
||||
}
|
||||
// allocate node
|
||||
let idx = if let Some(i) = self.free.pop() {
|
||||
self.nodes[i] = Node { key: h, prev: None, next: None };
|
||||
i
|
||||
} else {
|
||||
let i = self.nodes.len();
|
||||
self.nodes.push(Node { key: h, prev: None, next: None });
|
||||
i
|
||||
};
|
||||
self.map.insert(h, idx);
|
||||
self.attach_to_head(idx);
|
||||
}
|
||||
}
|
||||
|
||||
/// Longest leading prefix of `hashes` present; touches the matched blocks.
|
||||
pub fn longest_prefix(&mut self, hashes: &[u64]) -> usize {
|
||||
let mut n = 0usize;
|
||||
for &h in hashes {
|
||||
if !self.touch(h) {
|
||||
break;
|
||||
}
|
||||
n += 1;
|
||||
}
|
||||
n
|
||||
}
|
||||
|
||||
/// Read-only longest prefix without LRU updates (used for routing probes).
|
||||
pub fn longest_prefix_peek(&self, hashes: &[u64]) -> usize {
|
||||
let mut n = 0usize;
|
||||
for &h in hashes {
|
||||
if !self.map.contains_key(&h) {
|
||||
break;
|
||||
}
|
||||
n += 1;
|
||||
}
|
||||
n
|
||||
}
|
||||
|
||||
fn move_to_head(&mut self, idx: usize) {
|
||||
if Some(idx) == self.head {
|
||||
return;
|
||||
}
|
||||
self.detach(idx);
|
||||
self.attach_to_head(idx);
|
||||
}
|
||||
|
||||
fn detach(&mut self, idx: usize) {
|
||||
let (prev, next) = {
|
||||
let n = &self.nodes[idx];
|
||||
(n.prev, n.next)
|
||||
};
|
||||
if let Some(p) = prev {
|
||||
self.nodes[p].next = next;
|
||||
} else {
|
||||
// it was the head
|
||||
self.head = next;
|
||||
}
|
||||
if let Some(nx) = next {
|
||||
self.nodes[nx].prev = prev;
|
||||
} else {
|
||||
// it was the tail
|
||||
self.tail = prev;
|
||||
}
|
||||
self.nodes[idx].prev = None;
|
||||
self.nodes[idx].next = None;
|
||||
}
|
||||
|
||||
fn attach_to_head(&mut self, idx: usize) {
|
||||
let old_head = self.head;
|
||||
self.nodes[idx].prev = None;
|
||||
self.nodes[idx].next = old_head;
|
||||
if let Some(h) = old_head {
|
||||
self.nodes[h].prev = Some(idx);
|
||||
}
|
||||
self.head = Some(idx);
|
||||
if self.tail.is_none() {
|
||||
self.tail = Some(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Two-tier (HBM, DRAM) cache.
|
||||
#[derive(Debug)]
|
||||
pub struct TwoTierCache {
|
||||
pub l0: LruBlocks,
|
||||
pub l1: LruBlocks,
|
||||
}
|
||||
|
||||
impl TwoTierCache {
|
||||
pub fn new(l0_cap: usize, l1_cap: usize) -> Self {
|
||||
Self {
|
||||
l0: LruBlocks::new(l0_cap),
|
||||
l1: LruBlocks::new(l1_cap),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn lcp_full_partial_empty() {
|
||||
let mut c = LruBlocks::new(8);
|
||||
let mut ev = Vec::new();
|
||||
c.insert_blocks(&[1, 2, 3, 4], &mut ev);
|
||||
assert_eq!(c.longest_prefix(&[1, 2, 3, 4, 5, 6]), 4);
|
||||
assert_eq!(c.longest_prefix(&[1, 2, 9]), 2);
|
||||
assert_eq!(c.longest_prefix(&[99, 1]), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lru_eviction_order() {
|
||||
let mut c = LruBlocks::new(3);
|
||||
let mut ev = Vec::new();
|
||||
c.insert_blocks(&[1, 2, 3], &mut ev);
|
||||
assert!(ev.is_empty());
|
||||
// touch 1 -> MRU
|
||||
c.touch(1);
|
||||
// insert 4 -> evicts LRU which should be 2
|
||||
c.insert_blocks(&[4], &mut ev);
|
||||
assert_eq!(ev, vec![2]);
|
||||
assert!(c.contains(1));
|
||||
assert!(c.contains(3));
|
||||
assert!(c.contains(4));
|
||||
assert!(!c.contains(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn longest_prefix_touches_blocks() {
|
||||
let mut c = LruBlocks::new(3);
|
||||
let mut ev = Vec::new();
|
||||
c.insert_blocks(&[1, 2, 3], &mut ev);
|
||||
// touch 1 via prefix lookup (only the first matching block: 1)
|
||||
assert_eq!(c.longest_prefix(&[1, 99]), 1);
|
||||
// now insert 4 -> LRU should be 2 (since 3 was just inserted MRU after 2,
|
||||
// 1 is freshest, then 3, then 2)
|
||||
c.insert_blocks(&[4], &mut ev);
|
||||
assert_eq!(ev, vec![2]);
|
||||
}
|
||||
}
|
||||
6
src/instance/mod.rs
Normal file
6
src/instance/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod compute;
|
||||
pub mod kv_cache;
|
||||
#[allow(clippy::module_inception)]
|
||||
pub mod instance;
|
||||
|
||||
pub use instance::Instance;
|
||||
Reference in New Issue
Block a user