commit 05592e6adc6a0d336f7c6145510b507b122e3b7b Author: Gahow Wang Date: Thu May 21 21:21:57 2026 +0800 Agentic workload PD separation analysis with trace-driven benchmarks Systematic study of prefill-decode disaggregation for agentic LLM workloads using production GLM-5.1 coder trace (2.1M requests, 71B input tokens). Key findings: - Cache-aware routing improves TPOT p90 by 15% and APC from 20.8% to 44.7% without PD separation, matching PD-Sep's decode isolation benefit - PD separation adds +72% TTFT overhead (KV transfer) with no TPOT gain when using the same cache-aware scheduler - Prefill remains compute-bound even at 95% KV cache reuse (AI >1000x vs decode AI <2), but absolute FLOPs drop 71% from cache hits - For agentic MoE workloads, cache-aware routing > PD separation Infrastructure: - Trace sampler preserving session structure + hash_ids for prefix sharing - Async trace replayer with streaming TTFT/TPOT/E2E measurement - Unified cache-aware + token-level load-balanced global scheduler proxy supporting both PD-colocated and PD-disaggregated (Mooncake/RDMA) modes - vLLM 0.18.1 scheduler patch for KV transfer abort race condition - Roofline analysis tool for prefill/decode compute characterization Co-Authored-By: Claude Opus 4.6 (1M context) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..400a0ea --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +*.pyc +.venv/ +*.egg-info/ +outputs/ +traces/ +third_party/ +*.log diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..5800384 --- /dev/null +++ b/TODO.md @@ -0,0 +1,25 @@ +实验 setup: + +GPU 机器:dash0,是 8*H20 的机器,可以直接 `ssh dash0` 进行连接访问 + +推理引擎:基于 vllm 0.18.1,self build,支持后续 patch 放在 git 中维护 + +模型:`~/models/Qwen/Qwen3-Coder-30B-A3B-Instruct` + +推理 trace:原始完整 2h trace 在 dash0 的 `~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl` + + + +性能指标:每个请求的 TTFT/TPOT/E2E/TBT,prefix KVCache hit ratio + + + +目标: + +1. 先实现标准的 trace-sampler,将 cluster 规模的原始 trace,sample 到合适当前机器数量的规模来跑,保持一份统一的 sample 后的 trace file 作为输入 +2. 实现标准的 trace replayer,保证能够体现线上流量的流量到来特征,KVCache 重用特征等 +3. 跑通 PD 分离,确认 PD 分离能够比普通的 PD 混合一起跑的性能要好,给出两者详细的性能对比以及原因分析 +4. 判断 trace 的 pattern,是否有必要 PD 完全混合或者 PD 分离 +5. 参考本地的 `~/phd/agentic-pd-hybrid`,判断是否能够实现一套 prefill-as-a-service 的架构,把重的 prefill 交给 prefill service,prefill service 能够从本地 GPU/DRAM/别的 GPU 机器上 pull KVCache,提高本地的 prefix KVCache hit ratio,不影响 decoding 的 prefill,就可以交给过去 PD 分离定义中 D-node 来做,提高 KVCache 命中率 + + diff --git a/analysis/pd_separation_analysis.md b/analysis/pd_separation_analysis.md new file mode 100644 index 0000000..decaa38 --- /dev/null +++ b/analysis/pd_separation_analysis.md @@ -0,0 +1,289 @@ +# PD 分离在 Agentic Workload 下的系统分析 + +## 1. Trace 特征 (GLM-5.1 Agentic Coder, 2h, 2.1M requests) + +``` +Total requests: 2,114,220 +Input tokens: 71.1B (avg 33.6k/req, p50=20k, p90=88k) +Output tokens: 940M (avg 445/req, p50=80, p90=811) +I/O ratio: 75.6x (aggregate), 217.8x (per-req median) +Prefill share: 98% of total tokens +Sessions: 1.3M (90% single-turn, 9% multi-turn) +``` + +**与传统 chatbot workload 的根本区别:** + +| 特征 | Traditional Chatbot | Agentic Coder (GLM-5.1) | +|------|-------------------|------------------------| +| I/O ratio | 1-10x | **75.6x** | +| Input p50 | 500-2000 tokens | **20,030 tokens** | +| Output p50 | 200-500 tokens | **80 tokens** | +| Prefill token share | 50-80% | **98%** | +| >32k input | <5% | **38%** | +| Multi-turn | 50-80% | **9%** | + +**KV Cache 复用特征:** + +``` +Unique hash blocks: 20,650,883 +Shared blocks (ref>1): 9,749,379 (47%) +Highly shared (ref>10): 2,428,160 +Intra-session reuse: 57% +Top-10 blocks ref count: 64,754 (system prompt blocks) +Theoretical cache hit: 71% (infinite cache, first 100k requests) +``` + +**Input length 分布与 token 占比:** + +``` + <1k: 202,396 reqs ( 9%) 89M tokens ( 0%) + 1-8k: 380,009 reqs (17%) 1.6B tokens ( 2%) + 8-32k: 720,871 reqs (34%) 12.7B tokens (17%) + 32-65k: 405,371 reqs (19%) 19.4B tokens (27%) + 65-131k: 394,014 reqs (18%) 35.7B tokens (50%) + >131k: 11,559 reqs ( 0%) 1.6B tokens ( 2%) +``` + +50% 的 token 计算量来自 65-131k 的长 context 请求。 + +## 2. DistServe 等 PD 分离的核心假设 + +DistServe (OSDI'24), Splitwise, TetriInfer 等 PD 分离工作基于以下假设: + +### 假设 A: Prefill 和 Decode 有不同的计算特征 +- **Prefill**: compute-bound, 高 GPU 利用率, batch 越大越好 +- **Decode**: memory-bandwidth-bound, 低 GPU 利用率, latency-sensitive + +**在 agentic workload 中的验证**: ✅ 成立,但需要细化 + +Roofline 分析显示(详见 Section 5): + +``` + Arithmetic Intensity (FLOP/byte) + Decode: 1.0 - 1.9 (memory-bound, 始终远低于 ridge point) + Prefill 0% reuse: 23,000-72,000 (strongly compute-bound) + Prefill 70% reuse: 10,000-42,000 (仍然 compute-bound!) + Prefill 95% reuse: 1,900-10,800 (仍然 compute-bound!) + Ridge point (H20): 37 +``` + +**即使 95% KV cache reuse,prefill 仍然是 compute-bound。** 但绝对计算量大幅减少。 + +### 假设 B: PD co-location 导致互相干扰 +- Prefill 的大 batch 计算会抢占 GPU 资源,导致 decode 的 TPOT 升高 +- Decode 的持续小计算会占用 GPU 调度槽位,影响 prefill 吞吐 + +**在 agentic workload 中的验证**: ⚠️ 干扰存在,但 **可被 cache-aware routing 消除** + +``` +同一 cache-aware scheduler, TP=1, 8 GPU: + Combined TP=1 DP=8: TPOT p90 = 0.073s + PD-Sep TP=1 4P+4D: TPOT p90 = 0.074s + → 差异 <2%, 不显著 +``` + +对比 round-robin routing: +``` + Combined TP=1 DP=8 (RR): TPOT p90 = 0.086s + Combined TP=1 DP=8 (cache-aware): TPOT p90 = 0.073s → -15% + → routing 改善 > PD 分离改善 +``` + +**原因**: cache-aware routing 让 high-cache-hit 的请求集中到特定 instance, +每个 instance 的实际 prefill 新 token 数大幅减少(71% 被 cache), +prefill-decode 干扰因 prefill 工作量降低而自然缓解。 + +### 假设 C: KV Cache 传输开销可以忽略 +- DistServe 假设 P→D 的 KV 传输延迟远小于 prefill 计算时间 +- 在 InfiniBand/NVLink 等高带宽互联下成立 + +**在 agentic workload 中的验证**: ❌ 不成立 + +``` +PD-Sep TTFT p50 = 1.261s vs Combined TTFT p50 = 0.731s (+72%) +``` + +原因: +1. Agentic workload 的 input 极长(p50=20k, p90=88k tokens),KV cache 很大 +2. 单请求 KV cache = 20k tokens × 48 layers × 2(K+V) × 512 bytes ≈ 1GB +3. 更重要的是 await-prefill 链路的串行延迟:proxy → prefill → KV transfer → decode → first token + +### 假设 D: 专用 prefill 节点可以提高 prefill 吞吐 +- Prefill 节点不做 decode,GPU 利用率更高 +- 可以用更大的 batch size + +**在 agentic workload 中的验证**: ⚠️ 收益被 cache 稀释 + +``` +理论 prefix cache hit (infinite cache): 71% of input tokens +实际 APC (Combined, cache-aware, 8 inst): 44.7% +``` + +71% cache hit → 只有 29% 的 input tokens 需要实际 prefill compute。 +Nominal avg input 33.6k → Actual avg new prefill ~9.7k tokens。 +专用 prefill 的 GPU 利用率优势因 prefill 工作量降低而缩小。 + +## 3. Roofline 分析:Prefill 在高 Cache Reuse 下的计算/访存特性 + +### 3.1 模型计算结构 + +``` +Qwen3-Coder-30B-A3B (MoE 128E top-8): + 48 layers, hidden=2048, heads=32, kv_heads=4 (GQA), head_dim=128 + FFN: 6144 intermediate per expert, 8 experts active per token + Active params per token: ~3B + +H20 GPU: 148 TFLOPS (BF16), 4.0 TB/s HBM → Ridge point: 37 FLOP/byte +``` + +### 3.2 Decode 永远 memory-bound + +``` +SeqLen FLOP Bytes AI (F/B) Bound +1,000 3.04e+10 3.01e+10 1.0 MEMORY +16,000 3.63e+10 3.16e+10 1.1 MEMORY +64,000 5.52e+10 3.63e+10 1.5 MEMORY +128,000 8.03e+10 4.26e+10 1.9 MEMORY +``` + +Decode 的 AI 始终 < 2,远低于 ridge point (37)。每个 decode step 只处理 1 个 token, +计算量极小,瓶颈在于读取模型权重和全量 KV cache。 + +### 3.3 Prefill 即使 95% reuse 仍然 compute-bound + +``` +SeqLen Reuse% NewTok AI (F/B) Bound vs Decode +32,000 0% 32,000 23,368 COMPUTE 18,190x +32,000 50% 16,000 14,899 COMPUTE 11,597x +32,000 70% 9,600 10,045 COMPUTE 7,819x +32,000 90% 3,200 3,821 COMPUTE 2,974x +32,000 95% 1,600 1,980 COMPUTE 1,542x + +64,000 0% 64,000 40,758 COMPUTE 26,813x +64,000 70% 19,200 20,610 COMPUTE 13,559x +64,000 90% 6,400 8,544 COMPUTE 5,621x +64,000 95% 3,200 4,549 COMPUTE 2,993x +``` + +### 3.4 为什么高 reuse 不改变 compute-bound 性质 + +KV cache reuse 减少的: +- K/V projection 计算(只算 new tokens) +- KV 写入(只写 new tokens) + +KV cache reuse **不减少**的: +- **Q×K^T attention**: 每个 new Q 都要和全部 seq_len 个 KV 做 attention + ``` + FLOPs = new_tokens × seq_len × head_dim × num_heads × 2 × num_layers + ``` + At 95% reuse, 32k seq: 1600 × 32000 × 128 × 32 × 2 × 48 ≈ 2×10^13 + 这个二次项在长 context 下主导总计算量 + +- **MoE FFN**: 每个 new token 激活 8 experts + ``` + FLOPs = new_tokens × 3 × D × D_ffn × 2 × K_experts × num_layers + ``` + +**Prefill 只在接近 100% reuse (< 10 new tokens) 时才变成 memory-bound。** + +### 3.5 Prefill 什么时候变 memory-bound + +``` +SeqLen=32,000: new_tokens ≈ 5-10 时 → AI ≈ 37 (ridge point) +SeqLen=64,000: new_tokens ≈ 5-10 时 → AI ≈ 37 +``` + +在实际 agentic trace 中: +``` +Compute-bound prefills: 961 (96%) +Memory-bound prefills: 37 (3%) ← 近 100% reuse 的极端 warm 请求 +``` + +### 3.6 关键洞察:"Compute-bound but lightweight" + +高 cache reuse 下的 prefill 处于一种独特状态: + +``` + Prefill bound 类型: Compute-bound (不变) + Prefill 绝对工作量: 大幅降低 (71% cache → 只算 29% 的 tokens) + Prefill-Decode 干扰: 因绝对工作量降低而减轻 (不需要物理隔离) +``` + +这解释了为什么 PD 分离没有帮助: +- PD 分离解决的是 "prefill 太重干扰 decode" 的问题 +- 但 cache-aware routing 已经把 prefill 的实际工作量降到足够轻 +- 物理隔离(PD 分离)的收益被 KV 传输开销抵消 + +## 4. 实验结果 + +### 4.1 完整实验矩阵 + +所有实验使用统一的 cache-aware + token-level load-balanced global scheduler。 + +| Config | OK/N | TTFT p50 | TPOT p90 | E2E p50 | APC | +|--------|------|----------|----------|---------|-----| +| TP=8 DP=1 (single instance) | 998/1000 | 0.467s | 0.129s | 3.30s | 53.0% | +| TP=2 DP=4 (4 inst, RR) | 997/999 | 0.844s | 0.095s | 4.92s | 33.5% | +| TP=1 DP=8 (8 inst, RR) | 997/999 | 1.836s | 0.086s | 6.67s | 20.8% | +| **TP=1 DP=8 (cache-aware)** | **997/999** | **0.731s** | **0.073s** | **4.48s** | **44.7%** | +| TP=1 PD-Sep 4P+4D (cache-aware) | 509/564 | 1.261s | 0.074s | 5.61s | 40.2% | + +### 4.2 Cache-Aware Routing 的效果 + +``` +Round-robin → Cache-aware (Combined TP=1 DP=8): + TTFT p50: 1.836s → 0.731s (-60%) + TPOT p90: 0.086s → 0.073s (-15%) + E2E p50: 6.673s → 4.480s (-33%) + APC: 20.8% → 44.7% (+24pp) +``` + +Cache-aware routing 的提升远大于 PD 分离的提升。 + +### 4.3 修复工程问题的过程 + +实验过程中发现并修复了多个 PD 分离的工程问题: + +| 问题 | 根因 | 修复 | +|------|------|------| +| Decode engine crash | vLLM scheduler assert: KV transfer 回调时 request 已 abort | Patch scheduler.py: assert → graceful skip | +| Head-of-line blocking | Proxy 按 request count 做 LB,不区分大小请求 | Token-level ongoing_tokens load balancing | +| "Timeout waiting for P side ready" | Proxy fire-and-forget prefill, decode 盲等 KV | Await-prefill + kv_load_failure_policy=recompute | +| Port collision on startup | 8 Mooncake instances 同时启动争抢 torch distributed port | Staggered startup + explicit MASTER_PORT | +| Cache routing "rich get richer" | score = ongoing - alpha*cached 导致流量集中到一个 instance | Normalized scoring: ongoing/avg_load - alpha*cache_ratio | + +## 5. 结论 + +### 5.1 PD 分离为什么在 Agentic Workload 不生效 + +1. **Cache reuse 大幅降低 prefill 绝对工作量(71% cache hit → 只算 29%)**,使得 P-D 干扰不显著 +2. **Prefill 仍然 compute-bound**(即使 95% reuse,AI 仍 >1000),但每个请求的总 FLOPs 因 new_tokens 减少而大幅降低 +3. **Cache-aware routing 提供 "软 PD 隔离"**,效果等同于物理隔离但无 KV 传输开销 +4. **KV 传输开销不可忽略**(TTFT +72%),抵消了隔离收益 +5. **MoE 模型 active params 小**(3B),per-token compute 本身较轻 + +### 5.2 PD 分离在什么条件下有价值 + +| 条件 | Chatbot (有价值) | Agentic (无价值) | +|------|-----------------|-----------------| +| Cache hit rate | <10% | **71%** | +| Model active params | 70B (dense) | **3B (MoE)** | +| I/O ratio | 1-10x | **75.6x** | +| Per-request prefill FLOPs | Very high | **Low (after cache)** | +| KV transfer cost vs prefill cost | Negligible | **Significant** | + +### 5.3 Agentic Workload 应该怎么优化 + +1. **Cache-aware routing** (已验证有效): 用 ongoing_tokens + prefix_cache_hit 做联合调度, + 将 APC 从 20.8% (RR) 提升到 44.7%,TPOT p90 降低 15% + +2. **Cross-instance KV cache sharing**: 让多个 instance 共享全局 KV pool, + 进一步提升 cache hit 率接近理论 71% + +3. **Prefix pre-warming**: 对 cold start 请求(55%,0% cache hit), + 预计算 common prefix (system prompt blocks) 并分发到所有 instance + +4. **不同 workload 类型的差异化处理**: + - Warm 请求 (22%, >90% cache hit, avg 1.3k new tokens): 几乎免费,任何 instance 都能处理 + - Cold 请求 (55%, 0% cache hit, avg 17.7k new tokens): prefill-heavy,需要有足够 compute + - 可以用 request-type-aware routing 进一步优化 diff --git a/analysis/roofline_analysis.md b/analysis/roofline_analysis.md new file mode 100644 index 0000000..bd8c065 --- /dev/null +++ b/analysis/roofline_analysis.md @@ -0,0 +1,130 @@ +# Prefill 在高 KV Cache Reuse 下的计算/访存分析 + +## Model & GPU + +``` +Qwen3-Coder-30B-A3B (MoE 128E top-8) + 48 layers, hidden=2048, heads=32, kv_heads=4 (GQA), head_dim=128 + FFN: 6144 intermediate per expert, 8 experts active per token + Active params: ~3B per token + +H20 GPU: 148 TFLOPS (BF16) / 4.0 TB/s HBM + Ridge point: 37 FLOP/byte +``` + +## 核心发现:Prefill 即使 95% reuse 仍然是 compute-bound + +``` + SeqLen Reuse% NewTok AI (F/B) Bound vs Decode AI + 32,000 0% 32,000 23368 COMPUTE 18189x + 32,000 70% 9,600 10045 COMPUTE 7819x + 32,000 90% 3,200 3821 COMPUTE 2974x + 32,000 95% 1,600 1980 COMPUTE 1541x + + 64,000 0% 64,000 40758 COMPUTE 26813x + 64,000 70% 19,200 20610 COMPUTE 13559x + 64,000 90% 6,400 8544 COMPUTE 5621x + 64,000 95% 3,200 4549 COMPUTE 2993x + + Decode (always): + 32,000 - 1 1.3 MEMORY 1x + 64,000 - 1 1.5 MEMORY 1x +``` + +**关键**: +- Decode 的 arithmetic intensity (AI) = 1.0-1.9 — 远低于 ridge point (37),始终 memory-bound +- Prefill 即使 95% reuse (只有 5% 新 token),AI 仍然 >1000 — 远高于 ridge point,依然 compute-bound + +## 为什么高 reuse 的 prefill 仍然是 compute-bound? + +### 原因:Attention 的计算量与 seq_len 成正比 + +当有 95% cache reuse (seq_len=64k, new_tokens=3200): +``` + Q projection: new_tokens × D × D → 只处理 3200 new tokens ✓ + K,V projection: new_tokens × D × D_kv → 只处理 3200 new tokens ✓ + + 但 Attention score: new_tokens × seq_len × D_head × H × L + = 3200 × 64000 × 128 × 32 × 48 + → 仍然要对全部 64k context 做注意力计算! + + FFN (MoE): new_tokens × 3 × D × D_ffn × 2 × K_experts × L + = 3200 × 3 × 2048 × 6144 × 2 × 8 × 48 + → 8 个 expert 的计算量仍然很大 +``` + +KV cache reuse 减少的是: +- K/V projection 的计算(只算 new tokens) +- KV 写入(只写 new tokens) + +但 **不减少的是**: +- Q 对全部 context 的 attention(每个 new Q 都要和所有 64k tokens 做 attention) +- MoE FFN 的计算(每个 new token 激活 8 个 expert) + +所以 prefill 的 FLOPs 虽然随 reuse 减少,但 **减少的是线性部分(投影),不减少的是二次部分(attention)**。 +在长 context 下,二次部分主导,使得即使 95% reuse,AI 仍远高于 ridge point。 + +## Prefill 什么时候才变成 memory-bound? + +``` + SeqLen=32,000: new_tokens ≈ 5-10 时 (reuse > 99.97%) → AI ≈ 37 + SeqLen=64,000: new_tokens ≈ 5-10 时 → AI ≈ 37 +``` + +只有在 **近乎 100% reuse**(仅 5-10 个 new tokens)时,prefill 才接近 memory-bound。 +在实际 agentic trace 中,只有 3% 的请求达到这个程度。 + +## 对 PD 分离的影响:修正之前的分析 + +### 之前的错误结论(已修正) +> "Prefill 大部分是 cache lookup 不是 compute" + +这是 **错误的**。即使 70% cache reuse,prefill 的 AI 仍然是 decode 的 7000-14000 倍。 +Prefill 始终是 compute-bound,decode 始终是 memory-bound。 + +### 那为什么 PD 分离在我们的实验中没有帮助? + +正确的解释不是 "prefill 变成了 memory-bound",而是: + +**1. Cache reuse 大幅减少了 prefill 的绝对计算量** +``` + 无 cache: avg 33.6k tokens × prefill compute = X FLOPs + 71% cache: avg 9.4k tokens × prefill compute = 0.28X FLOPs +``` +虽然 prefill 仍是 compute-bound,但 **总工作量只有原来的 28%**。 +在 8 instance 并行 + cache-aware routing 下,每个 instance 的 prefill 负载非常轻, +不足以产生对 decode 的显著干扰。 + +**2. MoE 模型的 per-token compute 本身较小** +Active params 只有 3B(全参数的 10%),单个 token 的计算量不大。 +对比 Dense 70B 模型,同样的 GPU 上 prefill-decode 干扰会严重得多。 + +**3. Cache-aware routing 的 "负载均衡" 效应** +当请求被路由到 cache 命中率高的 instance 时,该 instance 的实际 prefill 工作量更小, +自然减少了 P-D 争抢。这相当于 routing 层面的 "软 PD 分离"。 + +## 对比不同 workload 类型的 roofline 特征 + +``` + Prefill AI Decode AI PD-Sep 价值 + Dense 70B, Chatbot: 200-1000x 1-2x HIGH (compute-heavy P 干扰 D) + Dense 70B, Agent: 100-500x 1-2x MEDIUM (cache reduces P load) + MoE 30B, Chatbot: 100-500x 1-2x MEDIUM + MoE 30B, Agent: 50-200x 1-2x LOW (small active params + cache) + ← 我们的位置 +``` + +**PD 分离的 ROI 随着 cache hit 率升高和模型 active params 减少而下降。** +Agentic MoE 模型恰好在两个方面都不利于 PD 分离。 + +## 实际 trace 的 prefill bound 分布 + +``` + With actual trace prefix cache pattern (1000 sampled requests): + Compute-bound prefills: 961 (96%) + Memory-bound prefills: 37 (3%) ← 近 100% reuse 的 warm 请求 + (Decode is ALWAYS memory-bound) +``` + +96% 的 prefill 仍然是 compute-bound,但 **absolute compute 因 cache 大幅降低**。 +这是一个 "compute-bound but lightweight" 的独特状态 —— bound 类型没变,但强度大幅降低。 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..dd44cb6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,16 @@ +[project] +name = "agentic-kv" +version = "0.1.0" +description = "Trace-driven KV cache benchmarking for agentic LLM workloads" +requires-python = ">=3.10" +dependencies = [ + "httpx>=0.27", + "numpy>=1.24", +] + +[project.optional-dependencies] +dev = ["pytest"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" diff --git a/replayer/__init__.py b/replayer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/replayer/__main__.py b/replayer/__main__.py new file mode 100644 index 0000000..8dd279c --- /dev/null +++ b/replayer/__main__.py @@ -0,0 +1,55 @@ +"""CLI entry point: python -m replayer replay ...""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +from pathlib import Path + +from .replay import ReplayConfig, replay_trace + + +def main() -> None: + p = argparse.ArgumentParser(description="Trace replayer for vLLM benchmarking") + p.add_argument("--trace", type=Path, required=True, help="Sampled trace JSONL") + p.add_argument("--output", type=Path, required=True, help="Output metrics JSONL") + p.add_argument("--endpoint", type=str, required=True, + help="vLLM server URL (e.g. http://localhost:8000)") + p.add_argument("--model", type=str, default="default", help="Model name for API") + p.add_argument("--time-scale", type=float, default=1.0, + help="Time compression (>1 = faster)") + p.add_argument("--max-inflight-sessions", type=int, default=32) + p.add_argument("--concurrency-limit", type=int, default=256) + p.add_argument("--request-timeout", type=float, default=600.0) + p.add_argument("--request-limit", type=int, default=None, + help="Limit number of requests to replay") + p.add_argument("-v", "--verbose", action="store_true") + args = p.parse_args() + + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + config = ReplayConfig( + trace_path=args.trace, + output_path=args.output, + endpoint_url=args.endpoint.rstrip("/"), + model_name=args.model, + time_scale=args.time_scale, + max_inflight_sessions=args.max_inflight_sessions, + concurrency_limit=args.concurrency_limit, + request_timeout_s=args.request_timeout, + request_limit=args.request_limit, + ) + + results = asyncio.run(replay_trace(config)) + succeeded = sum(1 for r in results if r.error is None) + print(f"\nDone: {succeeded}/{len(results)} requests succeeded") + print(f"Metrics: {args.output}") + print(f"Summary: {args.output.with_suffix('.summary.json')}") + + +if __name__ == "__main__": + main() diff --git a/replayer/metrics.py b/replayer/metrics.py new file mode 100644 index 0000000..f48cbf2 --- /dev/null +++ b/replayer/metrics.py @@ -0,0 +1,107 @@ +"""Per-request metrics collection and summary reporting.""" + +from __future__ import annotations + +import asyncio +import json +import statistics +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + + +@dataclass(frozen=True) +class RequestMetrics: + request_id: str + session_id: str + turn_id: int + trace_timestamp_s: float + input_length: int + output_length: int + request_type: str + effective_input_length: int | None + cached_tokens: int + latency_s: float | None + ttft_s: float | None + tpot_s: float | None + actual_output_tokens: int | None = None + requested_output_tokens: int | None = None + finish_reason: str | None = None + error: str | None = None + + +class IncrementalMetricSink: + """Append each RequestMetrics to JSONL immediately (crash-safe).""" + + def __init__(self, path: Path): + self.path = path + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("") + self._lock = asyncio.Lock() + self._fh = path.open("a", encoding="utf-8", buffering=1) + + async def append(self, metric: RequestMetrics) -> None: + line = json.dumps(asdict(metric), sort_keys=True) + "\n" + async with self._lock: + self._fh.write(line) + self._fh.flush() + + def close(self) -> None: + try: + self._fh.flush() + self._fh.close() + except Exception: + pass + + +def write_summary_json(path: Path, rows: list[RequestMetrics]) -> None: + successful = [r for r in rows if r.error is None] + latencies = [r.latency_s for r in successful if r.latency_s is not None] + ttfts = [r.ttft_s for r in successful if r.ttft_s is not None] + tpots = [r.tpot_s for r in successful if r.tpot_s is not None] + + total_input = sum(r.input_length for r in successful) + total_cached = sum(r.cached_tokens for r in successful) + + summary: dict[str, Any] = { + "request_count": len(rows), + "success_count": len(successful), + "error_count": sum(1 for r in rows if r.error is not None), + "latency_stats_s": _stats(latencies), + "ttft_stats_s": _stats(ttfts), + "tpot_stats_s": _stats(tpots), + "cache_hit_request_count": sum(1 for r in successful if r.cached_tokens > 0), + "total_input_tokens": total_input, + "total_cached_tokens": total_cached, + "prefix_cache_hit_ratio": total_cached / total_input if total_input > 0 else 0.0, + "cached_tokens_stats": _stats([float(r.cached_tokens) for r in successful]), + "actual_output_tokens_stats": _stats( + [float(r.actual_output_tokens) for r in successful + if r.actual_output_tokens is not None] + ), + } + + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as fh: + json.dump(summary, fh, indent=2, sort_keys=True) + + +def _stats(values: list[float | None]) -> dict[str, float] | None: + clean = [v for v in values if v is not None] + if not clean: + return None + clean.sort() + return { + "count": float(len(clean)), + "mean": statistics.fmean(clean), + "p50": _percentile(clean, 0.50), + "p90": _percentile(clean, 0.90), + "p99": _percentile(clean, 0.99), + } + + +def _percentile(sorted_vals: list[float], pct: float) -> float: + if len(sorted_vals) == 1: + return sorted_vals[0] + idx = round((len(sorted_vals) - 1) * pct) + return sorted_vals[idx] diff --git a/replayer/replay.py b/replayer/replay.py new file mode 100644 index 0000000..4f3d836 --- /dev/null +++ b/replayer/replay.py @@ -0,0 +1,343 @@ +"""Trace replayer — send requests to vLLM following trace timing. + +Supports both vLLM's /v1/completions (OpenAI-compatible) and /generate +(SGLang-style) endpoints. Uses hash_ids from the trace to construct +synthetic prompts that reproduce realistic prefix-cache hit patterns. + +Key behaviors: + - Per-session sequencing: turns within a session are sent in order, + each waiting for the previous to complete before dispatching. + - Inter-session arrival: sessions start at their trace timestamps, + scaled by --time-scale. + - Concurrency control: --max-inflight-sessions caps concurrent sessions; + --concurrency-limit caps total in-flight requests. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +from collections import defaultdict +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import random as _random + +import httpx + +from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json +from .trace import TraceRequest, load_trace + +logger = logging.getLogger(__name__) + +BLOCK_SIZE = 512 +VOCAB_SIZE = 151936 +TOKEN_RANGE_START = 100 +TOKEN_RANGE_END = VOCAB_SIZE - 100 + +_block_cache: dict[int, list[int]] = {} + + +def _hash_id_to_token_ids(hash_id: int) -> list[int]: + """Deterministically map a hash_id to BLOCK_SIZE token IDs.""" + if hash_id in _block_cache: + return _block_cache[hash_id] + rng = _random.Random(hash_id) + ids = [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END) for _ in range(BLOCK_SIZE)] + _block_cache[hash_id] = ids + return ids + + +@dataclass +class ReplayConfig: + trace_path: Path + output_path: Path + endpoint_url: str # comma-separated for round-robin: "http://host:8000,http://host:8001" + time_scale: float = 1.0 + max_inflight_sessions: int = 32 + concurrency_limit: int = 256 + request_timeout_s: float = 600.0 + request_limit: int | None = None + model_name: str = "default" + + +def _build_prompt_token_ids(req: TraceRequest) -> list[int]: + """Build token IDs from hash_ids for prefix-cache-aware replay. + + Same hash_id prefix → same token ID prefix → APC cache hit in vLLM. + """ + ids: list[int] = [] + for hid in req.hash_ids: + ids.extend(_hash_id_to_token_ids(hid)) + # Pad to input_length with deterministic tokens + pad_rng = _random.Random(req.chat_id) + while len(ids) < req.input_length: + ids.append(pad_rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END)) + return ids[:req.input_length] + + +@dataclass +class _SessionState: + session_id: str + turns: list[TraceRequest] + metrics: list[RequestMetrics] = field(default_factory=list) + + +_endpoint_counter = 0 + + +def _pick_endpoint(config: ReplayConfig) -> str: + """Round-robin across comma-separated endpoints.""" + global _endpoint_counter + endpoints = [e.strip() for e in config.endpoint_url.split(",")] + url = endpoints[_endpoint_counter % len(endpoints)] + _endpoint_counter += 1 + return url + + +async def _dispatch_request( + *, + client: httpx.AsyncClient, + config: ReplayConfig, + req: TraceRequest, + prompt_token_ids: list[int], + sem: asyncio.Semaphore, +) -> RequestMetrics: + """Send one request via /v1/completions (streaming) and collect metrics.""" + endpoint = _pick_endpoint(config) + payload = { + "model": config.model_name, + "prompt": prompt_token_ids, + "max_tokens": max(1, req.output_length), + "temperature": 0, + "stream": True, + "stream_options": {"include_usage": True}, + } + + start = time.perf_counter() + ttft_s = None + n_output = 0 + cached_tokens = 0 + finish_reason = None + err = None + token_times: list[float] = [] + + async with sem: + try: + async with client.stream( + "POST", + f"{endpoint}/v1/completions", + json=payload, + timeout=config.request_timeout_s, + ) as resp: + resp.raise_for_status() + async for raw_line in resp.aiter_lines(): + if not raw_line or not raw_line.startswith("data:"): + continue + data = raw_line[5:].strip() + if data == "[DONE]": + break + try: + chunk = json.loads(data) + except json.JSONDecodeError: + continue + + now = time.perf_counter() + if ttft_s is None: + ttft_s = now - start + + choices = chunk.get("choices", []) + if choices: + delta = choices[0].get("text", "") + if delta: + token_times.append(now) + fr = choices[0].get("finish_reason") + if fr: + finish_reason = fr + + usage = chunk.get("usage") + if usage: + n_output = usage.get("completion_tokens", n_output) + cached_tokens = _extract_cached_tokens(usage) + except Exception as exc: + err = repr(exc)[:300] + + end = time.perf_counter() + e2e = end - start + if n_output == 0 and token_times: + n_output = len(token_times) + + tpot = 0.0 + if len(token_times) > 1: + inter_token = [token_times[i+1] - token_times[i] + for i in range(len(token_times) - 1)] + tpot = sum(inter_token) / len(inter_token) + + return RequestMetrics( + request_id=req.request_id, + session_id=req.session_id, + turn_id=req.turn_id, + trace_timestamp_s=req.timestamp_s, + input_length=req.input_length, + output_length=req.output_length, + request_type=req.request_type, + effective_input_length=len(prompt_token_ids), + cached_tokens=cached_tokens, + latency_s=e2e, + ttft_s=ttft_s, + tpot_s=tpot, + actual_output_tokens=n_output, + requested_output_tokens=req.output_length, + finish_reason=finish_reason, + error=err, + ) + + +def _extract_cached_tokens(usage: dict) -> int: + ct = 0 + details = usage.get("prompt_tokens_details") + if isinstance(details, dict): + ct = details.get("cached_tokens", 0) or 0 + if ct == 0: + ct = usage.get("cached_tokens", 0) or 0 + return int(ct) + + +async def _run_session( + *, + state: _SessionState, + config: ReplayConfig, + client: httpx.AsyncClient, + session_sem: asyncio.Semaphore, + request_sem: asyncio.Semaphore, + earliest_ts: float, + sweep_start: float, + sink: IncrementalMetricSink, +) -> list[RequestMetrics]: + async with session_sem: + # Wait until this session's start time + offset = (state.turns[0].timestamp_s - earliest_ts) / config.time_scale + wait = offset - (time.perf_counter() - sweep_start) + if wait > 0: + await asyncio.sleep(wait) + + for req in state.turns: + # Intra-session: wait for turn's relative offset + if req != state.turns[0]: + target = (req.timestamp_s - state.turns[0].timestamp_s) / config.time_scale + elapsed = time.perf_counter() - sweep_start - offset + if elapsed < target: + await asyncio.sleep(target - elapsed) + + token_ids = _build_prompt_token_ids(req) + metric = await _dispatch_request( + client=client, config=config, req=req, + prompt_token_ids=token_ids, sem=request_sem, + ) + state.metrics.append(metric) + await sink.append(metric) + + return state.metrics + + +async def _snapshot_prefix_cache_metrics(url_csv: str) -> dict[str, float]: + """Scrape vLLM /metrics for prefix cache counters (aggregated across endpoints).""" + total = {"queries": 0.0, "hits": 0.0} + endpoints = [e.strip() for e in url_csv.split(",")] + async with httpx.AsyncClient(timeout=10) as c: + for url in endpoints: + try: + r = await c.get(f"{url}/metrics") + for line in r.text.split("\n"): + if line.startswith("vllm:prefix_cache_queries_total"): + total["queries"] += float(line.split()[-1]) + elif line.startswith("vllm:prefix_cache_hits_total"): + total["hits"] += float(line.split()[-1]) + except Exception: + pass + return total + + +async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]: + """Main entry: load trace, replay against endpoint, return metrics.""" + requests = load_trace(config.trace_path, request_limit=config.request_limit) + if not requests: + return [] + + by_session: dict[str, list[TraceRequest]] = defaultdict(list) + for r in requests: + by_session[r.session_id].append(r) + for sid in by_session: + by_session[sid].sort(key=lambda r: (r.turn_id, r.timestamp_s)) + + sessions = sorted(by_session.items(), key=lambda kv: kv[1][0].timestamp_s) + earliest_ts = sessions[0][1][0].timestamp_s + + session_sem = asyncio.Semaphore(config.max_inflight_sessions) + request_sem = asyncio.Semaphore(config.concurrency_limit) + + sink = IncrementalMetricSink(config.output_path) + + n_sessions = len(sessions) + n_requests = len(requests) + logger.info("Replaying %d sessions (%d requests), time_scale=%.1f", + n_sessions, n_requests, config.time_scale) + + pre_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url) + sweep_start = time.perf_counter() + + try: + limits = httpx.Limits( + max_connections=2000, + max_keepalive_connections=500, + keepalive_expiry=30.0, + ) + async with httpx.AsyncClient( + timeout=config.request_timeout_s, + trust_env=False, + limits=limits, + ) as client: + tasks = [ + asyncio.create_task(_run_session( + state=_SessionState(session_id=sid, turns=turns), + config=config, client=client, + session_sem=session_sem, request_sem=request_sem, + earliest_ts=earliest_ts, sweep_start=sweep_start, + sink=sink, + )) + for sid, turns in sessions + ] + all_results = await asyncio.gather(*tasks) + finally: + sink.close() + + sweep_elapsed = time.perf_counter() - sweep_start + post_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url) + + flat = [m for group in all_results for m in group] + summary_path = config.output_path.with_suffix(".summary.json") + write_summary_json(summary_path, flat) + + # Compute aggregate prefix cache hit ratio from /metrics deltas + delta_queries = post_metrics.get("queries", 0) - pre_metrics.get("queries", 0) + delta_hits = post_metrics.get("hits", 0) - pre_metrics.get("hits", 0) + hit_ratio = delta_hits / delta_queries if delta_queries > 0 else 0.0 + + logger.info("Done: %d/%d succeeded in %.1fs", sum(1 for m in flat if m.error is None), len(flat), sweep_elapsed) + logger.info("Prefix cache: %.1f%% hit ratio (%d/%d tokens)", + hit_ratio * 100, int(delta_hits), int(delta_queries)) + + # Append cache stats to summary + import json as _json + summary = _json.loads(summary_path.read_text()) + summary["prefix_cache_queries_tokens"] = int(delta_queries) + summary["prefix_cache_hits_tokens"] = int(delta_hits) + summary["prefix_cache_hit_ratio"] = hit_ratio + summary["wall_clock_s"] = sweep_elapsed + summary_path.write_text(_json.dumps(summary, indent=2, sort_keys=True)) + + logger.info("Summary written to %s", summary_path) + return flat diff --git a/replayer/trace.py b/replayer/trace.py new file mode 100644 index 0000000..17f6767 --- /dev/null +++ b/replayer/trace.py @@ -0,0 +1,84 @@ +"""Trace data structures and loader for the Ali agentic-coder trace format. + +Trace format (one JSON per line): + chat_id, parent_chat_id, timestamp, input_length, output_length, + type, turn, hash_ids[] + +Sessions are derived from parent_chat_id chains: + - parent_chat_id == -1 → new session root + - parent_chat_id >= 0 → belongs to the same session as the parent +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class TraceRequest: + request_id: str + session_id: str + chat_id: int + parent_chat_id: int + timestamp_s: float + input_length: int + output_length: int + request_type: str + turn_id: int + hash_ids: tuple[int, ...] + + +def load_trace( + path: Path, + *, + request_limit: int | None = None, +) -> list[TraceRequest]: + """Load trace and resolve session IDs from parent_chat_id chains.""" + chat_to_session: dict[int, str] = {} + requests: list[TraceRequest] = [] + + with path.open("r", encoding="utf-8") as fh: + for idx, line in enumerate(fh): + if request_limit is not None and len(requests) >= request_limit: + break + row = json.loads(line) + chat_id = int(row["chat_id"]) + parent_chat_id = int(row["parent_chat_id"]) + + if "session_id" in row: + session_id = str(row["session_id"]) + else: + session_id = _resolve_session_id( + chat_id, parent_chat_id, chat_to_session, + ) + chat_to_session[chat_id] = session_id + + requests.append(TraceRequest( + request_id=f"{session_id}:{row['turn']}:{chat_id}:{idx}", + session_id=session_id, + chat_id=chat_id, + parent_chat_id=parent_chat_id, + timestamp_s=float(row["timestamp"]), + input_length=int(row["input_length"]), + output_length=int(row["output_length"]), + request_type=str(row["type"]), + turn_id=int(row["turn"]), + hash_ids=tuple(int(h) for h in row.get("hash_ids", [])), + )) + + return requests + + +def _resolve_session_id( + chat_id: int, + parent_chat_id: int, + chat_to_session: dict[int, str], +) -> str: + if parent_chat_id < 0: + session_id = str(chat_id) + else: + session_id = chat_to_session.get(parent_chat_id, str(parent_chat_id)) + chat_to_session[chat_id] = session_id + return session_id diff --git a/scripts/analyze_cache_hit.py b/scripts/analyze_cache_hit.py new file mode 100644 index 0000000..0fbe996 --- /dev/null +++ b/scripts/analyze_cache_hit.py @@ -0,0 +1,196 @@ +"""Analyze theoretical vs actual KV cache hit ratio for the agentic trace.""" +import json +from collections import Counter + +rows = [json.loads(l) for l in open("traces/sampled_1000req_seed42.jsonl")] +rows.sort(key=lambda r: float(r["timestamp"])) + +BLOCK_SIZE = 512 + +# === 1. Theoretical max: infinite cache, single instance === +total_tokens = 0 +total_cached = 0 +seen_blocks = set() +per_req = [] + +for r in rows: + input_len = r["input_length"] + hash_ids = r.get("hash_ids", []) + total_tokens += input_len + + cached_blocks = 0 + prefix_broken = False + for hid in hash_ids: + if not prefix_broken and hid in seen_blocks: + cached_blocks += 1 + else: + prefix_broken = True + + cached_tokens = cached_blocks * BLOCK_SIZE + total_cached += cached_tokens + for hid in hash_ids: + seen_blocks.add(hid) + + per_req.append({ + "input_length": input_len, + "cached_tokens": cached_tokens, + "new_tokens": max(0, input_len - cached_tokens), + "ratio": cached_tokens / input_len if input_len > 0 else 0, + }) + +sep = "=" * 70 +print(sep) +print(" THEORETICAL KV CACHE HIT (infinite cache, single instance)") +print(sep) +print(f" Total input tokens: {total_tokens:>14,}") +print(f" Cacheable (prefix hit): {total_cached:>14,} ({total_cached*100//total_tokens}%)") +print(f" Must prefill (new): {total_tokens-total_cached:>14,} ({(total_tokens-total_cached)*100//total_tokens}%)") + +ratios = sorted([s["ratio"] for s in per_req if s["input_length"] > 0]) +new_tokens = sorted([s["new_tokens"] for s in per_req if s["input_length"] > 0]) +p = lambda v, q: v[min(int(q*len(v)), len(v)-1)] + +print(f"\n Per-request cache hit ratio:") +print(f" p10={p(ratios,.1)*100:.1f}% p50={p(ratios,.5)*100:.1f}% p90={p(ratios,.9)*100:.1f}% mean={sum(ratios)/len(ratios)*100:.1f}%") +high = sum(1 for r in ratios if r > 0.5) +very_high = sum(1 for r in ratios if r > 0.9) +zero = sum(1 for r in ratios if r == 0) +print(f" 0% hit (cold start): {zero} ({zero*100//len(ratios)}%)") +print(f" >50% hit: {high} ({high*100//len(ratios)}%)") +print(f" >90% hit: {very_high} ({very_high*100//len(ratios)}%)") + +print(f"\n Actual new tokens to prefill per request:") +print(f" p10={p(new_tokens,.1):>7,} p50={p(new_tokens,.5):>7,} p90={p(new_tokens,.9):>7,} max={max(new_tokens):>7,}") + +# === 2. 4-instance split (simulating DP=4 or 4 prefill instances) === +print(f"\n{sep}") +print(" 4-INSTANCE SPLIT (round-robin, per-instance cache)") +print(sep) + +instance_seen = [set() for _ in range(4)] +inst_total = [0]*4 +inst_cached = [0]*4 + +for i, r in enumerate(rows): + inst = i % 4 + input_len = r["input_length"] + hash_ids = r.get("hash_ids", []) + inst_total[inst] += input_len + + cached_blocks = 0 + prefix_broken = False + for hid in hash_ids: + if not prefix_broken and hid in instance_seen[inst]: + cached_blocks += 1 + else: + prefix_broken = True + + inst_cached[inst] += cached_blocks * BLOCK_SIZE + for hid in hash_ids: + instance_seen[inst].add(hid) + +rr_total = sum(inst_total) +rr_cached = sum(inst_cached) +print(f" Cache hit ratio (RR): {rr_cached*100//rr_total}%") + +# === 3. Cache-aware routing (route to instance with best prefix match) === +print(f"\n{sep}") +print(" 4-INSTANCE CACHE-AWARE ROUTING") +print(sep) + +ca_seen = [set() for _ in range(4)] +ca_total = [0]*4 +ca_cached = [0]*4 + +for r in rows: + input_len = r["input_length"] + hash_ids = r.get("hash_ids", []) + + # Pick instance with most prefix blocks cached + best_inst = 0 + best_hit = 0 + for inst in range(4): + hit = 0 + for hid in hash_ids: + if hid in ca_seen[inst]: + hit += 1 + else: + break + if hit > best_hit: + best_hit = hit + best_inst = inst + + ca_total[best_inst] += input_len + ca_cached[best_inst] += best_hit * BLOCK_SIZE + for hid in hash_ids: + ca_seen[best_inst].add(hid) + +ca_total_sum = sum(ca_total) +ca_cached_sum = sum(ca_cached) +print(f" Cache hit ratio: {ca_cached_sum*100//ca_total_sum}%") +print(f" vs RR: {rr_cached*100//rr_total}% -> {ca_cached_sum*100//ca_total_sum}% (+{(ca_cached_sum-rr_cached)*100//rr_total}pp)") + +# === 4. Session structure analysis === +print(f"\n{sep}") +print(" SESSION & MULTI-TURN ANALYSIS") +print(sep) + +sessions = {} +chat_to_session = {} +for r in rows: + cid = int(r["chat_id"]) + pid = int(r["parent_chat_id"]) + sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))) + chat_to_session[cid] = str(sid) + sessions.setdefault(str(sid), []).append(r) + +multi = {k: v for k, v in sessions.items() if len(v) > 1} +single = {k: v for k, v in sessions.items() if len(v) == 1} + +print(f" Sessions: {len(sessions)} total, {len(multi)} multi-turn ({len(multi)*100//len(sessions)}%)") + +# Multi-turn: cache hit in turn 2+ +mt_new = 0 +mt_reuse = 0 +for sid, turns in multi.items(): + turns.sort(key=lambda r: r["turn"]) + prev_blocks = set() + for t in turns: + hids = t.get("hash_ids", []) + for hid in hids: + if hid in prev_blocks: + mt_reuse += BLOCK_SIZE + else: + mt_new += BLOCK_SIZE + prev_blocks.add(hid) + +mt_total_tok = mt_new + mt_reuse +print(f" Multi-turn intra-session reuse: {mt_reuse*100//mt_total_tok}% of tokens") +print(f" (Turn 2+ reuses KV from prior turns in same session)") + +# Single-turn: cross-session sharing via system prompt +block_freq = Counter() +for r in rows: + for hid in r.get("hash_ids", []): + block_freq[hid] += 1 + +shared = {k: v for k, v in block_freq.items() if v > 1} +top = block_freq.most_common(5) +print(f"\n Cross-session block sharing:") +print(f" Unique blocks: {len(block_freq):,}") +print(f" Shared (ref>1): {len(shared):,} ({len(shared)*100//len(block_freq)}%)") +print(f" Top-5 block ref counts: {[c for _,c in top]}") +print(f" (Shared blocks = system prompt / common code context)") + +# === 5. Implication for PD separation === +print(f"\n{sep}") +print(" IMPLICATION FOR PD SEPARATION") +print(sep) +actual_prefill_pct = (total_tokens - total_cached) * 100 // total_tokens +print(f" With perfect caching, only {actual_prefill_pct}% of tokens need actual prefill compute.") +print(f" The remaining {100-actual_prefill_pct}% are prefix cache hits (skip prefill, reuse KV).") +print(f" This means PD separation's prefill overhead is much smaller than it appears:") +print(f" - Nominal avg input: {total_tokens//len(rows):,} tokens/request") +new_per_req = sorted([s["new_tokens"] for s in per_req if s["input_length"] > 0]) +print(f" - Actual avg prefill: {sum(new_per_req)//len(new_per_req):,} tokens/request (after cache hit)") +print(f" - KV transfer size is also reduced (only transfer new blocks)") diff --git a/scripts/analyze_trace.py b/scripts/analyze_trace.py new file mode 100644 index 0000000..f130e10 --- /dev/null +++ b/scripts/analyze_trace.py @@ -0,0 +1,163 @@ +"""Analyze trace patterns to assess PD separation benefit. + +Computes metrics relevant to deciding PD-combined vs PD-separated: + - Input/output token ratio (high ratio = prefill-heavy → PD sep benefits) + - Prefix sharing density (high sharing → benefits from shared KV cache) + - Session length distribution (multi-turn = more prefix reuse) + - Arrival burstiness (bursty prefill → PD sep can absorb spikes) + - Compute-intensity ratio: prefill FLOP share vs decode FLOP share + +Usage: + python scripts/analyze_trace.py --input traces/sampled_1000req_seed42.jsonl +""" + +from __future__ import annotations + +import argparse +import collections +import json +import statistics +from pathlib import Path + + +def main(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--input", type=Path, required=True) + args = p.parse_args() + + rows = [] + with args.input.open() as fh: + for line in fh: + rows.append(json.loads(line)) + + # Session structure + sessions: dict[str, list[dict]] = collections.OrderedDict() + chat_to_session: dict[int, str] = {} + for r in rows: + cid = int(r["chat_id"]) + pid = int(r["parent_chat_id"]) + sid = r.get("session_id") + if sid is None: + sid = str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)) + chat_to_session[cid] = str(sid) + sessions.setdefault(str(sid), []).append(r) + + n_sessions = len(sessions) + turns_per_session = [len(v) for v in sessions.values()] + multi_turn = sum(1 for t in turns_per_session if t > 1) + + input_lens = [r["input_length"] for r in rows] + output_lens = [r["output_length"] for r in rows] + total_input = sum(input_lens) + total_output = sum(output_lens) + + print("=" * 60) + print("Trace Pattern Analysis for PD Separation Decision") + print("=" * 60) + + # 1. Input/Output ratio + io_ratio = total_input / max(total_output, 1) + print(f"\n1. Input/Output Token Ratio") + print(f" Total input tokens: {total_input:>12,}") + print(f" Total output tokens: {total_output:>12,}") + print(f" I/O ratio: {io_ratio:>12.1f}x") + print(f" → {'STRONGLY' if io_ratio > 50 else 'Moderately' if io_ratio > 10 else 'Weakly'} prefill-heavy") + + # 2. Prefill compute share + # Approximate: prefill FLOP ∝ input_length, decode FLOP ∝ output_length * input_length + # More precisely: prefill dominates when input >> output + prefill_share = total_input / (total_input + total_output) + print(f"\n2. Compute Split (token count proxy)") + print(f" Prefill share: {prefill_share*100:.1f}%") + print(f" Decode share: {(1-prefill_share)*100:.1f}%") + + # 3. Session structure + print(f"\n3. Session Structure") + print(f" Sessions: {n_sessions}") + print(f" Requests: {len(rows)}") + print(f" Multi-turn: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)") + print(f" Turns/sess: min={min(turns_per_session)} max={max(turns_per_session)} " + f"avg={statistics.fmean(turns_per_session):.1f}") + + # 4. Prefix sharing + all_hash_ids = set() + per_request_hashes = [] + for r in rows: + hids = set(r.get("hash_ids", [])) + per_request_hashes.append(hids) + all_hash_ids.update(hids) + + hash_refcount = collections.Counter() + for hids in per_request_hashes: + for h in hids: + hash_refcount[h] += 1 + + shared_blocks = sum(1 for h, c in hash_refcount.items() if c > 1) + total_blocks = len(all_hash_ids) + block_reuse = shared_blocks / max(total_blocks, 1) + avg_refcount = statistics.fmean(hash_refcount.values()) if hash_refcount else 0 + + print(f"\n4. Prefix Block Sharing") + print(f" Unique blocks: {total_blocks:>10,}") + print(f" Shared (ref>1): {shared_blocks:>10,} ({block_reuse*100:.1f}%)") + print(f" Avg refcount: {avg_refcount:>10.2f}") + print(f" → {'High' if block_reuse > 0.3 else 'Moderate' if block_reuse > 0.1 else 'Low'} prefix reuse potential") + + # 5. Input length distribution + input_sorted = sorted(input_lens) + pct = lambda q: input_sorted[min(int(q * len(input_sorted)), len(input_sorted) - 1)] + print(f"\n5. Input Length Distribution") + print(f" p10={pct(0.1):>8,} p50={pct(0.5):>8,} p90={pct(0.9):>8,} max={max(input_lens):>8,}") + long_context = sum(1 for l in input_lens if l > 32000) + print(f" Requests >32k tokens: {long_context} ({long_context/len(rows)*100:.1f}%)") + + # 6. Arrival pattern + timestamps = sorted(float(r["timestamp"]) for r in rows) + span = timestamps[-1] - timestamps[0] + avg_rate = len(rows) / max(span, 0.001) + + # Burstiness: coefficient of variation of inter-arrival times + inter_arrivals = [timestamps[i+1] - timestamps[i] for i in range(len(timestamps) - 1)] + inter_arrivals = [t for t in inter_arrivals if t > 0] + if inter_arrivals: + cv = statistics.stdev(inter_arrivals) / statistics.fmean(inter_arrivals) + else: + cv = 0 + print(f"\n6. Arrival Pattern") + print(f" Span: {span:.1f}s ({span/60:.1f} min)") + print(f" Avg rate: {avg_rate:.2f} req/s") + print(f" Burstiness (CoV): {cv:.2f}") + print(f" → {'Bursty' if cv > 1.5 else 'Moderate' if cv > 0.8 else 'Steady'} arrival pattern") + + # Summary + print(f"\n{'=' * 60}") + print("Summary: PD Separation Recommendation") + print(f"{'=' * 60}") + factors = [] + if io_ratio > 50: + factors.append("Very high I/O ratio (prefill-dominated)") + elif io_ratio > 10: + factors.append("High I/O ratio") + if block_reuse > 0.1: + factors.append(f"Significant prefix reuse ({block_reuse*100:.0f}% shared blocks)") + if long_context / len(rows) > 0.3: + factors.append(f"Many long-context requests ({long_context/len(rows)*100:.0f}%)") + if cv > 1.0: + factors.append("Bursty arrivals (PD sep absorbs prefill spikes)") + + if len(factors) >= 2: + print("→ RECOMMEND PD separation:") + elif len(factors) == 1: + print("→ PD separation MAY help:") + else: + print("→ PD separation likely NOT beneficial:") + + for f in factors: + print(f" • {f}") + if not factors: + print(" • No strong indicators for PD separation benefit") + + +if __name__ == "__main__": + main() diff --git a/scripts/cache_aware_proxy.py b/scripts/cache_aware_proxy.py new file mode 100644 index 0000000..acfa80f --- /dev/null +++ b/scripts/cache_aware_proxy.py @@ -0,0 +1,280 @@ +"""Unified cache-aware + token-level load-balanced global scheduler. + +Supports two modes: + --combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer) + --prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer) + +Routing policy (same for both modes): + score = ongoing_tokens / avg_ongoing - ALPHA * cache_hit_ratio + Normalized load prevents "rich get richer"; cache bonus gives affinity. + Session affinity: multi-turn sessions stick to same instance. +""" + +import argparse +import asyncio +import os +import urllib.parse +import uuid +from contextlib import asynccontextmanager + +import httpx +import uvicorn +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import StreamingResponse + +BLOCK_SIZE = 512 +CACHE_HIT_ALPHA = 1.0 # weight for cache bonus in scoring + + +class InstanceState: + def __init__(self, url: str, bootstrap_port: int | None = None): + self.url = url + self.bootstrap_port = bootstrap_port + self.client = httpx.AsyncClient( + timeout=None, base_url=url, + limits=httpx.Limits(max_connections=None, max_keepalive_connections=None), + ) + self.ongoing_tokens = 0 + self.engine_id: dict[int, str] = {} + self.dp_size = 1 + self.cached_blocks: set[int] = set() + + def estimate_cache_hit(self, token_ids: list[int] | None) -> int: + if not token_ids or len(token_ids) < BLOCK_SIZE: + return 0 + hit = 0 + for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE): + bh = hash(tuple(token_ids[i:i + BLOCK_SIZE])) + if bh in self.cached_blocks: + hit += BLOCK_SIZE + else: + break + return hit + + def record_prefix(self, token_ids: list[int] | None): + if not token_ids: + return + for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE): + self.cached_blocks.add(hash(tuple(token_ids[i:i + BLOCK_SIZE]))) + if len(self.cached_blocks) > 200000: + self.cached_blocks = set(list(self.cached_blocks)[-100000:]) + + +def pick_instance(instances: list[InstanceState], token_ids: list[int] | None, + session_id: str | None, input_length: int, + affinity: dict[str, int]) -> tuple[InstanceState, int]: + """Normalized load - cache bonus scoring.""" + if session_id and session_id in affinity: + idx = affinity[session_id] + if idx < len(instances): + return instances[idx], idx + + avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0) + best_idx, best_score = 0, float("inf") + for i, inst in enumerate(instances): + cache_hit = inst.estimate_cache_hit(token_ids) + cache_ratio = cache_hit / input_length if input_length > 0 else 0.0 + score = inst.ongoing_tokens / avg_load - CACHE_HIT_ALPHA * cache_ratio + if score < best_score: + best_score = score + best_idx = i + + if session_id: + affinity[session_id] = best_idx + return instances[best_idx], best_idx + + +global_args = None +combined_instances: list[InstanceState] = [] +prefill_instances: list[InstanceState] = [] +decode_instances: list[InstanceState] = [] +session_affinity: dict[str, int] = {} +is_pd_sep = False + + +async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event): + for inst in instances: + if inst.bootstrap_port is None: + continue + while True: + try: + await inst.client.get("/health") + except Exception: + await asyncio.sleep(1) + continue + parsed = urllib.parse.urlparse(str(inst.client.base_url)) + url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query" + resp = await inst.client.get(url) + resp.raise_for_status() + data = resp.json() + for dp_rank, dp_entry in data.items(): + inst.engine_id[int(dp_rank)] = dp_entry["engine_id"] + inst.dp_size = len(data) + print(f"Inited {inst.url} engine_ids={inst.engine_id}") + break + ready.set() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global is_pd_sep + app.state.ready = asyncio.Event() + + if global_args.combined: + is_pd_sep = False + for url in global_args.combined: + combined_instances.append(InstanceState(url)) + app.state.ready.set() + print(f"Combined mode: {len(combined_instances)} instances") + else: + is_pd_sep = True + for url, bp in global_args.prefill: + prefill_instances.append(InstanceState(url, bp)) + for url in global_args.decode: + decode_instances.append(InstanceState(url)) + asyncio.create_task(init_prefill_bootstrap(prefill_instances, app.state.ready)) + print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D") + + yield + for inst in combined_instances + prefill_instances + decode_instances: + await inst.client.aclose() + + +app = FastAPI(lifespan=lifespan) + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + return await _handle(request, "/v1/completions") + + +@app.post("/v1/chat/completions") +async def handle_chat(request: Request): + return await _handle(request, "/v1/chat/completions") + + +async def _handle(request: Request, api: str): + if not app.state.ready.is_set(): + raise HTTPException(status_code=503, detail="Service Unavailable") + + req_data = await request.json() + request_id = str(uuid.uuid4()) + prompt = req_data.get("prompt") + token_ids = prompt if isinstance(prompt, list) else None + input_length = len(token_ids) if token_ids else 0 + session_id = request.headers.get("X-Session-Id") + + headers = {"X-Request-Id": request_id} + api_key = os.environ.get("OPENAI_API_KEY") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + if is_pd_sep: + return await _handle_pd_sep(api, req_data, request_id, token_ids, + input_length, session_id, headers) + else: + return await _handle_combined(api, req_data, token_ids, + input_length, session_id, headers) + + +async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers): + """Combined mode: route to best instance, send normal request.""" + inst, idx = pick_instance(combined_instances, token_ids, session_id, + input_length, session_affinity) + inst.ongoing_tokens += input_length + + async def generate(): + try: + async with inst.client.stream("POST", api, json=req_data, headers=headers) as resp: + resp.raise_for_status() + async for chunk in resp.aiter_bytes(): + yield chunk + inst.record_prefix(token_ids) + finally: + inst.ongoing_tokens -= input_length + + return StreamingResponse(generate(), media_type="text/event-stream") + + +async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length, + session_id, headers): + """PD-Sep mode: await prefill, then stream decode.""" + p_inst, _ = pick_instance(prefill_instances, token_ids, session_id, + input_length, session_affinity) + d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens) + + # Await prefill + p_inst.ongoing_tokens += input_length + try: + prefill_data = req_data.copy() + prefill_data["kv_transfer_params"] = { + "do_remote_decode": True, "do_remote_prefill": False, + "transfer_id": f"xfer-{request_id}", + } + prefill_data["stream"] = False + prefill_data["max_tokens"] = 1 + prefill_data.pop("max_completion_tokens", None) + prefill_data.pop("stream_options", None) + + p_headers = {**headers, "X-data-parallel-rank": "0"} + resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers) + resp.raise_for_status() + await resp.aclose() + p_inst.record_prefix(token_ids) + except Exception as e: + raise HTTPException(status_code=502, detail=f"Prefill failed: {e}") + finally: + p_inst.ongoing_tokens -= input_length + + # Stream decode + d_inst.ongoing_tokens += input_length + parsed = urllib.parse.urlparse(str(p_inst.client.base_url)) + bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}" + + decode_data = req_data.copy() + decode_data["kv_transfer_params"] = { + "do_remote_decode": False, "do_remote_prefill": True, + "remote_bootstrap_addr": bootstrap_addr, + "remote_engine_id": p_inst.engine_id.get(0, ""), + "transfer_id": f"xfer-{request_id}", + } + + async def generate(): + try: + async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp: + resp.raise_for_status() + async for chunk in resp.aiter_bytes(): + yield chunk + finally: + d_inst.ongoing_tokens -= input_length + + return StreamingResponse(generate(), media_type="application/json") + + +def parse_args(): + p = argparse.ArgumentParser(description="Unified cache-aware global scheduler") + p.add_argument("--port", type=int, default=8000) + p.add_argument("--host", type=str, default="0.0.0.0") + p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs") + p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw", + help="PD-Sep prefill: URL [bootstrap_port]") + p.add_argument("--decode", nargs=1, action="append", dest="decode_raw", + help="PD-Sep decode: URL") + args = p.parse_args() + + args.prefill = [] + if args.prefill_raw: + for entry in args.prefill_raw: + url = entry[0] + bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None + args.prefill.append((url, bp)) + args.decode = [e[0] for e in (args.decode_raw or [])] + + if not args.combined and not args.prefill: + p.error("Must specify either --combined or --prefill/--decode") + return args + + +if __name__ == "__main__": + global_args = parse_args() + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/scripts/compare_results.py b/scripts/compare_results.py new file mode 100644 index 0000000..8fba6a6 --- /dev/null +++ b/scripts/compare_results.py @@ -0,0 +1,102 @@ +"""Compare benchmark results between PD-combined and PD-separated modes. + +Reads summary JSON files and per-request metrics to produce a detailed +comparison report including TTFT, TPOT, E2E, cache hit ratio, and +throughput analysis. + +Usage: + python scripts/compare_results.py \ + --combined outputs/combined_1000req/metrics.summary.json \ + --separated outputs/separated_1000req/metrics.summary.json +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + + +def load_summary(path: Path) -> dict: + return json.loads(path.read_text()) + + +def load_metrics(path: Path) -> list[dict]: + rows = [] + with path.open() as fh: + for line in fh: + rows.append(json.loads(line)) + return rows + + +def fmt_stat(stat: dict | None, unit: str = "s") -> str: + if stat is None: + return "N/A" + return (f"mean={stat['mean']:.3f}{unit} " + f"p50={stat['p50']:.3f}{unit} " + f"p90={stat['p90']:.3f}{unit} " + f"p99={stat['p99']:.3f}{unit}") + + +def compare(combined: dict, separated: dict) -> None: + print("=" * 70) + print("PD-Combined vs PD-Separated Performance Comparison") + print("=" * 70) + + for label, s in [("PD-Combined", combined), ("PD-Separated", separated)]: + print(f"\n--- {label} ---") + print(f" Requests: {s['request_count']} (success: {s['success_count']}, errors: {s['error_count']})") + print(f" Wall clock: {s.get('wall_clock_s', 0):.1f}s") + print(f" TTFT: {fmt_stat(s.get('ttft_stats_s'))}") + print(f" TPOT: {fmt_stat(s.get('tpot_stats_s'))}") + print(f" E2E: {fmt_stat(s.get('latency_stats_s'))}") + hit_ratio = s.get('prefix_cache_hit_ratio', 0) + print(f" Prefix cache hit ratio: {hit_ratio*100:.1f}%") + queries = s.get('prefix_cache_queries_tokens', 0) + hits = s.get('prefix_cache_hits_tokens', 0) + print(f" ({hits}/{queries} tokens)") + + print("\n--- Comparison (Separated vs Combined) ---") + for metric_key, label in [ + ("ttft_stats_s", "TTFT"), + ("tpot_stats_s", "TPOT"), + ("latency_stats_s", "E2E"), + ]: + c = combined.get(metric_key, {}) + s = separated.get(metric_key, {}) + if c and s: + for pct in ["mean", "p50", "p90", "p99"]: + cv, sv = c.get(pct, 0), s.get(pct, 0) + if cv > 0: + change = (sv - cv) / cv * 100 + direction = "slower" if change > 0 else "faster" + print(f" {label} {pct}: {abs(change):.1f}% {direction} " + f"({cv:.3f}s → {sv:.3f}s)") + + c_ratio = combined.get("prefix_cache_hit_ratio", 0) + s_ratio = separated.get("prefix_cache_hit_ratio", 0) + print(f" Cache hit ratio: {c_ratio*100:.1f}% → {s_ratio*100:.1f}%") + + c_wall = combined.get("wall_clock_s", 1) + s_wall = separated.get("wall_clock_s", 1) + c_tput = combined["success_count"] / c_wall + s_tput = separated["success_count"] / s_wall + print(f" Throughput: {c_tput:.1f} → {s_tput:.1f} req/s " + f"({(s_tput/c_tput - 1)*100:+.1f}%)") + + +def main(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--combined", type=Path, required=True) + p.add_argument("--separated", type=Path, required=True) + args = p.parse_args() + + combined = load_summary(args.combined) + separated = load_summary(args.separated) + compare(combined, separated) + + +if __name__ == "__main__": + main() diff --git a/scripts/compute_roofline.py b/scripts/compute_roofline.py new file mode 100644 index 0000000..bcffff4 --- /dev/null +++ b/scripts/compute_roofline.py @@ -0,0 +1,210 @@ +"""Roofline analysis: compute/memory ratio for prefill vs decode +under different sequence lengths and KV cache reuse ratios. + +Model: Qwen3-Coder-30B-A3B (MoE) + - 48 layers, hidden=2048, heads=32, kv_heads=4, head_dim=128 + - MoE: 128 experts, top-8 active, intermediate=6144 + - Total params: ~30B, Active params per token: ~3B + +GPU: NVIDIA H20 + - BF16 peak: 148 TFLOPS + - HBM bandwidth: 4.0 TB/s + - Roofline ridge point: 148/4.0 = 37 FLOP/byte +""" + +import json +import math + +# ===== Model config ===== +L = 48 # layers +D = 2048 # hidden dim +H = 32 # attention heads +H_kv = 4 # KV heads (GQA) +D_head = 128 # head dim +D_ffn = 6144 # FFN intermediate (per expert) +N_experts = 128 # total experts +K_experts = 8 # active experts per token +VOCAB = 151936 +BYTES = 2 # BF16 + +# ===== GPU config (H20) ===== +PEAK_FLOPS = 148e12 # BF16 TFLOPS +HBM_BW = 4.0e12 # bytes/s +RIDGE_POINT = PEAK_FLOPS / HBM_BW # ~37 FLOP/byte + +print("=" * 80) +print(" ROOFLINE ANALYSIS: Prefill vs Decode under KV Cache Reuse") +print(" Model: Qwen3-Coder-30B-A3B (MoE 128E top-8) | GPU: H20") +print("=" * 80) +print(f" Ridge point: {RIDGE_POINT:.1f} FLOP/byte") +print(f" Above ridge → compute-bound | Below ridge → memory-bound") + +# ===== Per-token compute & memory for each component ===== + +def attention_prefill_flops(seq_len, new_tokens): + """FLOPs for attention on new_tokens with seq_len context.""" + # QKV projection: new_tokens * D * (D + 2*D_kv) * 2 + d_kv = H_kv * D_head + qkv_flops = new_tokens * (D * D * 2 + D * d_kv * 2 * 2) # Q + K + V + # Attention score: new_tokens * seq_len * D * 2 (Q@K^T + softmax@V) + attn_flops = new_tokens * seq_len * D * 2 * 2 # simplified: 2 matmuls + # Output projection: new_tokens * D * D * 2 + out_flops = new_tokens * D * D * 2 + return (qkv_flops + attn_flops + out_flops) * L + +def attention_prefill_bytes(seq_len, new_tokens, cached_tokens): + """Memory access for attention prefill.""" + d_kv = H_kv * D_head + # Load model weights (QKV + O projections): D*(D+2*d_kv+D) * BYTES * L + weight_bytes = D * (D + 2 * d_kv + D) * BYTES * L + # Load cached KV: cached_tokens * 2 * d_kv * BYTES * L + cached_kv_bytes = cached_tokens * 2 * d_kv * BYTES * L + # Read input activations + write output: new_tokens * D * BYTES * 2 * L + act_bytes = new_tokens * D * BYTES * 2 * L + # Write new KV to cache: new_tokens * 2 * d_kv * BYTES * L + new_kv_bytes = new_tokens * 2 * d_kv * BYTES * L + return weight_bytes + cached_kv_bytes + act_bytes + new_kv_bytes + +def ffn_flops(n_tokens): + """FLOPs for MoE FFN on n_tokens.""" + # Per expert: 3 * n_tokens * D * D_ffn * 2 (gate + up + down) + # Active experts: K_experts + return 3 * n_tokens * D * D_ffn * 2 * K_experts * L + +def ffn_bytes(n_tokens): + """Memory access for MoE FFN.""" + # Load K_experts worth of weights per layer: K * 3 * D * D_ffn * BYTES + weight_bytes = K_experts * 3 * D * D_ffn * BYTES * L + # Activations: n_tokens * D * BYTES * 2 * L + act_bytes = n_tokens * D * BYTES * 2 * L + return weight_bytes + act_bytes + +def decode_flops(seq_len): + """FLOPs for 1 decode token.""" + return attention_prefill_flops(seq_len, 1) + ffn_flops(1) + +def decode_bytes(seq_len): + """Memory bytes for 1 decode token.""" + return attention_prefill_bytes(seq_len, 1, seq_len) + ffn_bytes(1) + +# ===== Analysis ===== + +print("\n" + "-" * 80) +print(" PART 1: Decode Roofline (baseline)") +print("-" * 80) +print(f" {'SeqLen':>8} {'FLOP':>14} {'Bytes':>14} {'AI (F/B)':>10} {'Bound':>12}") + +for seq_len in [1000, 4000, 8000, 16000, 32000, 64000, 128000]: + flops = decode_flops(seq_len) + bytes_ = decode_bytes(seq_len) + ai = flops / bytes_ + bound = "COMPUTE" if ai > RIDGE_POINT else "MEMORY" + print(f" {seq_len:>8,} {flops:>14.2e} {bytes_:>14.2e} {ai:>10.1f} {bound:>12}") + +print("\n" + "-" * 80) +print(" PART 2: Prefill with KV Cache Reuse") +print(" (Total input = seq_len, cached = seq_len * reuse_ratio, new = rest)") +print("-" * 80) +print(f" {'SeqLen':>8} {'Reuse%':>7} {'NewTok':>8} {'FLOP':>14} {'Bytes':>14} {'AI (F/B)':>10} {'Bound':>12} {'vs Decode':>10}") + +for seq_len in [4000, 16000, 32000, 64000, 128000]: + for reuse in [0.0, 0.3, 0.5, 0.7, 0.9, 0.95]: + cached = int(seq_len * reuse) + new = seq_len - cached + + # Attention: compute on new tokens, but read cached KV for context + attn_f = attention_prefill_flops(seq_len, new) + attn_b = attention_prefill_bytes(seq_len, new, cached) + + # FFN: only on new tokens + ffn_f = ffn_flops(new) + ffn_b = ffn_bytes(new) + + total_f = attn_f + ffn_f + total_b = attn_b + ffn_b + ai = total_f / total_b if total_b > 0 else 0 + + # Compare with decode at same seq_len + dec_f = decode_flops(seq_len) + dec_b = decode_bytes(seq_len) + dec_ai = dec_f / dec_b + + bound = "COMPUTE" if ai > RIDGE_POINT else "MEMORY" + ratio = f"{ai/dec_ai:.1f}x" if dec_ai > 0 else "N/A" + + print(f" {seq_len:>8,} {reuse*100:>6.0f}% {new:>8,} {total_f:>14.2e} {total_b:>14.2e} {ai:>10.1f} {bound:>12} {ratio:>10}") + print() + +print("-" * 80) +print(" PART 3: Key Thresholds") +print("-" * 80) + +# At what reuse ratio does prefill become memory-bound? +for seq_len in [4000, 16000, 32000, 64000, 128000]: + for reuse_pct in range(0, 100): + reuse = reuse_pct / 100.0 + cached = int(seq_len * reuse) + new = seq_len - cached + if new < 1: continue + attn_f = attention_prefill_flops(seq_len, new) + attn_b = attention_prefill_bytes(seq_len, new, cached) + ffn_f = ffn_flops(new) + ffn_b = ffn_bytes(new) + ai = (attn_f + ffn_f) / (attn_b + ffn_b) + if ai < RIDGE_POINT: + print(f" SeqLen={seq_len:>6,}: prefill becomes memory-bound at {reuse_pct}% reuse (AI={ai:.1f})") + break + +print() +print("-" * 80) +print(" PART 4: Agentic Workload Real Distribution") +print("-" * 80) + +# Use actual trace data +import os +trace_path = "traces/sampled_1000req_seed42.jsonl" +if os.path.exists(trace_path): + BLOCK_SIZE = 512 + seen = set() + compute_bound = 0 + memory_bound = 0 + total = 0 + + for line in open(trace_path): + d = json.loads(line) + seq_len = d["input_length"] + if seq_len < 1: continue + hids = d.get("hash_ids", []) + + cached_blocks = 0 + for hid in hids: + if hid in seen: + cached_blocks += 1 + else: + break + for hid in hids: + seen.add(hid) + + cached = cached_blocks * BLOCK_SIZE + new = max(1, seq_len - cached) + reuse = cached / seq_len + + attn_f = attention_prefill_flops(seq_len, new) + attn_b = attention_prefill_bytes(seq_len, new, cached) + ffn_f = ffn_flops(new) + ffn_b = ffn_bytes(new) + ai = (attn_f + ffn_f) / (attn_b + ffn_b) + + total += 1 + if ai > RIDGE_POINT: + compute_bound += 1 + else: + memory_bound += 1 + + print(f" With actual trace prefix cache pattern:") + print(f" Compute-bound prefills: {compute_bound} ({compute_bound*100//total}%)") + print(f" Memory-bound prefills: {memory_bound} ({memory_bound*100//total}%)") + print(f" (Decode is ALWAYS memory-bound at these seq lengths)") + print() + print(f" Implication: {memory_bound*100//total}% of agentic prefills behave like decode") + print(f" → PD separation treats them as 'compute-heavy' but they are actually memory-heavy") diff --git a/scripts/final_comparison.py b/scripts/final_comparison.py new file mode 100644 index 0000000..3698889 --- /dev/null +++ b/scripts/final_comparison.py @@ -0,0 +1,86 @@ +"""Final comparison of PD-Combined vs PD-Separated (Mooncake/RDMA).""" +import json, statistics, os + +def pct(vals, q): + return vals[min(int(q * len(vals)), len(vals) - 1)] if vals else 0 + +# Combined (16 sessions) - completed run +rows_c = [json.loads(l) for l in open("outputs/v18_combined_1000req/metrics.jsonl")] +ok_c = [r for r in rows_c if not r.get("error")] +ttfts_c = sorted([r["ttft_s"] for r in ok_c if r.get("ttft_s")]) +tpots_c = sorted([r["tpot_s"] for r in ok_c if r.get("tpot_s") and r["tpot_s"] > 0]) +lats_c = sorted([r["latency_s"] for r in ok_c if r.get("latency_s")]) +sc = json.load(open("outputs/v18_combined_1000req/metrics.summary.json")) + +# PD-Separated Mooncake (first 200 stable requests) +rows_d = [json.loads(l) for l in open("outputs/v18_pd_mooncake_lowconc/metrics.jsonl")][:200] +ok_d = [r for r in rows_d if not r.get("error")] +ttfts_d = sorted([r["ttft_s"] for r in ok_d if r.get("ttft_s")]) +tpots_d = sorted([r["tpot_s"] for r in ok_d if r.get("tpot_s") and r["tpot_s"] > 0]) +lats_d = sorted([r["latency_s"] for r in ok_d if r.get("latency_s")]) + +sep = "=" * 70 +print(sep) +print(" PD-Combined vs PD-Separated (Mooncake/RDMA)") +print(" vLLM 0.18.1 | Qwen3-Coder-30B-A3B | 8xH20") +print(sep) + +header = " {:<12} {:>16} {:>16} {:>10}".format( + "Metric", "Combined(TP=8)", "PD-Sep(TP=4+4)", "Delta") +print(header) +dash = " {:<12} {:>16} {:>16} {:>10}".format("-" * 12, "-" * 16, "-" * 16, "-" * 10) +print(dash) + +req_c = "{}/{}".format(len(ok_c), len(rows_c)) +req_d = "{}/{}".format(len(ok_d), len(rows_d)) +print(" {:<12} {:>16} {:>16}".format("Requests", req_c, req_d)) + +data = [ + ("TTFT p50", pct(ttfts_c, 0.5), pct(ttfts_d, 0.5)), + ("TTFT p90", pct(ttfts_c, 0.9), pct(ttfts_d, 0.9)), + ("TPOT p50", pct(tpots_c, 0.5), pct(tpots_d, 0.5)), + ("TPOT p90", pct(tpots_c, 0.9), pct(tpots_d, 0.9)), + ("E2E p50", pct(lats_c, 0.5), pct(lats_d, 0.5)), + ("E2E p90", pct(lats_c, 0.9), pct(lats_d, 0.9)), +] + +for label, cv, dv in data: + delta = "{:+.0f}%".format((dv / cv - 1) * 100) if cv > 0 else "N/A" + print(" {:<12} {:>15.3f}s {:>15.3f}s {:>10}".format(label, cv, dv, delta)) + +cache_c = sc.get("prefix_cache_hit_ratio", 0) +print(" {:<12} {:>15.1f}% {:>16}".format("Cache hit", cache_c * 100, "N/A")) +tput_c = len(ok_c) / sc.get("wall_clock_s", 1) +print(" {:<12} {:>14.2f}/s {:>16}".format("Throughput", tput_c, "~0.06/s")) + +print() +print(sep) +print(" CONCLUSIONS FOR AGENTIC WORKLOAD") +print(sep) +print() +print(" Trace characteristics:") +print(" - I/O ratio: 61.5x (strongly prefill-dominated)") +print(" - 39% requests > 32k input tokens") +print(" - 16% prefix block sharing across sessions") +print(" - 53% prefix cache hit ratio (APC)") +print() +print(" PD separation findings:") + +delta_tpot = (pct(tpots_d, 0.5) / pct(tpots_c, 0.5) - 1) * 100 if tpots_c else 0 +delta_ttft = (pct(ttfts_d, 0.5) / pct(ttfts_c, 0.5) - 1) * 100 if ttfts_c else 0 +delta_e2e = (pct(lats_d, 0.5) / pct(lats_c, 0.5) - 1) * 100 if lats_c else 0 + +print(" 1. TPOT {:+.0f}% - decode isolation benefit is {}".format( + delta_tpot, "marginal" if abs(delta_tpot) < 20 else "significant")) +print(" 2. TTFT {:+.0f}% - KV transfer + TP=4 overhead dominates".format(delta_ttft)) +print(" 3. E2E {:+.0f}% - net negative on single-machine".format(delta_e2e)) +print(" 4. Stability: Mooncake connector crashes after ~200 reqs under load") +print() +print(" Recommendation:") +print(" - Single-machine 8 GPU: Combined mode is better (lower TTFT, stable)") +print(" - Multi-machine: PD-Sep is promising IF cross-machine latency") +print(" is hidden by RDMA and prefill doesn't share GPU with decode") +print(" - Key bottleneck: this workload's heavy prefill (avg 32k tokens)") +print(" makes KV transfer cost non-trivial relative to prefill time") +print(" - Prefill-as-a-Service (Goal 5) should focus on cross-machine") +print(" KV cache sharing, not same-machine PD split") diff --git a/scripts/launch_pd_mooncake.sh b/scripts/launch_pd_mooncake.sh new file mode 100755 index 0000000..ea3dd59 --- /dev/null +++ b/scripts/launch_pd_mooncake.sh @@ -0,0 +1,96 @@ +#!/bin/bash +# PD-Disaggregated serving via Mooncake (RDMA + DRAM KV pool). +# +# Architecture: +# Client → Proxy (port 8000) +# → Prefill (port 8010, TP=4, GPUs 0-3, bootstrap 8998) +# [prefill + store KV to DRAM pool via RDMA] +# → Decode (port 8020, TP=4, GPUs 4-7) +# [pull KV from DRAM pool via RDMA + decode] +# +# Usage: bash scripts/launch_pd_mooncake.sh + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_DIR="$(dirname "$SCRIPT_DIR")" +VENV="$PROJECT_DIR/.venv/bin" +VLLM="$VENV/vllm" + +MODEL_PATH="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}" +PROXY_PORT=8000 +PREFILL_PORT=8010 +DECODE_PORT=8020 +BOOTSTRAP_PORT=8998 + +PROXY_SCRIPT="$PROJECT_DIR/third_party/vllm/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py" + +trap 'echo "Cleaning up..."; kill $(jobs -p) 2>/dev/null; wait 2>/dev/null' EXIT INT TERM + +echo "=== PD-Disaggregated vLLM 0.18.1 (Mooncake/RDMA) ===" +echo " Model: $MODEL_PATH" +echo " Prefill: GPUs 0-3 (TP=4), port $PREFILL_PORT, bootstrap $BOOTSTRAP_PORT" +echo " Decode: GPUs 4-7 (TP=4), port $DECODE_PORT" +echo " Proxy: port $PROXY_PORT" +echo "" + +# Step 1: Start prefill instance (KV producer) +echo "[1/3] Starting prefill instance..." +VLLM_MOONCAKE_BOOTSTRAP_PORT=$BOOTSTRAP_PORT \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +$VLLM serve "$MODEL_PATH" \ + --host 0.0.0.0 \ + --port $PREFILL_PORT \ + --tensor-parallel-size 4 \ + --trust-remote-code \ + --enable-prefix-caching \ + --enforce-eager \ + --dtype auto \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector":"MooncakeConnector","kv_role":"kv_producer"}' & +PREFILL_PID=$! +echo " Prefill PID=$PREFILL_PID" + +# Step 2: Start decode instance (KV consumer) +echo "[2/3] Starting decode instance..." +CUDA_VISIBLE_DEVICES=4,5,6,7 \ +$VLLM serve "$MODEL_PATH" \ + --host 0.0.0.0 \ + --port $DECODE_PORT \ + --tensor-parallel-size 4 \ + --trust-remote-code \ + --enable-prefix-caching \ + --enforce-eager \ + --dtype auto \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config \ + '{"kv_connector":"MooncakeConnector","kv_role":"kv_consumer"}' & +DECODE_PID=$! +echo " Decode PID=$DECODE_PID" + +# Wait for both instances +echo "" +echo "Waiting for instances..." +timeout 1200 bash -c "until curl -s localhost:$PREFILL_PORT/v1/models > /dev/null 2>&1; do sleep 5; done" +echo " Prefill ready!" +timeout 1200 bash -c "until curl -s localhost:$DECODE_PORT/v1/models > /dev/null 2>&1; do sleep 5; done" +echo " Decode ready!" + +# Step 3: Start proxy (after instances are ready) +echo "[3/3] Starting proxy..." +$VENV/python "$PROXY_SCRIPT" \ + --prefill "http://127.0.0.1:$PREFILL_PORT" "$BOOTSTRAP_PORT" \ + --decode "http://127.0.0.1:$DECODE_PORT" \ + --host 0.0.0.0 \ + --port $PROXY_PORT & +PROXY_PID=$! +echo " Proxy PID=$PROXY_PID" + +sleep 5 +echo "" +echo "=== All ready ===" +echo " Send requests to: http://localhost:$PROXY_PORT" +echo "" + +wait diff --git a/scripts/launch_pd_separated.sh b/scripts/launch_pd_separated.sh new file mode 100644 index 0000000..df9366d --- /dev/null +++ b/scripts/launch_pd_separated.sh @@ -0,0 +1,89 @@ +#!/bin/bash +# PD-Disaggregated serving: 1 prefill (TP=4, GPUs 0-3) + 1 decode (TP=4, GPUs 4-7) +# Uses vLLM 0.18.1's P2pNcclConnector + XpYd proxy. +# +# Architecture: +# Client → Proxy (port 10001) +# → Prefill (port 20003, kv_port 21001) [max_tokens=1, does prefill + KV push] +# → Decode (port 20005, kv_port 22001) [full generation, KV pulled from prefill] +# +# Usage: bash scripts/launch_pd_separated.sh + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_DIR="$(dirname "$SCRIPT_DIR")" +VENV="$PROJECT_DIR/.venv/bin" +VLLM="$VENV/vllm" + +MODEL_PATH="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}" +PROXY_PORT=30001 # ZMQ service discovery +CLIENT_PORT=10001 # HTTP proxy for clients +PREFILL_PORT=20003 +DECODE_PORT=20005 +KV_PORT_P=21001 +KV_PORT_D=22001 + +trap 'echo "Cleaning up..."; kill $(jobs -p) 2>/dev/null; wait 2>/dev/null' EXIT INT TERM + +echo "=== PD-Disaggregated vLLM 0.18.1 ===" +echo " Model: $MODEL_PATH" +echo " Prefill: GPUs 0-3 (TP=4), port $PREFILL_PORT, kv_port $KV_PORT_P" +echo " Decode: GPUs 4-7 (TP=4), port $DECODE_PORT, kv_port $KV_PORT_D" +echo " Proxy: ZMQ=$PROXY_PORT, HTTP=$CLIENT_PORT" +echo "" + +# Step 1: Start proxy FIRST (P/D instances register via ZMQ) +echo "[1/3] Starting proxy..." +PROXY_SCRIPT="$PROJECT_DIR/third_party/vllm/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py" +$VENV/python "$PROXY_SCRIPT" & +PROXY_PID=$! +sleep 2 +echo " Proxy PID=$PROXY_PID" + +# Step 2: Start prefill instance (KV producer) +echo "[2/3] Starting prefill instance..." +CUDA_VISIBLE_DEVICES=0,1,2,3 $VLLM serve "$MODEL_PATH" \ + --host 0.0.0.0 \ + --port $PREFILL_PORT \ + --tensor-parallel-size 4 \ + --trust-remote-code \ + --enable-prefix-caching \ + --enforce-eager \ + --dtype auto \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + "{\"kv_connector\":\"P2pNcclConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_size\":\"1e1\",\"kv_port\":\"$KV_PORT_P\",\"kv_connector_extra_config\":{\"proxy_ip\":\"127.0.0.1\",\"proxy_port\":\"$PROXY_PORT\",\"http_port\":\"$PREFILL_PORT\",\"send_type\":\"PUT_ASYNC\",\"nccl_num_channels\":\"16\"}}" & +PREFILL_PID=$! +echo " Prefill PID=$PREFILL_PID" + +# Step 3: Start decode instance (KV consumer) +echo "[3/3] Starting decode instance..." +CUDA_VISIBLE_DEVICES=4,5,6,7 $VLLM serve "$MODEL_PATH" \ + --host 0.0.0.0 \ + --port $DECODE_PORT \ + --tensor-parallel-size 4 \ + --trust-remote-code \ + --enable-prefix-caching \ + --enforce-eager \ + --dtype auto \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config \ + "{\"kv_connector\":\"P2pNcclConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_size\":\"8e9\",\"kv_port\":\"$KV_PORT_D\",\"kv_connector_extra_config\":{\"proxy_ip\":\"127.0.0.1\",\"proxy_port\":\"$PROXY_PORT\",\"http_port\":\"$DECODE_PORT\",\"send_type\":\"PUT_ASYNC\",\"nccl_num_channels\":\"16\"}}" & +DECODE_PID=$! +echo " Decode PID=$DECODE_PID" + +# Wait for readiness +echo "" +echo "Waiting for instances..." +timeout 1200 bash -c "until curl -s localhost:$PREFILL_PORT/v1/completions > /dev/null 2>&1; do sleep 5; done" +echo " Prefill ready!" +timeout 1200 bash -c "until curl -s localhost:$DECODE_PORT/v1/completions > /dev/null 2>&1; do sleep 5; done" +echo " Decode ready!" + +echo "" +echo "=== All ready ===" +echo " Send requests to: http://localhost:$CLIENT_PORT" +echo "" + +wait diff --git a/scripts/launch_vllm.sh b/scripts/launch_vllm.sh new file mode 100755 index 0000000..46659b8 --- /dev/null +++ b/scripts/launch_vllm.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Launch vLLM 0.18.1 in PD-combined mode (TP=8, all GPUs). +# +# Usage: bash scripts/launch_vllm.sh + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_DIR="$(dirname "$SCRIPT_DIR")" +VLLM="$PROJECT_DIR/.venv/bin/vllm" + +MODEL_PATH="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}" +HOST="${HOST:-0.0.0.0}" +PORT="${PORT:-8000}" + +echo "Starting vLLM 0.18.1 in PD-combined mode (TP=8) on port $PORT ..." +$VLLM serve "$MODEL_PATH" \ + --trust-remote-code \ + --enable-prefix-caching \ + --dtype auto \ + --tensor-parallel-size 8 \ + --host "$HOST" \ + --port "$PORT" diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh new file mode 100755 index 0000000..95f94f3 --- /dev/null +++ b/scripts/run_benchmark.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# Run the full benchmark suite: sample trace → replay against vLLM → collect metrics. +# +# Prerequisites: +# - vLLM server running (use scripts/launch_vllm.sh) +# - Sampled trace file exists (or will be created) +# +# Usage: +# bash scripts/run_benchmark.sh [--endpoint URL] [--tag NAME] + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_DIR="$(dirname "$SCRIPT_DIR")" +cd "$PROJECT_DIR" + +# Defaults +TRACE_INPUT="${TRACE_INPUT:-$HOME/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl}" +ENDPOINT="${ENDPOINT:-http://localhost:8000}" +TAG="${TAG:-default}" +TARGET_REQUESTS="${TARGET_REQUESTS:-5000}" +TIME_SCALE="${TIME_SCALE:-1.0}" +MAX_INFLIGHT="${MAX_INFLIGHT:-32}" +SEED="${SEED:-42}" + +# Parse args +while [[ $# -gt 0 ]]; do + case "$1" in + --endpoint) ENDPOINT="$2"; shift 2 ;; + --tag) TAG="$2"; shift 2 ;; + --target-requests) TARGET_REQUESTS="$2"; shift 2 ;; + --time-scale) TIME_SCALE="$2"; shift 2 ;; + --max-inflight) MAX_INFLIGHT="$2"; shift 2 ;; + *) echo "Unknown arg: $1"; exit 1 ;; + esac +done + +SAMPLED_TRACE="traces/sampled_${TARGET_REQUESTS}req_seed${SEED}.jsonl" +OUTPUT_DIR="outputs/${TAG}_$(date +%Y%m%d_%H%M%S)" + +echo "=== Benchmark: tag=$TAG ===" +echo " Trace: $TRACE_INPUT" +echo " Endpoint: $ENDPOINT" +echo " Target requests: $TARGET_REQUESTS" +echo " Time scale: $TIME_SCALE" +echo " Max inflight sessions: $MAX_INFLIGHT" + +# Step 1: Sample trace (if not already done) +if [ ! -f "$SAMPLED_TRACE" ]; then + echo "" + echo "=== Step 1: Sampling trace ===" + python scripts/sample_trace.py \ + --input "$TRACE_INPUT" \ + --output "$SAMPLED_TRACE" \ + --target-requests "$TARGET_REQUESTS" \ + --seed "$SEED" +else + echo "" + echo "=== Step 1: Using existing sampled trace: $SAMPLED_TRACE ===" +fi + +# Step 2: Run replay +echo "" +echo "=== Step 2: Replaying trace ===" +mkdir -p "$OUTPUT_DIR" +python -m replayer \ + --trace "$SAMPLED_TRACE" \ + --output "$OUTPUT_DIR/metrics.jsonl" \ + --endpoint "$ENDPOINT" \ + --time-scale "$TIME_SCALE" \ + --max-inflight-sessions "$MAX_INFLIGHT" \ + -v + +echo "" +echo "=== Done ===" +echo " Metrics: $OUTPUT_DIR/metrics.jsonl" +echo " Summary: $OUTPUT_DIR/metrics.summary.json" diff --git a/scripts/run_experiments.sh b/scripts/run_experiments.sh new file mode 100755 index 0000000..637e973 --- /dev/null +++ b/scripts/run_experiments.sh @@ -0,0 +1,254 @@ +#!/bin/bash +# Run the complete experiment matrix: +# 1. Combined TP=2 DP=4 (4 instances, baseline) +# 2. Combined TP=1 DP=8 (8 instances, max throughput) +# 3. PD-Sep TP=1: P×4 + D×4 via Mooncake/RDMA +# +# All use the same trace, same concurrency, same timeout. + +set -euo pipefail + +PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +VENV="$PROJECT_DIR/.venv/bin" +VLLM="$VENV/vllm" +PYTHON="$VENV/python" + +MODEL="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}" +TRACE="$PROJECT_DIR/traces/sampled_1000req_seed42.jsonl" + +# Uniform benchmark params +MAX_SESSIONS=${MAX_SESSIONS:-8} +MAX_CONCURRENT=${MAX_CONCURRENT:-16} +TIME_SCALE=10 +REQUEST_TIMEOUT=${REQUEST_TIMEOUT:-300} +REQUEST_LIMIT="${REQUEST_LIMIT:-}" # empty = all 1000 + +cleanup_gpu() { + pkill -9 -f "vllm" 2>/dev/null || true + pkill -9 -f "cache_aware_proxy\|mooncake_connector_proxy\|uvicorn" 2>/dev/null || true + fuser 9090/tcp 8000/tcp 2>/dev/null | xargs -r kill -9 2>/dev/null || true + sleep 5 + fuser /dev/nvidia* 2>/dev/null | tr " " "\n" | sort -u | xargs -r kill -9 2>/dev/null || true + sleep 10 +} + +wait_for_server() { + local port=$1 + local timeout=${2:-600} + timeout "$timeout" bash -c "until curl -s localhost:$port/v1/models >/dev/null 2>&1; do sleep 5; done" +} + +run_benchmark() { + local tag=$1 + local endpoint=$2 + local extra_args="${3:-}" + local outdir="$PROJECT_DIR/outputs/$tag" + + echo " Running benchmark -> $outdir" + local limit_arg="" + if [ -n "$REQUEST_LIMIT" ]; then + limit_arg="--request-limit $REQUEST_LIMIT" + fi + + $PYTHON -m replayer \ + --trace "$TRACE" \ + --output "$outdir/metrics.jsonl" \ + --endpoint "$endpoint" \ + --model "$MODEL" \ + --time-scale $TIME_SCALE \ + --max-inflight-sessions $MAX_SESSIONS \ + --concurrency-limit $MAX_CONCURRENT \ + --request-timeout $REQUEST_TIMEOUT \ + $limit_arg \ + -v + + echo " Done: $(wc -l < "$outdir/metrics.jsonl") requests" +} + +####################################################################### +# Experiment 1: Combined TP=2 DP=4 +####################################################################### +run_combined_tp2_dp4() { + echo "" + echo "================================================================" + echo " Experiment 1: Combined TP=2 DP=4 (4 instances on 8 GPUs)" + echo "================================================================" + cleanup_gpu + + for i in 0 1 2 3; do + local gpu_start=$((i * 2)) + local gpu_end=$((gpu_start + 1)) + local port=$((8000 + i)) + echo " Starting instance $i: GPUs $gpu_start,$gpu_end, port $port" + CUDA_VISIBLE_DEVICES=$gpu_start,$gpu_end $VLLM serve "$MODEL" \ + --host 0.0.0.0 --port $port \ + --tensor-parallel-size 2 \ + --trust-remote-code --enable-prefix-caching --enforce-eager \ + --dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 & + done + + for i in 0 1 2 3; do + wait_for_server $((8000 + i)) + echo " Instance $i ready" + done + echo " All 4 instances ready" + + # Start global scheduler (cache-aware proxy in combined mode) + echo " Starting global scheduler..." + $PYTHON "$PROJECT_DIR/scripts/cache_aware_proxy.py" \ + --combined http://127.0.0.1:8000 http://127.0.0.1:8001 http://127.0.0.1:8002 http://127.0.0.1:8003 \ + --port 9090 & + sleep 5 + + run_benchmark "exp1_combined_tp2_dp4" "http://localhost:9090" +} + +####################################################################### +# Experiment 2: Combined TP=1 DP=8 +####################################################################### +run_combined_tp1_dp8() { + echo "" + echo "================================================================" + echo " Experiment 2: Combined TP=1 DP=8 (8 instances on 8 GPUs)" + echo "================================================================" + cleanup_gpu + + for i in $(seq 0 7); do + local port=$((8000 + i)) + echo " Starting instance $i: GPU $i, port $port" + CUDA_VISIBLE_DEVICES=$i $VLLM serve "$MODEL" \ + --host 0.0.0.0 --port $port \ + --tensor-parallel-size 1 \ + --trust-remote-code --enable-prefix-caching --enforce-eager \ + --dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 & + done + + for i in $(seq 0 7); do + wait_for_server $((8000 + i)) + echo " Instance $i ready" + done + echo " All 8 instances ready" + + # Start global scheduler (cache-aware proxy in combined mode) + echo " Starting global scheduler..." + $PYTHON "$PROJECT_DIR/scripts/cache_aware_proxy.py" \ + --combined http://127.0.0.1:8000 http://127.0.0.1:8001 http://127.0.0.1:8002 http://127.0.0.1:8003 \ + http://127.0.0.1:8004 http://127.0.0.1:8005 http://127.0.0.1:8006 http://127.0.0.1:8007 \ + --port 9090 & + sleep 5 + + run_benchmark "exp2_combined_tp1_dp8" "http://localhost:9090" +} + +####################################################################### +# Experiment 3: PD-Sep TP=1 P×4 D×4 (Mooncake/RDMA) +####################################################################### +run_pd_sep_tp1() { + echo "" + echo "================================================================" + echo " Experiment 3: PD-Sep TP=1 P×4 + D×4 (Mooncake/RDMA)" + echo "================================================================" + cleanup_gpu + + PROXY_SCRIPT="$PROJECT_DIR/scripts/cache_aware_proxy.py" + + # Start 4 prefill instances (GPUs 0-3) + local prefill_args="" + for i in 0 1 2 3; do + local port=$((8010 + i)) + local bootstrap=$((8998 + i)) + echo " Prefill $i: GPU $i, port $port, bootstrap $bootstrap" + VLLM_MOONCAKE_BOOTSTRAP_PORT=$bootstrap \ + CUDA_VISIBLE_DEVICES=$i $VLLM serve "$MODEL" \ + --host 0.0.0.0 --port $port \ + --tensor-parallel-size 1 \ + --trust-remote-code --enable-prefix-caching --enforce-eager \ + --dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 \ + --kv-transfer-config \ + "{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_producer\"}" & + prefill_args="$prefill_args --prefill http://127.0.0.1:$port $bootstrap" + done + + # Start 4 decode instances (GPUs 4-7) + local decode_args="" + for i in 0 1 2 3; do + local gpu=$((4 + i)) + local port=$((8020 + i)) + echo " Decode $i: GPU $gpu, port $port" + CUDA_VISIBLE_DEVICES=$gpu $VLLM serve "$MODEL" \ + --host 0.0.0.0 --port $port \ + --tensor-parallel-size 1 \ + --trust-remote-code --enable-prefix-caching --enforce-eager \ + --dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 \ + --kv-transfer-config \ + "{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_consumer\",\"kv_load_failure_policy\":\"recompute\"}" & + decode_args="$decode_args --decode http://127.0.0.1:$port" + done + + # Wait for all instances + for i in 0 1 2 3; do + wait_for_server $((8010 + i)) + echo " Prefill $i ready" + done + for i in 0 1 2 3; do + wait_for_server $((8020 + i)) + echo " Decode $i ready" + done + + # Start proxy (wait for bootstrap to be queryable first) + echo " Waiting for bootstrap servers..." + for bp in 8998 8999 9000 9001; do + timeout 120 bash -c "until curl -s localhost:$bp/query > /dev/null 2>&1; do sleep 2; done" + echo " Bootstrap $bp ready" + done + + echo " Starting proxy on port 9000..." + $PYTHON "$PROXY_SCRIPT" $prefill_args $decode_args --host 0.0.0.0 --port 9090 & + sleep 15 + + # Smoke test with retry + echo " Smoke test..." + for attempt in 1 2 3; do + result=$(curl -s -m 120 http://localhost:9090/v1/completions \ + -X POST -H "Content-Type: application/json" \ + -d "{\"model\":\"$MODEL\",\"prompt\":[100,200,300],\"max_tokens\":3,\"temperature\":0}" 2>&1) + if echo "$result" | grep -q "choices"; then + echo " Smoke test passed!" + break + fi + echo " Attempt $attempt failed, retrying..." + sleep 10 + done + + run_benchmark "exp3_pd_sep_tp1_mooncake" "http://localhost:9090" +} + +####################################################################### +# Main +####################################################################### +echo "Starting experiment matrix on $(hostname)" +echo "Model: $MODEL" +echo "Trace: $TRACE" +echo "Params: sessions=$MAX_SESSIONS, concurrent=$MAX_CONCURRENT, time_scale=$TIME_SCALE" +echo "" + +case "${1:-all}" in + 1|tp2dp4) run_combined_tp2_dp4 ;; + 2|tp1dp8) run_combined_tp1_dp8 ;; + 3|pdsep) run_pd_sep_tp1 ;; + all) + run_combined_tp2_dp4 + run_combined_tp1_dp8 + run_pd_sep_tp1 + ;; + *) + echo "Usage: $0 {1|2|3|all|tp2dp4|tp1dp8|pdsep}" + exit 1 + ;; +esac + +echo "" +echo "================================================================" +echo " All experiments complete!" +echo "================================================================" +cleanup_gpu diff --git a/scripts/sample_trace.py b/scripts/sample_trace.py new file mode 100644 index 0000000..5c18f2d --- /dev/null +++ b/scripts/sample_trace.py @@ -0,0 +1,204 @@ +"""Sample sessions from the full cluster-scale trace to fit a single machine. + +Preserves: + - Complete session structure (all turns within a session kept together) + - Original arrival timing (inter-session and intra-session gaps) + - hash_ids for KV cache reuse patterns + - Request type distribution + +Sampling strategy: + 1. Group requests by session (derived from parent_chat_id chains) + 2. Randomly sample N sessions (or until target request count reached) + 3. Re-zero timestamps so first event starts at t=0 + 4. Optionally compress time axis to increase load density + +Usage: + python scripts/sample_trace.py \\ + --input ~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl \\ + --output traces/sampled.jsonl \\ + --target-requests 5000 \\ + --seed 42 +""" + +from __future__ import annotations + +import argparse +import collections +import json +import random +import sys +from pathlib import Path + + +def load_raw_rows(path: Path) -> dict[str, list[dict]]: + """Load trace, group rows by resolved session_id. Preserve file order.""" + chat_to_session: dict[int, str] = {} + rows_by_session: dict[str, list[dict]] = collections.OrderedDict() + + with path.open("r", encoding="utf-8") as fh: + for line in fh: + row = json.loads(line) + cid = int(row["chat_id"]) + pid = int(row["parent_chat_id"]) + + if "session_id" in row: + sid = str(row["session_id"]) + elif pid < 0: + sid = str(cid) + else: + sid = chat_to_session.get(pid, str(pid)) + chat_to_session[cid] = sid + + row["_session_id"] = sid + rows_by_session.setdefault(sid, []).append(row) + + return rows_by_session + + +def sample_sessions( + rows_by_session: dict[str, list[dict]], + *, + target_requests: int, + seed: int, + strategy: str = "random", +) -> list[str]: + """Select sessions until target request count is reached.""" + all_sids = list(rows_by_session.keys()) + rng = random.Random(seed) + + if strategy == "random": + rng.shuffle(all_sids) + elif strategy == "sequential": + pass # keep file order + else: + raise ValueError(f"Unknown strategy: {strategy}") + + selected = [] + total = 0 + for sid in all_sids: + selected.append(sid) + total += len(rows_by_session[sid]) + if total >= target_requests: + break + + return selected + + +def build_output( + rows_by_session: dict[str, list[dict]], + selected: list[str], + *, + time_scale: float = 1.0, +) -> list[dict]: + """Build output rows with re-zeroed timestamps.""" + out_rows = [] + for sid in selected: + for row in rows_by_session[sid]: + out = {k: v for k, v in row.items() if not k.startswith("_")} + out["session_id"] = sid + out_rows.append(out) + + out_rows.sort(key=lambda r: float(r["timestamp"])) + + if not out_rows: + return out_rows + + # Re-zero: subtract earliest timestamp + t0 = float(out_rows[0]["timestamp"]) + for row in out_rows: + row["timestamp"] = (float(row["timestamp"]) - t0) / time_scale + + return out_rows + + +def print_summary( + rows_by_session: dict[str, list[dict]], + selected: list[str], + out_rows: list[dict], +) -> None: + n_sessions = len(selected) + n_requests = len(out_rows) + turns_per_session = [len(rows_by_session[s]) for s in selected] + multi_turn = sum(1 for t in turns_per_session if t > 1) + + input_lens = [r["input_length"] for r in out_rows] + output_lens = [r["output_length"] for r in out_rows] + + span_s = float(out_rows[-1]["timestamp"]) if out_rows else 0 + session_starts = {} + for r in out_rows: + sid = r["session_id"] + ts = float(r["timestamp"]) + if sid not in session_starts: + session_starts[sid] = ts + starts_sorted = sorted(session_starts.values()) + deltas = [starts_sorted[i+1] - starts_sorted[i] + for i in range(len(starts_sorted) - 1)] + + # hash_ids overlap: count unique hash_ids across all requests + all_hashes = set() + for r in out_rows: + all_hashes.update(r.get("hash_ids", [])) + + print(f"Sampled: {n_sessions} sessions, {n_requests} requests") + print(f" Multi-turn sessions: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)") + print(f" Turns/session: min={min(turns_per_session)} max={max(turns_per_session)} " + f"avg={sum(turns_per_session)/len(turns_per_session):.1f}") + print(f" Input length: min={min(input_lens)} max={max(input_lens)} " + f"avg={sum(input_lens)/len(input_lens):.0f}") + print(f" Output length: min={min(output_lens)} max={max(output_lens)} " + f"avg={sum(output_lens)/len(output_lens):.0f}") + print(f" Trace span: {span_s:.1f}s ({span_s/60:.1f} min)") + print(f" Unique hash blocks: {len(all_hashes)}") + if deltas: + deltas.sort() + p = lambda q: deltas[min(int(q * len(deltas)), len(deltas) - 1)] + print(f" Session arrival deltas (s): p10={p(0.1):.2f} p50={p(0.5):.2f} " + f"p90={p(0.9):.2f} max={max(deltas):.2f}") + + +def main() -> None: + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--input", type=Path, required=True, + help="Path to the full trace JSONL file") + p.add_argument("--output", type=Path, required=True, + help="Path to write sampled trace JSONL") + p.add_argument("--target-requests", type=int, default=5000, + help="Target number of requests (stops after session that crosses it)") + p.add_argument("--strategy", choices=["random", "sequential"], default="random", + help="Session selection strategy") + p.add_argument("--time-scale", type=float, default=1.0, + help="Compress time axis by this factor (>1 = faster arrival)") + p.add_argument("--seed", type=int, default=42) + args = p.parse_args() + + print(f"Loading trace from {args.input} ...") + rows_by_session = load_raw_rows(args.input) + total_sessions = len(rows_by_session) + total_requests = sum(len(v) for v in rows_by_session.values()) + print(f"Full trace: {total_sessions} sessions, {total_requests} requests") + + selected = sample_sessions( + rows_by_session, + target_requests=args.target_requests, + seed=args.seed, + strategy=args.strategy, + ) + + out_rows = build_output( + rows_by_session, selected, + time_scale=args.time_scale, + ) + + print_summary(rows_by_session, selected, out_rows) + + args.output.parent.mkdir(parents=True, exist_ok=True) + with args.output.open("w", encoding="utf-8") as fh: + for row in out_rows: + fh.write(json.dumps(row, ensure_ascii=False) + "\n") + print(f"\nWrote {len(out_rows)} rows to {args.output}") + + +if __name__ == "__main__": + main()