//! Network cost models for RDMA (cross-instance) and PCIe (host<->GPU). //! //! Each link is modeled as a token bucket via a `next_free` cursor: a fetch of //! `bytes` starting at `now` waits until `next_free`, then advances the cursor //! by `bytes / bw`. Latency is added on top of transfer time. This captures //! contention without simulating individual packets. use crate::config::HardwareConfig; #[derive(Debug, Clone)] pub struct LinkModel { pub bw_bytes_per_s: f64, pub latency_s: f64, next_free: f64, } impl LinkModel { pub fn new(bw_bytes_per_s: f64, latency_s: f64) -> Self { Self { bw_bytes_per_s, latency_s, next_free: 0.0, } } /// Reserve a transfer of `bytes` starting at `now`. Returns the absolute /// time at which the bytes have all arrived (advances internal cursor). pub fn reserve(&mut self, now: f64, bytes: u64) -> f64 { if bytes == 0 { return now + self.latency_s; } let xfer = bytes as f64 / self.bw_bytes_per_s; let start = self.next_free.max(now); self.next_free = start + xfer; self.next_free + self.latency_s } /// Pure cost (no contention): how long to push `bytes` over this link. pub fn cost(&self, bytes: u64) -> f64 { if bytes == 0 { self.latency_s } else { self.latency_s + bytes as f64 / self.bw_bytes_per_s } } } /// Per-instance bundle of links: PCIe (host<->GPU) and RDMA (host<->remote). #[derive(Debug, Clone)] pub struct InstanceLinks { pub pcie: LinkModel, pub rdma: LinkModel, } impl InstanceLinks { pub fn from_hw(hw: &HardwareConfig) -> Self { Self { pcie: LinkModel::new(hw.pcie_bw, hw.pcie_latency_us * 1e-6), rdma: LinkModel::new(hw.rdma_bw, hw.rdma_latency_us * 1e-6), } } } #[cfg(test)] mod tests { use super::*; #[test] fn link_cost_matches_formula() { let l = LinkModel::new(1.0e9, 1.0e-6); // 1 GB / (1 GB/s) = 1s, plus 1us latency let t = l.cost(1_000_000_000); assert!((t - (1.0 + 1e-6)).abs() < 1e-9); } #[test] fn reserve_serializes_concurrent_transfers() { let mut l = LinkModel::new(1.0e9, 0.0); let t1 = l.reserve(0.0, 500_000_000); // 0.5s let t2 = l.reserve(0.0, 500_000_000); // contended -> 1.0s assert!((t1 - 0.5).abs() < 1e-9); assert!((t2 - 1.0).abs() < 1e-9); } }