use ahash::{AHashMap, AHashSet}; use anyhow::{anyhow, Result}; use serde::Serialize; use std::cmp::min; use std::collections::BinaryHeap; use crate::config::Config; use crate::instance::kv_cache::LruBlocks; use crate::trace::RequestRecord; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] #[serde(rename_all = "snake_case")] pub enum ReplayEvictPolicy { Lru, Belady, } impl ReplayEvictPolicy { pub fn parse(s: &str) -> Result { match s { "lru" => Ok(Self::Lru), "belady" => Err(anyhow!( "exact belady is not supported for fixed-placement full-hierarchy ablation" )), other => Err(anyhow!("unknown evict policy: {other}")), } } pub fn as_str(&self) -> &'static str { match self { Self::Lru => "lru", Self::Belady => "belady", } } } #[derive(Debug, Clone)] pub struct PlacementEntry { pub req_id: u64, pub instance: u32, } #[derive(Debug, Clone, Serialize, Default)] pub struct ReplaySummary { pub num_requests: u64, pub total_blocks: u64, pub l0_hit_blocks: u64, pub l1_hit_blocks: u64, pub remote_hit_blocks: u64, pub miss_blocks: u64, pub hit_rate_l0: f64, pub hit_rate_l1: f64, pub hit_rate_remote: f64, pub miss_rate: f64, pub total_rdma_bytes: u64, pub total_pcie_bytes: u64, } impl ReplaySummary { fn from_counts( num_requests: usize, total_blocks: u64, l0_hit_blocks: u64, l1_hit_blocks: u64, remote_hit_blocks: u64, miss_blocks: u64, total_rdma_bytes: u64, total_pcie_bytes: u64, ) -> Self { let denom = total_blocks.max(1) as f64; Self { num_requests: num_requests as u64, total_blocks, l0_hit_blocks, l1_hit_blocks, remote_hit_blocks, miss_blocks, hit_rate_l0: l0_hit_blocks as f64 / denom, hit_rate_l1: l1_hit_blocks as f64 / denom, hit_rate_remote: remote_hit_blocks as f64 / denom, miss_rate: miss_blocks as f64 / denom, total_rdma_bytes, total_pcie_bytes, } } } #[derive(Debug, Clone, Copy)] enum FutureKind { L0, L1, } #[derive(Debug)] struct FutureIndex { local: AHashMap<(u32, u64), Vec>, global: AHashMap>, } impl FutureIndex { fn build(records: &[RequestRecord], placement: &[u32]) -> Self { let mut local: AHashMap<(u32, u64), Vec> = AHashMap::new(); let mut global: AHashMap> = AHashMap::new(); for (req_idx, record) in records.iter().enumerate() { let inst = placement[req_idx]; let mut seen = AHashSet::new(); for &block in &record.hash_ids { if !seen.insert(block) { continue; } local.entry((inst, block)).or_default().push(req_idx); global.entry(block).or_default().push((req_idx, inst)); } } Self { local, global } } fn next_local(&self, inst: u32, block: u64, current_req_idx: usize) -> usize { match self.local.get(&(inst, block)) { Some(indices) => next_after(indices, current_req_idx), None => usize::MAX, } } fn next_other(&self, inst: u32, block: u64, current_req_idx: usize) -> usize { let Some(indices) = self.global.get(&block) else { return usize::MAX; }; let start = first_after_pair(indices, current_req_idx); for &(req_idx, owner_inst) in indices.iter().skip(start) { if owner_inst != inst { return req_idx; } } usize::MAX } fn next_use(&self, kind: FutureKind, inst: u32, block: u64, current_req_idx: usize) -> usize { match kind { FutureKind::L0 => self.next_local(inst, block, current_req_idx), FutureKind::L1 => min( self.next_local(inst, block, current_req_idx), self.next_other(inst, block, current_req_idx), ), } } } fn next_after(indices: &[usize], current_req_idx: usize) -> usize { let pos = indices.partition_point(|&idx| idx <= current_req_idx); indices.get(pos).copied().unwrap_or(usize::MAX) } fn first_after_pair(indices: &[(usize, u32)], current_req_idx: usize) -> usize { indices.partition_point(|&(idx, _)| idx <= current_req_idx) } #[derive(Debug)] struct BeladyTier { capacity: usize, resident: AHashSet, versions: AHashMap, heap: BinaryHeap<(usize, u64, u64)>, next_version: u64, } impl BeladyTier { fn new(capacity: usize) -> Self { Self { capacity, resident: AHashSet::with_capacity(capacity), versions: AHashMap::with_capacity(capacity), heap: BinaryHeap::with_capacity(capacity), next_version: 0, } } fn contains(&self, key: u64) -> bool { self.resident.contains(&key) } fn remove(&mut self, key: u64) -> bool { if self.resident.remove(&key) { self.versions.remove(&key); true } else { false } } fn touch( &mut self, key: u64, current_req_idx: usize, kind: FutureKind, inst: u32, futures: &FutureIndex, ) -> bool { if !self.resident.contains(&key) { return false; } self.next_version += 1; let version = self.next_version; let next_use = futures.next_use(kind, inst, key, current_req_idx); self.versions.insert(key, version); self.heap.push((next_use, version, key)); true } fn insert( &mut self, key: u64, current_req_idx: usize, kind: FutureKind, inst: u32, futures: &FutureIndex, ) -> Option { if self.touch(key, current_req_idx, kind, inst, futures) { return None; } if self.capacity == 0 { return Some(key); } let mut evicted = None; if self.resident.len() == self.capacity { evicted = self.evict(current_req_idx, kind, inst, futures); } self.next_version += 1; let version = self.next_version; let next_use = futures.next_use(kind, inst, key, current_req_idx); self.resident.insert(key); self.versions.insert(key, version); self.heap.push((next_use, version, key)); evicted } fn evict( &mut self, current_req_idx: usize, kind: FutureKind, inst: u32, futures: &FutureIndex, ) -> Option { while let Some((stored_next_use, version, key)) = self.heap.pop() { if !self.resident.contains(&key) { continue; } let Some(current_version) = self.versions.get(&key).copied() else { continue; }; if current_version != version { continue; } let actual_next_use = futures.next_use(kind, inst, key, current_req_idx); if actual_next_use != stored_next_use { self.next_version += 1; let new_version = self.next_version; self.versions.insert(key, new_version); self.heap.push((actual_next_use, new_version, key)); continue; } self.resident.remove(&key); self.versions.remove(&key); return Some(key); } None } } #[derive(Debug)] enum Tier { Lru(LruBlocks), Belady(BeladyTier), } impl Tier { fn new(policy: ReplayEvictPolicy, capacity: usize) -> Self { match policy { ReplayEvictPolicy::Lru => Self::Lru(LruBlocks::new(capacity)), ReplayEvictPolicy::Belady => Self::Belady(BeladyTier::new(capacity)), } } fn contains(&self, key: u64) -> bool { match self { Self::Lru(tier) => tier.contains(key), Self::Belady(tier) => tier.contains(key), } } fn remove(&mut self, key: u64) -> bool { match self { Self::Lru(tier) => tier.remove(key), Self::Belady(tier) => tier.remove(key), } } fn touch( &mut self, key: u64, req_idx: usize, kind: FutureKind, inst: u32, futures: &FutureIndex, ) -> bool { match self { Self::Lru(tier) => tier.touch(key), Self::Belady(tier) => tier.touch(key, req_idx, kind, inst, futures), } } fn insert( &mut self, key: u64, req_idx: usize, kind: FutureKind, inst: u32, futures: &FutureIndex, ) -> Option { match self { Self::Lru(tier) => tier.insert_block(key), Self::Belady(tier) => tier.insert(key, req_idx, kind, inst, futures), } } fn longest_prefix_touch( &mut self, hashes: &[u64], req_idx: usize, kind: FutureKind, inst: u32, futures: &FutureIndex, ) -> usize { match self { Self::Lru(tier) => tier.longest_prefix(hashes), Self::Belady(tier) => { let mut matched = 0usize; for &hash in hashes { if !tier.touch(hash, req_idx, kind, inst, futures) { break; } matched += 1; } matched } } } fn longest_prefix_peek(&self, hashes: &[u64]) -> usize { match self { Self::Lru(tier) => tier.longest_prefix_peek(hashes), Self::Belady(tier) => { let mut matched = 0usize; for &hash in hashes { if !tier.contains(hash) { break; } matched += 1; } matched } } } } #[derive(Debug)] struct ReplayInstanceCache { l0: Tier, l1: Tier, } impl ReplayInstanceCache { fn new(policy: ReplayEvictPolicy, l0_cap: usize, l1_cap: usize) -> Self { Self { l0: Tier::new(policy, l0_cap), l1: Tier::new(policy, l1_cap), } } fn promote_l1_blocks_to_l0( &mut self, hashes: &[u64], req_idx: usize, inst: u32, futures: &FutureIndex, owners: &mut AHashMap>, ) { for &hash in hashes { if self.l1.remove(hash) { remove_owner(owners, hash, inst); } self.insert_block_into_l0(hash, req_idx, inst, futures, owners); } } fn fetch_remote_blocks_to_l0( &mut self, hashes: &[u64], req_idx: usize, inst: u32, futures: &FutureIndex, owners: &mut AHashMap>, ) { for &hash in hashes { self.stage_remote_block_in_l1(hash, req_idx, inst, futures, owners); if self.l1.remove(hash) { remove_owner(owners, hash, inst); } self.insert_block_into_l0(hash, req_idx, inst, futures, owners); } } fn insert_blocks_into_l0( &mut self, hashes: &[u64], req_idx: usize, inst: u32, futures: &FutureIndex, owners: &mut AHashMap>, ) { for &hash in hashes { self.insert_block_into_l0(hash, req_idx, inst, futures, owners); } } fn insert_block_into_l0( &mut self, hash: u64, req_idx: usize, inst: u32, futures: &FutureIndex, owners: &mut AHashMap>, ) { if self.l0.touch(hash, req_idx, FutureKind::L0, inst, futures) { return; } if self.l1.remove(hash) { remove_owner(owners, hash, inst); } if let Some(evicted_l0) = self.l0.insert(hash, req_idx, FutureKind::L0, inst, futures) { self.demote_into_l1(evicted_l0, req_idx, inst, futures, owners); } } fn stage_remote_block_in_l1( &mut self, hash: u64, req_idx: usize, inst: u32, futures: &FutureIndex, owners: &mut AHashMap>, ) { if self.l0.contains(hash) || self.l1.contains(hash) { return; } if let Some(evicted_l1) = self.l1.insert(hash, req_idx, FutureKind::L1, inst, futures) { remove_owner(owners, evicted_l1, inst); } add_owner(owners, hash, inst); } fn demote_into_l1( &mut self, hash: u64, req_idx: usize, inst: u32, futures: &FutureIndex, owners: &mut AHashMap>, ) { if self.l1.touch(hash, req_idx, FutureKind::L1, inst, futures) { return; } if let Some(evicted_l1) = self.l1.insert(hash, req_idx, FutureKind::L1, inst, futures) { remove_owner(owners, evicted_l1, inst); } add_owner(owners, hash, inst); } } fn add_owner(owners: &mut AHashMap>, hash: u64, inst: u32) { owners.entry(hash).or_default().insert(inst); } fn remove_owner(owners: &mut AHashMap>, hash: u64, inst: u32) { if let Some(bucket) = owners.get_mut(&hash) { bucket.remove(&inst); if bucket.is_empty() { owners.remove(&hash); } } } pub fn replay_fixed_placement( cfg: &Config, records: &[RequestRecord], placements: &[PlacementEntry], policy: ReplayEvictPolicy, ) -> Result { cfg.cluster .require_legacy_single_pool("fixed-placement replay")?; if records.len() != placements.len() { return Err(anyhow!( "records/placements length mismatch: {} vs {}", records.len(), placements.len() )); } let placement_by_req: AHashMap = placements.iter().map(|p| (p.req_id, p.instance)).collect(); let ordered_placement: Vec = records .iter() .map(|r| { placement_by_req .get(&r.req_id) .copied() .ok_or_else(|| anyhow!("missing placement for req_id={}", r.req_id)) }) .collect::>()?; let futures = FutureIndex::build(records, &ordered_placement); let block_bytes = cfg.model.kv_block_bytes() as f64; let l0_cap = (cfg.hardware.hbm_bytes / block_bytes).max(1.0) as usize; let l1_cap = (cfg.hardware.dram_bytes / block_bytes).max(1.0) as usize; let num_instances = cfg.cluster.total_instances() as usize; let mut caches: Vec = (0..num_instances) .map(|_| ReplayInstanceCache::new(policy, l0_cap, l1_cap)) .collect(); let mut owners: AHashMap> = AHashMap::new(); let mut total_blocks = 0u64; let mut l0_hit_blocks = 0u64; let mut l1_hit_blocks = 0u64; let mut remote_hit_blocks = 0u64; let mut miss_blocks = 0u64; let mut total_rdma_bytes = 0u64; let mut total_pcie_bytes = 0u64; for (req_idx, record) in records.iter().enumerate() { let inst = ordered_placement[req_idx]; let cache = &mut caches[inst as usize]; total_blocks += record.hash_ids.len() as u64; let l0_hits = cache.l0.longest_prefix_touch( &record.hash_ids, req_idx, FutureKind::L0, inst, &futures, ); let suffix_after_l0 = &record.hash_ids[l0_hits..]; let l1_hits = cache.l1.longest_prefix_peek(suffix_after_l0); if l1_hits > 0 { cache.promote_l1_blocks_to_l0( &suffix_after_l0[..l1_hits], req_idx, inst, &futures, &mut owners, ); } let suffix_after_l1 = &suffix_after_l0[l1_hits..]; let mut remote_hits = 0usize; for &hash in suffix_after_l1 { let any_remote = owners .get(&hash) .map(|bucket| bucket.iter().any(|owner| *owner != inst)) .unwrap_or(false); if any_remote { remote_hits += 1; } else { break; } } if remote_hits > 0 { cache.fetch_remote_blocks_to_l0( &suffix_after_l1[..remote_hits], req_idx, inst, &futures, &mut owners, ); } let misses = record.hash_ids.len() - l0_hits - l1_hits - remote_hits; let new_input = &record.hash_ids[(l0_hits + l1_hits + remote_hits)..]; if !new_input.is_empty() { cache.insert_blocks_into_l0(new_input, req_idx, inst, &futures, &mut owners); } l0_hit_blocks += l0_hits as u64; l1_hit_blocks += l1_hits as u64; remote_hit_blocks += remote_hits as u64; miss_blocks += misses as u64; let kv_block_bytes = cfg.model.kv_block_bytes(); total_rdma_bytes += (remote_hits as u64) * kv_block_bytes; total_pcie_bytes += ((l1_hits + remote_hits) as u64) * kv_block_bytes; } Ok(ReplaySummary::from_counts( records.len(), total_blocks, l0_hit_blocks, l1_hit_blocks, remote_hit_blocks, miss_blocks, total_rdma_bytes, total_pcie_bytes, )) }