Agentic workload PD separation analysis with trace-driven benchmarks
Systematic study of prefill-decode disaggregation for agentic LLM workloads using production GLM-5.1 coder trace (2.1M requests, 71B input tokens). Key findings: - Cache-aware routing improves TPOT p90 by 15% and APC from 20.8% to 44.7% without PD separation, matching PD-Sep's decode isolation benefit - PD separation adds +72% TTFT overhead (KV transfer) with no TPOT gain when using the same cache-aware scheduler - Prefill remains compute-bound even at 95% KV cache reuse (AI >1000x vs decode AI <2), but absolute FLOPs drop 71% from cache hits - For agentic MoE workloads, cache-aware routing > PD separation Infrastructure: - Trace sampler preserving session structure + hash_ids for prefix sharing - Async trace replayer with streaming TTFT/TPOT/E2E measurement - Unified cache-aware + token-level load-balanced global scheduler proxy supporting both PD-colocated and PD-disaggregated (Mooncake/RDMA) modes - vLLM 0.18.1 scheduler patch for KV transfer abort race condition - Roofline analysis tool for prefill/decode compute characterization Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
8
.gitignore
vendored
Normal file
8
.gitignore
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
.venv/
|
||||||
|
*.egg-info/
|
||||||
|
outputs/
|
||||||
|
traces/
|
||||||
|
third_party/
|
||||||
|
*.log
|
||||||
25
TODO.md
Normal file
25
TODO.md
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
实验 setup:
|
||||||
|
|
||||||
|
GPU 机器:dash0,是 8*H20 的机器,可以直接 `ssh dash0` 进行连接访问
|
||||||
|
|
||||||
|
推理引擎:基于 vllm 0.18.1,self build,支持后续 patch 放在 git 中维护
|
||||||
|
|
||||||
|
模型:`~/models/Qwen/Qwen3-Coder-30B-A3B-Instruct`
|
||||||
|
|
||||||
|
推理 trace:原始完整 2h trace 在 dash0 的 `~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl`
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
性能指标:每个请求的 TTFT/TPOT/E2E/TBT,prefix KVCache hit ratio
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
目标:
|
||||||
|
|
||||||
|
1. 先实现标准的 trace-sampler,将 cluster 规模的原始 trace,sample 到合适当前机器数量的规模来跑,保持一份统一的 sample 后的 trace file 作为输入
|
||||||
|
2. 实现标准的 trace replayer,保证能够体现线上流量的流量到来特征,KVCache 重用特征等
|
||||||
|
3. 跑通 PD 分离,确认 PD 分离能够比普通的 PD 混合一起跑的性能要好,给出两者详细的性能对比以及原因分析
|
||||||
|
4. 判断 trace 的 pattern,是否有必要 PD 完全混合或者 PD 分离
|
||||||
|
5. 参考本地的 `~/phd/agentic-pd-hybrid`,判断是否能够实现一套 prefill-as-a-service 的架构,把重的 prefill 交给 prefill service,prefill service 能够从本地 GPU/DRAM/别的 GPU 机器上 pull KVCache,提高本地的 prefix KVCache hit ratio,不影响 decoding 的 prefill,就可以交给过去 PD 分离定义中 D-node 来做,提高 KVCache 命中率
|
||||||
|
|
||||||
|
|
||||||
289
analysis/pd_separation_analysis.md
Normal file
289
analysis/pd_separation_analysis.md
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
# PD 分离在 Agentic Workload 下的系统分析
|
||||||
|
|
||||||
|
## 1. Trace 特征 (GLM-5.1 Agentic Coder, 2h, 2.1M requests)
|
||||||
|
|
||||||
|
```
|
||||||
|
Total requests: 2,114,220
|
||||||
|
Input tokens: 71.1B (avg 33.6k/req, p50=20k, p90=88k)
|
||||||
|
Output tokens: 940M (avg 445/req, p50=80, p90=811)
|
||||||
|
I/O ratio: 75.6x (aggregate), 217.8x (per-req median)
|
||||||
|
Prefill share: 98% of total tokens
|
||||||
|
Sessions: 1.3M (90% single-turn, 9% multi-turn)
|
||||||
|
```
|
||||||
|
|
||||||
|
**与传统 chatbot workload 的根本区别:**
|
||||||
|
|
||||||
|
| 特征 | Traditional Chatbot | Agentic Coder (GLM-5.1) |
|
||||||
|
|------|-------------------|------------------------|
|
||||||
|
| I/O ratio | 1-10x | **75.6x** |
|
||||||
|
| Input p50 | 500-2000 tokens | **20,030 tokens** |
|
||||||
|
| Output p50 | 200-500 tokens | **80 tokens** |
|
||||||
|
| Prefill token share | 50-80% | **98%** |
|
||||||
|
| >32k input | <5% | **38%** |
|
||||||
|
| Multi-turn | 50-80% | **9%** |
|
||||||
|
|
||||||
|
**KV Cache 复用特征:**
|
||||||
|
|
||||||
|
```
|
||||||
|
Unique hash blocks: 20,650,883
|
||||||
|
Shared blocks (ref>1): 9,749,379 (47%)
|
||||||
|
Highly shared (ref>10): 2,428,160
|
||||||
|
Intra-session reuse: 57%
|
||||||
|
Top-10 blocks ref count: 64,754 (system prompt blocks)
|
||||||
|
Theoretical cache hit: 71% (infinite cache, first 100k requests)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Input length 分布与 token 占比:**
|
||||||
|
|
||||||
|
```
|
||||||
|
<1k: 202,396 reqs ( 9%) 89M tokens ( 0%)
|
||||||
|
1-8k: 380,009 reqs (17%) 1.6B tokens ( 2%)
|
||||||
|
8-32k: 720,871 reqs (34%) 12.7B tokens (17%)
|
||||||
|
32-65k: 405,371 reqs (19%) 19.4B tokens (27%)
|
||||||
|
65-131k: 394,014 reqs (18%) 35.7B tokens (50%)
|
||||||
|
>131k: 11,559 reqs ( 0%) 1.6B tokens ( 2%)
|
||||||
|
```
|
||||||
|
|
||||||
|
50% 的 token 计算量来自 65-131k 的长 context 请求。
|
||||||
|
|
||||||
|
## 2. DistServe 等 PD 分离的核心假设
|
||||||
|
|
||||||
|
DistServe (OSDI'24), Splitwise, TetriInfer 等 PD 分离工作基于以下假设:
|
||||||
|
|
||||||
|
### 假设 A: Prefill 和 Decode 有不同的计算特征
|
||||||
|
- **Prefill**: compute-bound, 高 GPU 利用率, batch 越大越好
|
||||||
|
- **Decode**: memory-bandwidth-bound, 低 GPU 利用率, latency-sensitive
|
||||||
|
|
||||||
|
**在 agentic workload 中的验证**: ✅ 成立,但需要细化
|
||||||
|
|
||||||
|
Roofline 分析显示(详见 Section 5):
|
||||||
|
|
||||||
|
```
|
||||||
|
Arithmetic Intensity (FLOP/byte)
|
||||||
|
Decode: 1.0 - 1.9 (memory-bound, 始终远低于 ridge point)
|
||||||
|
Prefill 0% reuse: 23,000-72,000 (strongly compute-bound)
|
||||||
|
Prefill 70% reuse: 10,000-42,000 (仍然 compute-bound!)
|
||||||
|
Prefill 95% reuse: 1,900-10,800 (仍然 compute-bound!)
|
||||||
|
Ridge point (H20): 37
|
||||||
|
```
|
||||||
|
|
||||||
|
**即使 95% KV cache reuse,prefill 仍然是 compute-bound。** 但绝对计算量大幅减少。
|
||||||
|
|
||||||
|
### 假设 B: PD co-location 导致互相干扰
|
||||||
|
- Prefill 的大 batch 计算会抢占 GPU 资源,导致 decode 的 TPOT 升高
|
||||||
|
- Decode 的持续小计算会占用 GPU 调度槽位,影响 prefill 吞吐
|
||||||
|
|
||||||
|
**在 agentic workload 中的验证**: ⚠️ 干扰存在,但 **可被 cache-aware routing 消除**
|
||||||
|
|
||||||
|
```
|
||||||
|
同一 cache-aware scheduler, TP=1, 8 GPU:
|
||||||
|
Combined TP=1 DP=8: TPOT p90 = 0.073s
|
||||||
|
PD-Sep TP=1 4P+4D: TPOT p90 = 0.074s
|
||||||
|
→ 差异 <2%, 不显著
|
||||||
|
```
|
||||||
|
|
||||||
|
对比 round-robin routing:
|
||||||
|
```
|
||||||
|
Combined TP=1 DP=8 (RR): TPOT p90 = 0.086s
|
||||||
|
Combined TP=1 DP=8 (cache-aware): TPOT p90 = 0.073s → -15%
|
||||||
|
→ routing 改善 > PD 分离改善
|
||||||
|
```
|
||||||
|
|
||||||
|
**原因**: cache-aware routing 让 high-cache-hit 的请求集中到特定 instance,
|
||||||
|
每个 instance 的实际 prefill 新 token 数大幅减少(71% 被 cache),
|
||||||
|
prefill-decode 干扰因 prefill 工作量降低而自然缓解。
|
||||||
|
|
||||||
|
### 假设 C: KV Cache 传输开销可以忽略
|
||||||
|
- DistServe 假设 P→D 的 KV 传输延迟远小于 prefill 计算时间
|
||||||
|
- 在 InfiniBand/NVLink 等高带宽互联下成立
|
||||||
|
|
||||||
|
**在 agentic workload 中的验证**: ❌ 不成立
|
||||||
|
|
||||||
|
```
|
||||||
|
PD-Sep TTFT p50 = 1.261s vs Combined TTFT p50 = 0.731s (+72%)
|
||||||
|
```
|
||||||
|
|
||||||
|
原因:
|
||||||
|
1. Agentic workload 的 input 极长(p50=20k, p90=88k tokens),KV cache 很大
|
||||||
|
2. 单请求 KV cache = 20k tokens × 48 layers × 2(K+V) × 512 bytes ≈ 1GB
|
||||||
|
3. 更重要的是 await-prefill 链路的串行延迟:proxy → prefill → KV transfer → decode → first token
|
||||||
|
|
||||||
|
### 假设 D: 专用 prefill 节点可以提高 prefill 吞吐
|
||||||
|
- Prefill 节点不做 decode,GPU 利用率更高
|
||||||
|
- 可以用更大的 batch size
|
||||||
|
|
||||||
|
**在 agentic workload 中的验证**: ⚠️ 收益被 cache 稀释
|
||||||
|
|
||||||
|
```
|
||||||
|
理论 prefix cache hit (infinite cache): 71% of input tokens
|
||||||
|
实际 APC (Combined, cache-aware, 8 inst): 44.7%
|
||||||
|
```
|
||||||
|
|
||||||
|
71% cache hit → 只有 29% 的 input tokens 需要实际 prefill compute。
|
||||||
|
Nominal avg input 33.6k → Actual avg new prefill ~9.7k tokens。
|
||||||
|
专用 prefill 的 GPU 利用率优势因 prefill 工作量降低而缩小。
|
||||||
|
|
||||||
|
## 3. Roofline 分析:Prefill 在高 Cache Reuse 下的计算/访存特性
|
||||||
|
|
||||||
|
### 3.1 模型计算结构
|
||||||
|
|
||||||
|
```
|
||||||
|
Qwen3-Coder-30B-A3B (MoE 128E top-8):
|
||||||
|
48 layers, hidden=2048, heads=32, kv_heads=4 (GQA), head_dim=128
|
||||||
|
FFN: 6144 intermediate per expert, 8 experts active per token
|
||||||
|
Active params per token: ~3B
|
||||||
|
|
||||||
|
H20 GPU: 148 TFLOPS (BF16), 4.0 TB/s HBM → Ridge point: 37 FLOP/byte
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 Decode 永远 memory-bound
|
||||||
|
|
||||||
|
```
|
||||||
|
SeqLen FLOP Bytes AI (F/B) Bound
|
||||||
|
1,000 3.04e+10 3.01e+10 1.0 MEMORY
|
||||||
|
16,000 3.63e+10 3.16e+10 1.1 MEMORY
|
||||||
|
64,000 5.52e+10 3.63e+10 1.5 MEMORY
|
||||||
|
128,000 8.03e+10 4.26e+10 1.9 MEMORY
|
||||||
|
```
|
||||||
|
|
||||||
|
Decode 的 AI 始终 < 2,远低于 ridge point (37)。每个 decode step 只处理 1 个 token,
|
||||||
|
计算量极小,瓶颈在于读取模型权重和全量 KV cache。
|
||||||
|
|
||||||
|
### 3.3 Prefill 即使 95% reuse 仍然 compute-bound
|
||||||
|
|
||||||
|
```
|
||||||
|
SeqLen Reuse% NewTok AI (F/B) Bound vs Decode
|
||||||
|
32,000 0% 32,000 23,368 COMPUTE 18,190x
|
||||||
|
32,000 50% 16,000 14,899 COMPUTE 11,597x
|
||||||
|
32,000 70% 9,600 10,045 COMPUTE 7,819x
|
||||||
|
32,000 90% 3,200 3,821 COMPUTE 2,974x
|
||||||
|
32,000 95% 1,600 1,980 COMPUTE 1,542x
|
||||||
|
|
||||||
|
64,000 0% 64,000 40,758 COMPUTE 26,813x
|
||||||
|
64,000 70% 19,200 20,610 COMPUTE 13,559x
|
||||||
|
64,000 90% 6,400 8,544 COMPUTE 5,621x
|
||||||
|
64,000 95% 3,200 4,549 COMPUTE 2,993x
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.4 为什么高 reuse 不改变 compute-bound 性质
|
||||||
|
|
||||||
|
KV cache reuse 减少的:
|
||||||
|
- K/V projection 计算(只算 new tokens)
|
||||||
|
- KV 写入(只写 new tokens)
|
||||||
|
|
||||||
|
KV cache reuse **不减少**的:
|
||||||
|
- **Q×K^T attention**: 每个 new Q 都要和全部 seq_len 个 KV 做 attention
|
||||||
|
```
|
||||||
|
FLOPs = new_tokens × seq_len × head_dim × num_heads × 2 × num_layers
|
||||||
|
```
|
||||||
|
At 95% reuse, 32k seq: 1600 × 32000 × 128 × 32 × 2 × 48 ≈ 2×10^13
|
||||||
|
这个二次项在长 context 下主导总计算量
|
||||||
|
|
||||||
|
- **MoE FFN**: 每个 new token 激活 8 experts
|
||||||
|
```
|
||||||
|
FLOPs = new_tokens × 3 × D × D_ffn × 2 × K_experts × num_layers
|
||||||
|
```
|
||||||
|
|
||||||
|
**Prefill 只在接近 100% reuse (< 10 new tokens) 时才变成 memory-bound。**
|
||||||
|
|
||||||
|
### 3.5 Prefill 什么时候变 memory-bound
|
||||||
|
|
||||||
|
```
|
||||||
|
SeqLen=32,000: new_tokens ≈ 5-10 时 → AI ≈ 37 (ridge point)
|
||||||
|
SeqLen=64,000: new_tokens ≈ 5-10 时 → AI ≈ 37
|
||||||
|
```
|
||||||
|
|
||||||
|
在实际 agentic trace 中:
|
||||||
|
```
|
||||||
|
Compute-bound prefills: 961 (96%)
|
||||||
|
Memory-bound prefills: 37 (3%) ← 近 100% reuse 的极端 warm 请求
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.6 关键洞察:"Compute-bound but lightweight"
|
||||||
|
|
||||||
|
高 cache reuse 下的 prefill 处于一种独特状态:
|
||||||
|
|
||||||
|
```
|
||||||
|
Prefill bound 类型: Compute-bound (不变)
|
||||||
|
Prefill 绝对工作量: 大幅降低 (71% cache → 只算 29% 的 tokens)
|
||||||
|
Prefill-Decode 干扰: 因绝对工作量降低而减轻 (不需要物理隔离)
|
||||||
|
```
|
||||||
|
|
||||||
|
这解释了为什么 PD 分离没有帮助:
|
||||||
|
- PD 分离解决的是 "prefill 太重干扰 decode" 的问题
|
||||||
|
- 但 cache-aware routing 已经把 prefill 的实际工作量降到足够轻
|
||||||
|
- 物理隔离(PD 分离)的收益被 KV 传输开销抵消
|
||||||
|
|
||||||
|
## 4. 实验结果
|
||||||
|
|
||||||
|
### 4.1 完整实验矩阵
|
||||||
|
|
||||||
|
所有实验使用统一的 cache-aware + token-level load-balanced global scheduler。
|
||||||
|
|
||||||
|
| Config | OK/N | TTFT p50 | TPOT p90 | E2E p50 | APC |
|
||||||
|
|--------|------|----------|----------|---------|-----|
|
||||||
|
| TP=8 DP=1 (single instance) | 998/1000 | 0.467s | 0.129s | 3.30s | 53.0% |
|
||||||
|
| TP=2 DP=4 (4 inst, RR) | 997/999 | 0.844s | 0.095s | 4.92s | 33.5% |
|
||||||
|
| TP=1 DP=8 (8 inst, RR) | 997/999 | 1.836s | 0.086s | 6.67s | 20.8% |
|
||||||
|
| **TP=1 DP=8 (cache-aware)** | **997/999** | **0.731s** | **0.073s** | **4.48s** | **44.7%** |
|
||||||
|
| TP=1 PD-Sep 4P+4D (cache-aware) | 509/564 | 1.261s | 0.074s | 5.61s | 40.2% |
|
||||||
|
|
||||||
|
### 4.2 Cache-Aware Routing 的效果
|
||||||
|
|
||||||
|
```
|
||||||
|
Round-robin → Cache-aware (Combined TP=1 DP=8):
|
||||||
|
TTFT p50: 1.836s → 0.731s (-60%)
|
||||||
|
TPOT p90: 0.086s → 0.073s (-15%)
|
||||||
|
E2E p50: 6.673s → 4.480s (-33%)
|
||||||
|
APC: 20.8% → 44.7% (+24pp)
|
||||||
|
```
|
||||||
|
|
||||||
|
Cache-aware routing 的提升远大于 PD 分离的提升。
|
||||||
|
|
||||||
|
### 4.3 修复工程问题的过程
|
||||||
|
|
||||||
|
实验过程中发现并修复了多个 PD 分离的工程问题:
|
||||||
|
|
||||||
|
| 问题 | 根因 | 修复 |
|
||||||
|
|------|------|------|
|
||||||
|
| Decode engine crash | vLLM scheduler assert: KV transfer 回调时 request 已 abort | Patch scheduler.py: assert → graceful skip |
|
||||||
|
| Head-of-line blocking | Proxy 按 request count 做 LB,不区分大小请求 | Token-level ongoing_tokens load balancing |
|
||||||
|
| "Timeout waiting for P side ready" | Proxy fire-and-forget prefill, decode 盲等 KV | Await-prefill + kv_load_failure_policy=recompute |
|
||||||
|
| Port collision on startup | 8 Mooncake instances 同时启动争抢 torch distributed port | Staggered startup + explicit MASTER_PORT |
|
||||||
|
| Cache routing "rich get richer" | score = ongoing - alpha*cached 导致流量集中到一个 instance | Normalized scoring: ongoing/avg_load - alpha*cache_ratio |
|
||||||
|
|
||||||
|
## 5. 结论
|
||||||
|
|
||||||
|
### 5.1 PD 分离为什么在 Agentic Workload 不生效
|
||||||
|
|
||||||
|
1. **Cache reuse 大幅降低 prefill 绝对工作量(71% cache hit → 只算 29%)**,使得 P-D 干扰不显著
|
||||||
|
2. **Prefill 仍然 compute-bound**(即使 95% reuse,AI 仍 >1000),但每个请求的总 FLOPs 因 new_tokens 减少而大幅降低
|
||||||
|
3. **Cache-aware routing 提供 "软 PD 隔离"**,效果等同于物理隔离但无 KV 传输开销
|
||||||
|
4. **KV 传输开销不可忽略**(TTFT +72%),抵消了隔离收益
|
||||||
|
5. **MoE 模型 active params 小**(3B),per-token compute 本身较轻
|
||||||
|
|
||||||
|
### 5.2 PD 分离在什么条件下有价值
|
||||||
|
|
||||||
|
| 条件 | Chatbot (有价值) | Agentic (无价值) |
|
||||||
|
|------|-----------------|-----------------|
|
||||||
|
| Cache hit rate | <10% | **71%** |
|
||||||
|
| Model active params | 70B (dense) | **3B (MoE)** |
|
||||||
|
| I/O ratio | 1-10x | **75.6x** |
|
||||||
|
| Per-request prefill FLOPs | Very high | **Low (after cache)** |
|
||||||
|
| KV transfer cost vs prefill cost | Negligible | **Significant** |
|
||||||
|
|
||||||
|
### 5.3 Agentic Workload 应该怎么优化
|
||||||
|
|
||||||
|
1. **Cache-aware routing** (已验证有效): 用 ongoing_tokens + prefix_cache_hit 做联合调度,
|
||||||
|
将 APC 从 20.8% (RR) 提升到 44.7%,TPOT p90 降低 15%
|
||||||
|
|
||||||
|
2. **Cross-instance KV cache sharing**: 让多个 instance 共享全局 KV pool,
|
||||||
|
进一步提升 cache hit 率接近理论 71%
|
||||||
|
|
||||||
|
3. **Prefix pre-warming**: 对 cold start 请求(55%,0% cache hit),
|
||||||
|
预计算 common prefix (system prompt blocks) 并分发到所有 instance
|
||||||
|
|
||||||
|
4. **不同 workload 类型的差异化处理**:
|
||||||
|
- Warm 请求 (22%, >90% cache hit, avg 1.3k new tokens): 几乎免费,任何 instance 都能处理
|
||||||
|
- Cold 请求 (55%, 0% cache hit, avg 17.7k new tokens): prefill-heavy,需要有足够 compute
|
||||||
|
- 可以用 request-type-aware routing 进一步优化
|
||||||
130
analysis/roofline_analysis.md
Normal file
130
analysis/roofline_analysis.md
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
# Prefill 在高 KV Cache Reuse 下的计算/访存分析
|
||||||
|
|
||||||
|
## Model & GPU
|
||||||
|
|
||||||
|
```
|
||||||
|
Qwen3-Coder-30B-A3B (MoE 128E top-8)
|
||||||
|
48 layers, hidden=2048, heads=32, kv_heads=4 (GQA), head_dim=128
|
||||||
|
FFN: 6144 intermediate per expert, 8 experts active per token
|
||||||
|
Active params: ~3B per token
|
||||||
|
|
||||||
|
H20 GPU: 148 TFLOPS (BF16) / 4.0 TB/s HBM
|
||||||
|
Ridge point: 37 FLOP/byte
|
||||||
|
```
|
||||||
|
|
||||||
|
## 核心发现:Prefill 即使 95% reuse 仍然是 compute-bound
|
||||||
|
|
||||||
|
```
|
||||||
|
SeqLen Reuse% NewTok AI (F/B) Bound vs Decode AI
|
||||||
|
32,000 0% 32,000 23368 COMPUTE 18189x
|
||||||
|
32,000 70% 9,600 10045 COMPUTE 7819x
|
||||||
|
32,000 90% 3,200 3821 COMPUTE 2974x
|
||||||
|
32,000 95% 1,600 1980 COMPUTE 1541x
|
||||||
|
|
||||||
|
64,000 0% 64,000 40758 COMPUTE 26813x
|
||||||
|
64,000 70% 19,200 20610 COMPUTE 13559x
|
||||||
|
64,000 90% 6,400 8544 COMPUTE 5621x
|
||||||
|
64,000 95% 3,200 4549 COMPUTE 2993x
|
||||||
|
|
||||||
|
Decode (always):
|
||||||
|
32,000 - 1 1.3 MEMORY 1x
|
||||||
|
64,000 - 1 1.5 MEMORY 1x
|
||||||
|
```
|
||||||
|
|
||||||
|
**关键**:
|
||||||
|
- Decode 的 arithmetic intensity (AI) = 1.0-1.9 — 远低于 ridge point (37),始终 memory-bound
|
||||||
|
- Prefill 即使 95% reuse (只有 5% 新 token),AI 仍然 >1000 — 远高于 ridge point,依然 compute-bound
|
||||||
|
|
||||||
|
## 为什么高 reuse 的 prefill 仍然是 compute-bound?
|
||||||
|
|
||||||
|
### 原因:Attention 的计算量与 seq_len 成正比
|
||||||
|
|
||||||
|
当有 95% cache reuse (seq_len=64k, new_tokens=3200):
|
||||||
|
```
|
||||||
|
Q projection: new_tokens × D × D → 只处理 3200 new tokens ✓
|
||||||
|
K,V projection: new_tokens × D × D_kv → 只处理 3200 new tokens ✓
|
||||||
|
|
||||||
|
但 Attention score: new_tokens × seq_len × D_head × H × L
|
||||||
|
= 3200 × 64000 × 128 × 32 × 48
|
||||||
|
→ 仍然要对全部 64k context 做注意力计算!
|
||||||
|
|
||||||
|
FFN (MoE): new_tokens × 3 × D × D_ffn × 2 × K_experts × L
|
||||||
|
= 3200 × 3 × 2048 × 6144 × 2 × 8 × 48
|
||||||
|
→ 8 个 expert 的计算量仍然很大
|
||||||
|
```
|
||||||
|
|
||||||
|
KV cache reuse 减少的是:
|
||||||
|
- K/V projection 的计算(只算 new tokens)
|
||||||
|
- KV 写入(只写 new tokens)
|
||||||
|
|
||||||
|
但 **不减少的是**:
|
||||||
|
- Q 对全部 context 的 attention(每个 new Q 都要和所有 64k tokens 做 attention)
|
||||||
|
- MoE FFN 的计算(每个 new token 激活 8 个 expert)
|
||||||
|
|
||||||
|
所以 prefill 的 FLOPs 虽然随 reuse 减少,但 **减少的是线性部分(投影),不减少的是二次部分(attention)**。
|
||||||
|
在长 context 下,二次部分主导,使得即使 95% reuse,AI 仍远高于 ridge point。
|
||||||
|
|
||||||
|
## Prefill 什么时候才变成 memory-bound?
|
||||||
|
|
||||||
|
```
|
||||||
|
SeqLen=32,000: new_tokens ≈ 5-10 时 (reuse > 99.97%) → AI ≈ 37
|
||||||
|
SeqLen=64,000: new_tokens ≈ 5-10 时 → AI ≈ 37
|
||||||
|
```
|
||||||
|
|
||||||
|
只有在 **近乎 100% reuse**(仅 5-10 个 new tokens)时,prefill 才接近 memory-bound。
|
||||||
|
在实际 agentic trace 中,只有 3% 的请求达到这个程度。
|
||||||
|
|
||||||
|
## 对 PD 分离的影响:修正之前的分析
|
||||||
|
|
||||||
|
### 之前的错误结论(已修正)
|
||||||
|
> "Prefill 大部分是 cache lookup 不是 compute"
|
||||||
|
|
||||||
|
这是 **错误的**。即使 70% cache reuse,prefill 的 AI 仍然是 decode 的 7000-14000 倍。
|
||||||
|
Prefill 始终是 compute-bound,decode 始终是 memory-bound。
|
||||||
|
|
||||||
|
### 那为什么 PD 分离在我们的实验中没有帮助?
|
||||||
|
|
||||||
|
正确的解释不是 "prefill 变成了 memory-bound",而是:
|
||||||
|
|
||||||
|
**1. Cache reuse 大幅减少了 prefill 的绝对计算量**
|
||||||
|
```
|
||||||
|
无 cache: avg 33.6k tokens × prefill compute = X FLOPs
|
||||||
|
71% cache: avg 9.4k tokens × prefill compute = 0.28X FLOPs
|
||||||
|
```
|
||||||
|
虽然 prefill 仍是 compute-bound,但 **总工作量只有原来的 28%**。
|
||||||
|
在 8 instance 并行 + cache-aware routing 下,每个 instance 的 prefill 负载非常轻,
|
||||||
|
不足以产生对 decode 的显著干扰。
|
||||||
|
|
||||||
|
**2. MoE 模型的 per-token compute 本身较小**
|
||||||
|
Active params 只有 3B(全参数的 10%),单个 token 的计算量不大。
|
||||||
|
对比 Dense 70B 模型,同样的 GPU 上 prefill-decode 干扰会严重得多。
|
||||||
|
|
||||||
|
**3. Cache-aware routing 的 "负载均衡" 效应**
|
||||||
|
当请求被路由到 cache 命中率高的 instance 时,该 instance 的实际 prefill 工作量更小,
|
||||||
|
自然减少了 P-D 争抢。这相当于 routing 层面的 "软 PD 分离"。
|
||||||
|
|
||||||
|
## 对比不同 workload 类型的 roofline 特征
|
||||||
|
|
||||||
|
```
|
||||||
|
Prefill AI Decode AI PD-Sep 价值
|
||||||
|
Dense 70B, Chatbot: 200-1000x 1-2x HIGH (compute-heavy P 干扰 D)
|
||||||
|
Dense 70B, Agent: 100-500x 1-2x MEDIUM (cache reduces P load)
|
||||||
|
MoE 30B, Chatbot: 100-500x 1-2x MEDIUM
|
||||||
|
MoE 30B, Agent: 50-200x 1-2x LOW (small active params + cache)
|
||||||
|
← 我们的位置
|
||||||
|
```
|
||||||
|
|
||||||
|
**PD 分离的 ROI 随着 cache hit 率升高和模型 active params 减少而下降。**
|
||||||
|
Agentic MoE 模型恰好在两个方面都不利于 PD 分离。
|
||||||
|
|
||||||
|
## 实际 trace 的 prefill bound 分布
|
||||||
|
|
||||||
|
```
|
||||||
|
With actual trace prefix cache pattern (1000 sampled requests):
|
||||||
|
Compute-bound prefills: 961 (96%)
|
||||||
|
Memory-bound prefills: 37 (3%) ← 近 100% reuse 的 warm 请求
|
||||||
|
(Decode is ALWAYS memory-bound)
|
||||||
|
```
|
||||||
|
|
||||||
|
96% 的 prefill 仍然是 compute-bound,但 **absolute compute 因 cache 大幅降低**。
|
||||||
|
这是一个 "compute-bound but lightweight" 的独特状态 —— bound 类型没变,但强度大幅降低。
|
||||||
16
pyproject.toml
Normal file
16
pyproject.toml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
[project]
|
||||||
|
name = "agentic-kv"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Trace-driven KV cache benchmarking for agentic LLM workloads"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"httpx>=0.27",
|
||||||
|
"numpy>=1.24",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = ["pytest"]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
0
replayer/__init__.py
Normal file
0
replayer/__init__.py
Normal file
55
replayer/__main__.py
Normal file
55
replayer/__main__.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""CLI entry point: python -m replayer replay ..."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .replay import ReplayConfig, replay_trace
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
p = argparse.ArgumentParser(description="Trace replayer for vLLM benchmarking")
|
||||||
|
p.add_argument("--trace", type=Path, required=True, help="Sampled trace JSONL")
|
||||||
|
p.add_argument("--output", type=Path, required=True, help="Output metrics JSONL")
|
||||||
|
p.add_argument("--endpoint", type=str, required=True,
|
||||||
|
help="vLLM server URL (e.g. http://localhost:8000)")
|
||||||
|
p.add_argument("--model", type=str, default="default", help="Model name for API")
|
||||||
|
p.add_argument("--time-scale", type=float, default=1.0,
|
||||||
|
help="Time compression (>1 = faster)")
|
||||||
|
p.add_argument("--max-inflight-sessions", type=int, default=32)
|
||||||
|
p.add_argument("--concurrency-limit", type=int, default=256)
|
||||||
|
p.add_argument("--request-timeout", type=float, default=600.0)
|
||||||
|
p.add_argument("--request-limit", type=int, default=None,
|
||||||
|
help="Limit number of requests to replay")
|
||||||
|
p.add_argument("-v", "--verbose", action="store_true")
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
config = ReplayConfig(
|
||||||
|
trace_path=args.trace,
|
||||||
|
output_path=args.output,
|
||||||
|
endpoint_url=args.endpoint.rstrip("/"),
|
||||||
|
model_name=args.model,
|
||||||
|
time_scale=args.time_scale,
|
||||||
|
max_inflight_sessions=args.max_inflight_sessions,
|
||||||
|
concurrency_limit=args.concurrency_limit,
|
||||||
|
request_timeout_s=args.request_timeout,
|
||||||
|
request_limit=args.request_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = asyncio.run(replay_trace(config))
|
||||||
|
succeeded = sum(1 for r in results if r.error is None)
|
||||||
|
print(f"\nDone: {succeeded}/{len(results)} requests succeeded")
|
||||||
|
print(f"Metrics: {args.output}")
|
||||||
|
print(f"Summary: {args.output.with_suffix('.summary.json')}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
107
replayer/metrics.py
Normal file
107
replayer/metrics.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
"""Per-request metrics collection and summary reporting."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import statistics
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RequestMetrics:
|
||||||
|
request_id: str
|
||||||
|
session_id: str
|
||||||
|
turn_id: int
|
||||||
|
trace_timestamp_s: float
|
||||||
|
input_length: int
|
||||||
|
output_length: int
|
||||||
|
request_type: str
|
||||||
|
effective_input_length: int | None
|
||||||
|
cached_tokens: int
|
||||||
|
latency_s: float | None
|
||||||
|
ttft_s: float | None
|
||||||
|
tpot_s: float | None
|
||||||
|
actual_output_tokens: int | None = None
|
||||||
|
requested_output_tokens: int | None = None
|
||||||
|
finish_reason: str | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class IncrementalMetricSink:
|
||||||
|
"""Append each RequestMetrics to JSONL immediately (crash-safe)."""
|
||||||
|
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.path = path
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_text("")
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._fh = path.open("a", encoding="utf-8", buffering=1)
|
||||||
|
|
||||||
|
async def append(self, metric: RequestMetrics) -> None:
|
||||||
|
line = json.dumps(asdict(metric), sort_keys=True) + "\n"
|
||||||
|
async with self._lock:
|
||||||
|
self._fh.write(line)
|
||||||
|
self._fh.flush()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
try:
|
||||||
|
self._fh.flush()
|
||||||
|
self._fh.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def write_summary_json(path: Path, rows: list[RequestMetrics]) -> None:
|
||||||
|
successful = [r for r in rows if r.error is None]
|
||||||
|
latencies = [r.latency_s for r in successful if r.latency_s is not None]
|
||||||
|
ttfts = [r.ttft_s for r in successful if r.ttft_s is not None]
|
||||||
|
tpots = [r.tpot_s for r in successful if r.tpot_s is not None]
|
||||||
|
|
||||||
|
total_input = sum(r.input_length for r in successful)
|
||||||
|
total_cached = sum(r.cached_tokens for r in successful)
|
||||||
|
|
||||||
|
summary: dict[str, Any] = {
|
||||||
|
"request_count": len(rows),
|
||||||
|
"success_count": len(successful),
|
||||||
|
"error_count": sum(1 for r in rows if r.error is not None),
|
||||||
|
"latency_stats_s": _stats(latencies),
|
||||||
|
"ttft_stats_s": _stats(ttfts),
|
||||||
|
"tpot_stats_s": _stats(tpots),
|
||||||
|
"cache_hit_request_count": sum(1 for r in successful if r.cached_tokens > 0),
|
||||||
|
"total_input_tokens": total_input,
|
||||||
|
"total_cached_tokens": total_cached,
|
||||||
|
"prefix_cache_hit_ratio": total_cached / total_input if total_input > 0 else 0.0,
|
||||||
|
"cached_tokens_stats": _stats([float(r.cached_tokens) for r in successful]),
|
||||||
|
"actual_output_tokens_stats": _stats(
|
||||||
|
[float(r.actual_output_tokens) for r in successful
|
||||||
|
if r.actual_output_tokens is not None]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with path.open("w", encoding="utf-8") as fh:
|
||||||
|
json.dump(summary, fh, indent=2, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _stats(values: list[float | None]) -> dict[str, float] | None:
|
||||||
|
clean = [v for v in values if v is not None]
|
||||||
|
if not clean:
|
||||||
|
return None
|
||||||
|
clean.sort()
|
||||||
|
return {
|
||||||
|
"count": float(len(clean)),
|
||||||
|
"mean": statistics.fmean(clean),
|
||||||
|
"p50": _percentile(clean, 0.50),
|
||||||
|
"p90": _percentile(clean, 0.90),
|
||||||
|
"p99": _percentile(clean, 0.99),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _percentile(sorted_vals: list[float], pct: float) -> float:
|
||||||
|
if len(sorted_vals) == 1:
|
||||||
|
return sorted_vals[0]
|
||||||
|
idx = round((len(sorted_vals) - 1) * pct)
|
||||||
|
return sorted_vals[idx]
|
||||||
343
replayer/replay.py
Normal file
343
replayer/replay.py
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
"""Trace replayer — send requests to vLLM following trace timing.
|
||||||
|
|
||||||
|
Supports both vLLM's /v1/completions (OpenAI-compatible) and /generate
|
||||||
|
(SGLang-style) endpoints. Uses hash_ids from the trace to construct
|
||||||
|
synthetic prompts that reproduce realistic prefix-cache hit patterns.
|
||||||
|
|
||||||
|
Key behaviors:
|
||||||
|
- Per-session sequencing: turns within a session are sent in order,
|
||||||
|
each waiting for the previous to complete before dispatching.
|
||||||
|
- Inter-session arrival: sessions start at their trace timestamps,
|
||||||
|
scaled by --time-scale.
|
||||||
|
- Concurrency control: --max-inflight-sessions caps concurrent sessions;
|
||||||
|
--concurrency-limit caps total in-flight requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import random as _random
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json
|
||||||
|
from .trace import TraceRequest, load_trace
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
BLOCK_SIZE = 512
|
||||||
|
VOCAB_SIZE = 151936
|
||||||
|
TOKEN_RANGE_START = 100
|
||||||
|
TOKEN_RANGE_END = VOCAB_SIZE - 100
|
||||||
|
|
||||||
|
_block_cache: dict[int, list[int]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_id_to_token_ids(hash_id: int) -> list[int]:
|
||||||
|
"""Deterministically map a hash_id to BLOCK_SIZE token IDs."""
|
||||||
|
if hash_id in _block_cache:
|
||||||
|
return _block_cache[hash_id]
|
||||||
|
rng = _random.Random(hash_id)
|
||||||
|
ids = [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END) for _ in range(BLOCK_SIZE)]
|
||||||
|
_block_cache[hash_id] = ids
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReplayConfig:
|
||||||
|
trace_path: Path
|
||||||
|
output_path: Path
|
||||||
|
endpoint_url: str # comma-separated for round-robin: "http://host:8000,http://host:8001"
|
||||||
|
time_scale: float = 1.0
|
||||||
|
max_inflight_sessions: int = 32
|
||||||
|
concurrency_limit: int = 256
|
||||||
|
request_timeout_s: float = 600.0
|
||||||
|
request_limit: int | None = None
|
||||||
|
model_name: str = "default"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_prompt_token_ids(req: TraceRequest) -> list[int]:
|
||||||
|
"""Build token IDs from hash_ids for prefix-cache-aware replay.
|
||||||
|
|
||||||
|
Same hash_id prefix → same token ID prefix → APC cache hit in vLLM.
|
||||||
|
"""
|
||||||
|
ids: list[int] = []
|
||||||
|
for hid in req.hash_ids:
|
||||||
|
ids.extend(_hash_id_to_token_ids(hid))
|
||||||
|
# Pad to input_length with deterministic tokens
|
||||||
|
pad_rng = _random.Random(req.chat_id)
|
||||||
|
while len(ids) < req.input_length:
|
||||||
|
ids.append(pad_rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END))
|
||||||
|
return ids[:req.input_length]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _SessionState:
|
||||||
|
session_id: str
|
||||||
|
turns: list[TraceRequest]
|
||||||
|
metrics: list[RequestMetrics] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
_endpoint_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _pick_endpoint(config: ReplayConfig) -> str:
|
||||||
|
"""Round-robin across comma-separated endpoints."""
|
||||||
|
global _endpoint_counter
|
||||||
|
endpoints = [e.strip() for e in config.endpoint_url.split(",")]
|
||||||
|
url = endpoints[_endpoint_counter % len(endpoints)]
|
||||||
|
_endpoint_counter += 1
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
async def _dispatch_request(
|
||||||
|
*,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
config: ReplayConfig,
|
||||||
|
req: TraceRequest,
|
||||||
|
prompt_token_ids: list[int],
|
||||||
|
sem: asyncio.Semaphore,
|
||||||
|
) -> RequestMetrics:
|
||||||
|
"""Send one request via /v1/completions (streaming) and collect metrics."""
|
||||||
|
endpoint = _pick_endpoint(config)
|
||||||
|
payload = {
|
||||||
|
"model": config.model_name,
|
||||||
|
"prompt": prompt_token_ids,
|
||||||
|
"max_tokens": max(1, req.output_length),
|
||||||
|
"temperature": 0,
|
||||||
|
"stream": True,
|
||||||
|
"stream_options": {"include_usage": True},
|
||||||
|
}
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
ttft_s = None
|
||||||
|
n_output = 0
|
||||||
|
cached_tokens = 0
|
||||||
|
finish_reason = None
|
||||||
|
err = None
|
||||||
|
token_times: list[float] = []
|
||||||
|
|
||||||
|
async with sem:
|
||||||
|
try:
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
f"{endpoint}/v1/completions",
|
||||||
|
json=payload,
|
||||||
|
timeout=config.request_timeout_s,
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
async for raw_line in resp.aiter_lines():
|
||||||
|
if not raw_line or not raw_line.startswith("data:"):
|
||||||
|
continue
|
||||||
|
data = raw_line[5:].strip()
|
||||||
|
if data == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
now = time.perf_counter()
|
||||||
|
if ttft_s is None:
|
||||||
|
ttft_s = now - start
|
||||||
|
|
||||||
|
choices = chunk.get("choices", [])
|
||||||
|
if choices:
|
||||||
|
delta = choices[0].get("text", "")
|
||||||
|
if delta:
|
||||||
|
token_times.append(now)
|
||||||
|
fr = choices[0].get("finish_reason")
|
||||||
|
if fr:
|
||||||
|
finish_reason = fr
|
||||||
|
|
||||||
|
usage = chunk.get("usage")
|
||||||
|
if usage:
|
||||||
|
n_output = usage.get("completion_tokens", n_output)
|
||||||
|
cached_tokens = _extract_cached_tokens(usage)
|
||||||
|
except Exception as exc:
|
||||||
|
err = repr(exc)[:300]
|
||||||
|
|
||||||
|
end = time.perf_counter()
|
||||||
|
e2e = end - start
|
||||||
|
if n_output == 0 and token_times:
|
||||||
|
n_output = len(token_times)
|
||||||
|
|
||||||
|
tpot = 0.0
|
||||||
|
if len(token_times) > 1:
|
||||||
|
inter_token = [token_times[i+1] - token_times[i]
|
||||||
|
for i in range(len(token_times) - 1)]
|
||||||
|
tpot = sum(inter_token) / len(inter_token)
|
||||||
|
|
||||||
|
return RequestMetrics(
|
||||||
|
request_id=req.request_id,
|
||||||
|
session_id=req.session_id,
|
||||||
|
turn_id=req.turn_id,
|
||||||
|
trace_timestamp_s=req.timestamp_s,
|
||||||
|
input_length=req.input_length,
|
||||||
|
output_length=req.output_length,
|
||||||
|
request_type=req.request_type,
|
||||||
|
effective_input_length=len(prompt_token_ids),
|
||||||
|
cached_tokens=cached_tokens,
|
||||||
|
latency_s=e2e,
|
||||||
|
ttft_s=ttft_s,
|
||||||
|
tpot_s=tpot,
|
||||||
|
actual_output_tokens=n_output,
|
||||||
|
requested_output_tokens=req.output_length,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
error=err,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_cached_tokens(usage: dict) -> int:
|
||||||
|
ct = 0
|
||||||
|
details = usage.get("prompt_tokens_details")
|
||||||
|
if isinstance(details, dict):
|
||||||
|
ct = details.get("cached_tokens", 0) or 0
|
||||||
|
if ct == 0:
|
||||||
|
ct = usage.get("cached_tokens", 0) or 0
|
||||||
|
return int(ct)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_session(
|
||||||
|
*,
|
||||||
|
state: _SessionState,
|
||||||
|
config: ReplayConfig,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
session_sem: asyncio.Semaphore,
|
||||||
|
request_sem: asyncio.Semaphore,
|
||||||
|
earliest_ts: float,
|
||||||
|
sweep_start: float,
|
||||||
|
sink: IncrementalMetricSink,
|
||||||
|
) -> list[RequestMetrics]:
|
||||||
|
async with session_sem:
|
||||||
|
# Wait until this session's start time
|
||||||
|
offset = (state.turns[0].timestamp_s - earliest_ts) / config.time_scale
|
||||||
|
wait = offset - (time.perf_counter() - sweep_start)
|
||||||
|
if wait > 0:
|
||||||
|
await asyncio.sleep(wait)
|
||||||
|
|
||||||
|
for req in state.turns:
|
||||||
|
# Intra-session: wait for turn's relative offset
|
||||||
|
if req != state.turns[0]:
|
||||||
|
target = (req.timestamp_s - state.turns[0].timestamp_s) / config.time_scale
|
||||||
|
elapsed = time.perf_counter() - sweep_start - offset
|
||||||
|
if elapsed < target:
|
||||||
|
await asyncio.sleep(target - elapsed)
|
||||||
|
|
||||||
|
token_ids = _build_prompt_token_ids(req)
|
||||||
|
metric = await _dispatch_request(
|
||||||
|
client=client, config=config, req=req,
|
||||||
|
prompt_token_ids=token_ids, sem=request_sem,
|
||||||
|
)
|
||||||
|
state.metrics.append(metric)
|
||||||
|
await sink.append(metric)
|
||||||
|
|
||||||
|
return state.metrics
|
||||||
|
|
||||||
|
|
||||||
|
async def _snapshot_prefix_cache_metrics(url_csv: str) -> dict[str, float]:
|
||||||
|
"""Scrape vLLM /metrics for prefix cache counters (aggregated across endpoints)."""
|
||||||
|
total = {"queries": 0.0, "hits": 0.0}
|
||||||
|
endpoints = [e.strip() for e in url_csv.split(",")]
|
||||||
|
async with httpx.AsyncClient(timeout=10) as c:
|
||||||
|
for url in endpoints:
|
||||||
|
try:
|
||||||
|
r = await c.get(f"{url}/metrics")
|
||||||
|
for line in r.text.split("\n"):
|
||||||
|
if line.startswith("vllm:prefix_cache_queries_total"):
|
||||||
|
total["queries"] += float(line.split()[-1])
|
||||||
|
elif line.startswith("vllm:prefix_cache_hits_total"):
|
||||||
|
total["hits"] += float(line.split()[-1])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
|
||||||
|
"""Main entry: load trace, replay against endpoint, return metrics."""
|
||||||
|
requests = load_trace(config.trace_path, request_limit=config.request_limit)
|
||||||
|
if not requests:
|
||||||
|
return []
|
||||||
|
|
||||||
|
by_session: dict[str, list[TraceRequest]] = defaultdict(list)
|
||||||
|
for r in requests:
|
||||||
|
by_session[r.session_id].append(r)
|
||||||
|
for sid in by_session:
|
||||||
|
by_session[sid].sort(key=lambda r: (r.turn_id, r.timestamp_s))
|
||||||
|
|
||||||
|
sessions = sorted(by_session.items(), key=lambda kv: kv[1][0].timestamp_s)
|
||||||
|
earliest_ts = sessions[0][1][0].timestamp_s
|
||||||
|
|
||||||
|
session_sem = asyncio.Semaphore(config.max_inflight_sessions)
|
||||||
|
request_sem = asyncio.Semaphore(config.concurrency_limit)
|
||||||
|
|
||||||
|
sink = IncrementalMetricSink(config.output_path)
|
||||||
|
|
||||||
|
n_sessions = len(sessions)
|
||||||
|
n_requests = len(requests)
|
||||||
|
logger.info("Replaying %d sessions (%d requests), time_scale=%.1f",
|
||||||
|
n_sessions, n_requests, config.time_scale)
|
||||||
|
|
||||||
|
pre_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
|
||||||
|
sweep_start = time.perf_counter()
|
||||||
|
|
||||||
|
try:
|
||||||
|
limits = httpx.Limits(
|
||||||
|
max_connections=2000,
|
||||||
|
max_keepalive_connections=500,
|
||||||
|
keepalive_expiry=30.0,
|
||||||
|
)
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
timeout=config.request_timeout_s,
|
||||||
|
trust_env=False,
|
||||||
|
limits=limits,
|
||||||
|
) as client:
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(_run_session(
|
||||||
|
state=_SessionState(session_id=sid, turns=turns),
|
||||||
|
config=config, client=client,
|
||||||
|
session_sem=session_sem, request_sem=request_sem,
|
||||||
|
earliest_ts=earliest_ts, sweep_start=sweep_start,
|
||||||
|
sink=sink,
|
||||||
|
))
|
||||||
|
for sid, turns in sessions
|
||||||
|
]
|
||||||
|
all_results = await asyncio.gather(*tasks)
|
||||||
|
finally:
|
||||||
|
sink.close()
|
||||||
|
|
||||||
|
sweep_elapsed = time.perf_counter() - sweep_start
|
||||||
|
post_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
|
||||||
|
|
||||||
|
flat = [m for group in all_results for m in group]
|
||||||
|
summary_path = config.output_path.with_suffix(".summary.json")
|
||||||
|
write_summary_json(summary_path, flat)
|
||||||
|
|
||||||
|
# Compute aggregate prefix cache hit ratio from /metrics deltas
|
||||||
|
delta_queries = post_metrics.get("queries", 0) - pre_metrics.get("queries", 0)
|
||||||
|
delta_hits = post_metrics.get("hits", 0) - pre_metrics.get("hits", 0)
|
||||||
|
hit_ratio = delta_hits / delta_queries if delta_queries > 0 else 0.0
|
||||||
|
|
||||||
|
logger.info("Done: %d/%d succeeded in %.1fs", sum(1 for m in flat if m.error is None), len(flat), sweep_elapsed)
|
||||||
|
logger.info("Prefix cache: %.1f%% hit ratio (%d/%d tokens)",
|
||||||
|
hit_ratio * 100, int(delta_hits), int(delta_queries))
|
||||||
|
|
||||||
|
# Append cache stats to summary
|
||||||
|
import json as _json
|
||||||
|
summary = _json.loads(summary_path.read_text())
|
||||||
|
summary["prefix_cache_queries_tokens"] = int(delta_queries)
|
||||||
|
summary["prefix_cache_hits_tokens"] = int(delta_hits)
|
||||||
|
summary["prefix_cache_hit_ratio"] = hit_ratio
|
||||||
|
summary["wall_clock_s"] = sweep_elapsed
|
||||||
|
summary_path.write_text(_json.dumps(summary, indent=2, sort_keys=True))
|
||||||
|
|
||||||
|
logger.info("Summary written to %s", summary_path)
|
||||||
|
return flat
|
||||||
84
replayer/trace.py
Normal file
84
replayer/trace.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Trace data structures and loader for the Ali agentic-coder trace format.
|
||||||
|
|
||||||
|
Trace format (one JSON per line):
|
||||||
|
chat_id, parent_chat_id, timestamp, input_length, output_length,
|
||||||
|
type, turn, hash_ids[]
|
||||||
|
|
||||||
|
Sessions are derived from parent_chat_id chains:
|
||||||
|
- parent_chat_id == -1 → new session root
|
||||||
|
- parent_chat_id >= 0 → belongs to the same session as the parent
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TraceRequest:
|
||||||
|
request_id: str
|
||||||
|
session_id: str
|
||||||
|
chat_id: int
|
||||||
|
parent_chat_id: int
|
||||||
|
timestamp_s: float
|
||||||
|
input_length: int
|
||||||
|
output_length: int
|
||||||
|
request_type: str
|
||||||
|
turn_id: int
|
||||||
|
hash_ids: tuple[int, ...]
|
||||||
|
|
||||||
|
|
||||||
|
def load_trace(
|
||||||
|
path: Path,
|
||||||
|
*,
|
||||||
|
request_limit: int | None = None,
|
||||||
|
) -> list[TraceRequest]:
|
||||||
|
"""Load trace and resolve session IDs from parent_chat_id chains."""
|
||||||
|
chat_to_session: dict[int, str] = {}
|
||||||
|
requests: list[TraceRequest] = []
|
||||||
|
|
||||||
|
with path.open("r", encoding="utf-8") as fh:
|
||||||
|
for idx, line in enumerate(fh):
|
||||||
|
if request_limit is not None and len(requests) >= request_limit:
|
||||||
|
break
|
||||||
|
row = json.loads(line)
|
||||||
|
chat_id = int(row["chat_id"])
|
||||||
|
parent_chat_id = int(row["parent_chat_id"])
|
||||||
|
|
||||||
|
if "session_id" in row:
|
||||||
|
session_id = str(row["session_id"])
|
||||||
|
else:
|
||||||
|
session_id = _resolve_session_id(
|
||||||
|
chat_id, parent_chat_id, chat_to_session,
|
||||||
|
)
|
||||||
|
chat_to_session[chat_id] = session_id
|
||||||
|
|
||||||
|
requests.append(TraceRequest(
|
||||||
|
request_id=f"{session_id}:{row['turn']}:{chat_id}:{idx}",
|
||||||
|
session_id=session_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
parent_chat_id=parent_chat_id,
|
||||||
|
timestamp_s=float(row["timestamp"]),
|
||||||
|
input_length=int(row["input_length"]),
|
||||||
|
output_length=int(row["output_length"]),
|
||||||
|
request_type=str(row["type"]),
|
||||||
|
turn_id=int(row["turn"]),
|
||||||
|
hash_ids=tuple(int(h) for h in row.get("hash_ids", [])),
|
||||||
|
))
|
||||||
|
|
||||||
|
return requests
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_session_id(
|
||||||
|
chat_id: int,
|
||||||
|
parent_chat_id: int,
|
||||||
|
chat_to_session: dict[int, str],
|
||||||
|
) -> str:
|
||||||
|
if parent_chat_id < 0:
|
||||||
|
session_id = str(chat_id)
|
||||||
|
else:
|
||||||
|
session_id = chat_to_session.get(parent_chat_id, str(parent_chat_id))
|
||||||
|
chat_to_session[chat_id] = session_id
|
||||||
|
return session_id
|
||||||
196
scripts/analyze_cache_hit.py
Normal file
196
scripts/analyze_cache_hit.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
"""Analyze theoretical vs actual KV cache hit ratio for the agentic trace."""
|
||||||
|
import json
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
rows = [json.loads(l) for l in open("traces/sampled_1000req_seed42.jsonl")]
|
||||||
|
rows.sort(key=lambda r: float(r["timestamp"]))
|
||||||
|
|
||||||
|
BLOCK_SIZE = 512
|
||||||
|
|
||||||
|
# === 1. Theoretical max: infinite cache, single instance ===
|
||||||
|
total_tokens = 0
|
||||||
|
total_cached = 0
|
||||||
|
seen_blocks = set()
|
||||||
|
per_req = []
|
||||||
|
|
||||||
|
for r in rows:
|
||||||
|
input_len = r["input_length"]
|
||||||
|
hash_ids = r.get("hash_ids", [])
|
||||||
|
total_tokens += input_len
|
||||||
|
|
||||||
|
cached_blocks = 0
|
||||||
|
prefix_broken = False
|
||||||
|
for hid in hash_ids:
|
||||||
|
if not prefix_broken and hid in seen_blocks:
|
||||||
|
cached_blocks += 1
|
||||||
|
else:
|
||||||
|
prefix_broken = True
|
||||||
|
|
||||||
|
cached_tokens = cached_blocks * BLOCK_SIZE
|
||||||
|
total_cached += cached_tokens
|
||||||
|
for hid in hash_ids:
|
||||||
|
seen_blocks.add(hid)
|
||||||
|
|
||||||
|
per_req.append({
|
||||||
|
"input_length": input_len,
|
||||||
|
"cached_tokens": cached_tokens,
|
||||||
|
"new_tokens": max(0, input_len - cached_tokens),
|
||||||
|
"ratio": cached_tokens / input_len if input_len > 0 else 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
sep = "=" * 70
|
||||||
|
print(sep)
|
||||||
|
print(" THEORETICAL KV CACHE HIT (infinite cache, single instance)")
|
||||||
|
print(sep)
|
||||||
|
print(f" Total input tokens: {total_tokens:>14,}")
|
||||||
|
print(f" Cacheable (prefix hit): {total_cached:>14,} ({total_cached*100//total_tokens}%)")
|
||||||
|
print(f" Must prefill (new): {total_tokens-total_cached:>14,} ({(total_tokens-total_cached)*100//total_tokens}%)")
|
||||||
|
|
||||||
|
ratios = sorted([s["ratio"] for s in per_req if s["input_length"] > 0])
|
||||||
|
new_tokens = sorted([s["new_tokens"] for s in per_req if s["input_length"] > 0])
|
||||||
|
p = lambda v, q: v[min(int(q*len(v)), len(v)-1)]
|
||||||
|
|
||||||
|
print(f"\n Per-request cache hit ratio:")
|
||||||
|
print(f" p10={p(ratios,.1)*100:.1f}% p50={p(ratios,.5)*100:.1f}% p90={p(ratios,.9)*100:.1f}% mean={sum(ratios)/len(ratios)*100:.1f}%")
|
||||||
|
high = sum(1 for r in ratios if r > 0.5)
|
||||||
|
very_high = sum(1 for r in ratios if r > 0.9)
|
||||||
|
zero = sum(1 for r in ratios if r == 0)
|
||||||
|
print(f" 0% hit (cold start): {zero} ({zero*100//len(ratios)}%)")
|
||||||
|
print(f" >50% hit: {high} ({high*100//len(ratios)}%)")
|
||||||
|
print(f" >90% hit: {very_high} ({very_high*100//len(ratios)}%)")
|
||||||
|
|
||||||
|
print(f"\n Actual new tokens to prefill per request:")
|
||||||
|
print(f" p10={p(new_tokens,.1):>7,} p50={p(new_tokens,.5):>7,} p90={p(new_tokens,.9):>7,} max={max(new_tokens):>7,}")
|
||||||
|
|
||||||
|
# === 2. 4-instance split (simulating DP=4 or 4 prefill instances) ===
|
||||||
|
print(f"\n{sep}")
|
||||||
|
print(" 4-INSTANCE SPLIT (round-robin, per-instance cache)")
|
||||||
|
print(sep)
|
||||||
|
|
||||||
|
instance_seen = [set() for _ in range(4)]
|
||||||
|
inst_total = [0]*4
|
||||||
|
inst_cached = [0]*4
|
||||||
|
|
||||||
|
for i, r in enumerate(rows):
|
||||||
|
inst = i % 4
|
||||||
|
input_len = r["input_length"]
|
||||||
|
hash_ids = r.get("hash_ids", [])
|
||||||
|
inst_total[inst] += input_len
|
||||||
|
|
||||||
|
cached_blocks = 0
|
||||||
|
prefix_broken = False
|
||||||
|
for hid in hash_ids:
|
||||||
|
if not prefix_broken and hid in instance_seen[inst]:
|
||||||
|
cached_blocks += 1
|
||||||
|
else:
|
||||||
|
prefix_broken = True
|
||||||
|
|
||||||
|
inst_cached[inst] += cached_blocks * BLOCK_SIZE
|
||||||
|
for hid in hash_ids:
|
||||||
|
instance_seen[inst].add(hid)
|
||||||
|
|
||||||
|
rr_total = sum(inst_total)
|
||||||
|
rr_cached = sum(inst_cached)
|
||||||
|
print(f" Cache hit ratio (RR): {rr_cached*100//rr_total}%")
|
||||||
|
|
||||||
|
# === 3. Cache-aware routing (route to instance with best prefix match) ===
|
||||||
|
print(f"\n{sep}")
|
||||||
|
print(" 4-INSTANCE CACHE-AWARE ROUTING")
|
||||||
|
print(sep)
|
||||||
|
|
||||||
|
ca_seen = [set() for _ in range(4)]
|
||||||
|
ca_total = [0]*4
|
||||||
|
ca_cached = [0]*4
|
||||||
|
|
||||||
|
for r in rows:
|
||||||
|
input_len = r["input_length"]
|
||||||
|
hash_ids = r.get("hash_ids", [])
|
||||||
|
|
||||||
|
# Pick instance with most prefix blocks cached
|
||||||
|
best_inst = 0
|
||||||
|
best_hit = 0
|
||||||
|
for inst in range(4):
|
||||||
|
hit = 0
|
||||||
|
for hid in hash_ids:
|
||||||
|
if hid in ca_seen[inst]:
|
||||||
|
hit += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if hit > best_hit:
|
||||||
|
best_hit = hit
|
||||||
|
best_inst = inst
|
||||||
|
|
||||||
|
ca_total[best_inst] += input_len
|
||||||
|
ca_cached[best_inst] += best_hit * BLOCK_SIZE
|
||||||
|
for hid in hash_ids:
|
||||||
|
ca_seen[best_inst].add(hid)
|
||||||
|
|
||||||
|
ca_total_sum = sum(ca_total)
|
||||||
|
ca_cached_sum = sum(ca_cached)
|
||||||
|
print(f" Cache hit ratio: {ca_cached_sum*100//ca_total_sum}%")
|
||||||
|
print(f" vs RR: {rr_cached*100//rr_total}% -> {ca_cached_sum*100//ca_total_sum}% (+{(ca_cached_sum-rr_cached)*100//rr_total}pp)")
|
||||||
|
|
||||||
|
# === 4. Session structure analysis ===
|
||||||
|
print(f"\n{sep}")
|
||||||
|
print(" SESSION & MULTI-TURN ANALYSIS")
|
||||||
|
print(sep)
|
||||||
|
|
||||||
|
sessions = {}
|
||||||
|
chat_to_session = {}
|
||||||
|
for r in rows:
|
||||||
|
cid = int(r["chat_id"])
|
||||||
|
pid = int(r["parent_chat_id"])
|
||||||
|
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)))
|
||||||
|
chat_to_session[cid] = str(sid)
|
||||||
|
sessions.setdefault(str(sid), []).append(r)
|
||||||
|
|
||||||
|
multi = {k: v for k, v in sessions.items() if len(v) > 1}
|
||||||
|
single = {k: v for k, v in sessions.items() if len(v) == 1}
|
||||||
|
|
||||||
|
print(f" Sessions: {len(sessions)} total, {len(multi)} multi-turn ({len(multi)*100//len(sessions)}%)")
|
||||||
|
|
||||||
|
# Multi-turn: cache hit in turn 2+
|
||||||
|
mt_new = 0
|
||||||
|
mt_reuse = 0
|
||||||
|
for sid, turns in multi.items():
|
||||||
|
turns.sort(key=lambda r: r["turn"])
|
||||||
|
prev_blocks = set()
|
||||||
|
for t in turns:
|
||||||
|
hids = t.get("hash_ids", [])
|
||||||
|
for hid in hids:
|
||||||
|
if hid in prev_blocks:
|
||||||
|
mt_reuse += BLOCK_SIZE
|
||||||
|
else:
|
||||||
|
mt_new += BLOCK_SIZE
|
||||||
|
prev_blocks.add(hid)
|
||||||
|
|
||||||
|
mt_total_tok = mt_new + mt_reuse
|
||||||
|
print(f" Multi-turn intra-session reuse: {mt_reuse*100//mt_total_tok}% of tokens")
|
||||||
|
print(f" (Turn 2+ reuses KV from prior turns in same session)")
|
||||||
|
|
||||||
|
# Single-turn: cross-session sharing via system prompt
|
||||||
|
block_freq = Counter()
|
||||||
|
for r in rows:
|
||||||
|
for hid in r.get("hash_ids", []):
|
||||||
|
block_freq[hid] += 1
|
||||||
|
|
||||||
|
shared = {k: v for k, v in block_freq.items() if v > 1}
|
||||||
|
top = block_freq.most_common(5)
|
||||||
|
print(f"\n Cross-session block sharing:")
|
||||||
|
print(f" Unique blocks: {len(block_freq):,}")
|
||||||
|
print(f" Shared (ref>1): {len(shared):,} ({len(shared)*100//len(block_freq)}%)")
|
||||||
|
print(f" Top-5 block ref counts: {[c for _,c in top]}")
|
||||||
|
print(f" (Shared blocks = system prompt / common code context)")
|
||||||
|
|
||||||
|
# === 5. Implication for PD separation ===
|
||||||
|
print(f"\n{sep}")
|
||||||
|
print(" IMPLICATION FOR PD SEPARATION")
|
||||||
|
print(sep)
|
||||||
|
actual_prefill_pct = (total_tokens - total_cached) * 100 // total_tokens
|
||||||
|
print(f" With perfect caching, only {actual_prefill_pct}% of tokens need actual prefill compute.")
|
||||||
|
print(f" The remaining {100-actual_prefill_pct}% are prefix cache hits (skip prefill, reuse KV).")
|
||||||
|
print(f" This means PD separation's prefill overhead is much smaller than it appears:")
|
||||||
|
print(f" - Nominal avg input: {total_tokens//len(rows):,} tokens/request")
|
||||||
|
new_per_req = sorted([s["new_tokens"] for s in per_req if s["input_length"] > 0])
|
||||||
|
print(f" - Actual avg prefill: {sum(new_per_req)//len(new_per_req):,} tokens/request (after cache hit)")
|
||||||
|
print(f" - KV transfer size is also reduced (only transfer new blocks)")
|
||||||
163
scripts/analyze_trace.py
Normal file
163
scripts/analyze_trace.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
"""Analyze trace patterns to assess PD separation benefit.
|
||||||
|
|
||||||
|
Computes metrics relevant to deciding PD-combined vs PD-separated:
|
||||||
|
- Input/output token ratio (high ratio = prefill-heavy → PD sep benefits)
|
||||||
|
- Prefix sharing density (high sharing → benefits from shared KV cache)
|
||||||
|
- Session length distribution (multi-turn = more prefix reuse)
|
||||||
|
- Arrival burstiness (bursty prefill → PD sep can absorb spikes)
|
||||||
|
- Compute-intensity ratio: prefill FLOP share vs decode FLOP share
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/analyze_trace.py --input traces/sampled_1000req_seed42.jsonl
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import collections
|
||||||
|
import json
|
||||||
|
import statistics
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser(description=__doc__,
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||||
|
p.add_argument("--input", type=Path, required=True)
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
with args.input.open() as fh:
|
||||||
|
for line in fh:
|
||||||
|
rows.append(json.loads(line))
|
||||||
|
|
||||||
|
# Session structure
|
||||||
|
sessions: dict[str, list[dict]] = collections.OrderedDict()
|
||||||
|
chat_to_session: dict[int, str] = {}
|
||||||
|
for r in rows:
|
||||||
|
cid = int(r["chat_id"])
|
||||||
|
pid = int(r["parent_chat_id"])
|
||||||
|
sid = r.get("session_id")
|
||||||
|
if sid is None:
|
||||||
|
sid = str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))
|
||||||
|
chat_to_session[cid] = str(sid)
|
||||||
|
sessions.setdefault(str(sid), []).append(r)
|
||||||
|
|
||||||
|
n_sessions = len(sessions)
|
||||||
|
turns_per_session = [len(v) for v in sessions.values()]
|
||||||
|
multi_turn = sum(1 for t in turns_per_session if t > 1)
|
||||||
|
|
||||||
|
input_lens = [r["input_length"] for r in rows]
|
||||||
|
output_lens = [r["output_length"] for r in rows]
|
||||||
|
total_input = sum(input_lens)
|
||||||
|
total_output = sum(output_lens)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Trace Pattern Analysis for PD Separation Decision")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 1. Input/Output ratio
|
||||||
|
io_ratio = total_input / max(total_output, 1)
|
||||||
|
print(f"\n1. Input/Output Token Ratio")
|
||||||
|
print(f" Total input tokens: {total_input:>12,}")
|
||||||
|
print(f" Total output tokens: {total_output:>12,}")
|
||||||
|
print(f" I/O ratio: {io_ratio:>12.1f}x")
|
||||||
|
print(f" → {'STRONGLY' if io_ratio > 50 else 'Moderately' if io_ratio > 10 else 'Weakly'} prefill-heavy")
|
||||||
|
|
||||||
|
# 2. Prefill compute share
|
||||||
|
# Approximate: prefill FLOP ∝ input_length, decode FLOP ∝ output_length * input_length
|
||||||
|
# More precisely: prefill dominates when input >> output
|
||||||
|
prefill_share = total_input / (total_input + total_output)
|
||||||
|
print(f"\n2. Compute Split (token count proxy)")
|
||||||
|
print(f" Prefill share: {prefill_share*100:.1f}%")
|
||||||
|
print(f" Decode share: {(1-prefill_share)*100:.1f}%")
|
||||||
|
|
||||||
|
# 3. Session structure
|
||||||
|
print(f"\n3. Session Structure")
|
||||||
|
print(f" Sessions: {n_sessions}")
|
||||||
|
print(f" Requests: {len(rows)}")
|
||||||
|
print(f" Multi-turn: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)")
|
||||||
|
print(f" Turns/sess: min={min(turns_per_session)} max={max(turns_per_session)} "
|
||||||
|
f"avg={statistics.fmean(turns_per_session):.1f}")
|
||||||
|
|
||||||
|
# 4. Prefix sharing
|
||||||
|
all_hash_ids = set()
|
||||||
|
per_request_hashes = []
|
||||||
|
for r in rows:
|
||||||
|
hids = set(r.get("hash_ids", []))
|
||||||
|
per_request_hashes.append(hids)
|
||||||
|
all_hash_ids.update(hids)
|
||||||
|
|
||||||
|
hash_refcount = collections.Counter()
|
||||||
|
for hids in per_request_hashes:
|
||||||
|
for h in hids:
|
||||||
|
hash_refcount[h] += 1
|
||||||
|
|
||||||
|
shared_blocks = sum(1 for h, c in hash_refcount.items() if c > 1)
|
||||||
|
total_blocks = len(all_hash_ids)
|
||||||
|
block_reuse = shared_blocks / max(total_blocks, 1)
|
||||||
|
avg_refcount = statistics.fmean(hash_refcount.values()) if hash_refcount else 0
|
||||||
|
|
||||||
|
print(f"\n4. Prefix Block Sharing")
|
||||||
|
print(f" Unique blocks: {total_blocks:>10,}")
|
||||||
|
print(f" Shared (ref>1): {shared_blocks:>10,} ({block_reuse*100:.1f}%)")
|
||||||
|
print(f" Avg refcount: {avg_refcount:>10.2f}")
|
||||||
|
print(f" → {'High' if block_reuse > 0.3 else 'Moderate' if block_reuse > 0.1 else 'Low'} prefix reuse potential")
|
||||||
|
|
||||||
|
# 5. Input length distribution
|
||||||
|
input_sorted = sorted(input_lens)
|
||||||
|
pct = lambda q: input_sorted[min(int(q * len(input_sorted)), len(input_sorted) - 1)]
|
||||||
|
print(f"\n5. Input Length Distribution")
|
||||||
|
print(f" p10={pct(0.1):>8,} p50={pct(0.5):>8,} p90={pct(0.9):>8,} max={max(input_lens):>8,}")
|
||||||
|
long_context = sum(1 for l in input_lens if l > 32000)
|
||||||
|
print(f" Requests >32k tokens: {long_context} ({long_context/len(rows)*100:.1f}%)")
|
||||||
|
|
||||||
|
# 6. Arrival pattern
|
||||||
|
timestamps = sorted(float(r["timestamp"]) for r in rows)
|
||||||
|
span = timestamps[-1] - timestamps[0]
|
||||||
|
avg_rate = len(rows) / max(span, 0.001)
|
||||||
|
|
||||||
|
# Burstiness: coefficient of variation of inter-arrival times
|
||||||
|
inter_arrivals = [timestamps[i+1] - timestamps[i] for i in range(len(timestamps) - 1)]
|
||||||
|
inter_arrivals = [t for t in inter_arrivals if t > 0]
|
||||||
|
if inter_arrivals:
|
||||||
|
cv = statistics.stdev(inter_arrivals) / statistics.fmean(inter_arrivals)
|
||||||
|
else:
|
||||||
|
cv = 0
|
||||||
|
print(f"\n6. Arrival Pattern")
|
||||||
|
print(f" Span: {span:.1f}s ({span/60:.1f} min)")
|
||||||
|
print(f" Avg rate: {avg_rate:.2f} req/s")
|
||||||
|
print(f" Burstiness (CoV): {cv:.2f}")
|
||||||
|
print(f" → {'Bursty' if cv > 1.5 else 'Moderate' if cv > 0.8 else 'Steady'} arrival pattern")
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
print("Summary: PD Separation Recommendation")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
factors = []
|
||||||
|
if io_ratio > 50:
|
||||||
|
factors.append("Very high I/O ratio (prefill-dominated)")
|
||||||
|
elif io_ratio > 10:
|
||||||
|
factors.append("High I/O ratio")
|
||||||
|
if block_reuse > 0.1:
|
||||||
|
factors.append(f"Significant prefix reuse ({block_reuse*100:.0f}% shared blocks)")
|
||||||
|
if long_context / len(rows) > 0.3:
|
||||||
|
factors.append(f"Many long-context requests ({long_context/len(rows)*100:.0f}%)")
|
||||||
|
if cv > 1.0:
|
||||||
|
factors.append("Bursty arrivals (PD sep absorbs prefill spikes)")
|
||||||
|
|
||||||
|
if len(factors) >= 2:
|
||||||
|
print("→ RECOMMEND PD separation:")
|
||||||
|
elif len(factors) == 1:
|
||||||
|
print("→ PD separation MAY help:")
|
||||||
|
else:
|
||||||
|
print("→ PD separation likely NOT beneficial:")
|
||||||
|
|
||||||
|
for f in factors:
|
||||||
|
print(f" • {f}")
|
||||||
|
if not factors:
|
||||||
|
print(" • No strong indicators for PD separation benefit")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
280
scripts/cache_aware_proxy.py
Normal file
280
scripts/cache_aware_proxy.py
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
"""Unified cache-aware + token-level load-balanced global scheduler.
|
||||||
|
|
||||||
|
Supports two modes:
|
||||||
|
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
|
||||||
|
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
|
||||||
|
|
||||||
|
Routing policy (same for both modes):
|
||||||
|
score = ongoing_tokens / avg_ongoing - ALPHA * cache_hit_ratio
|
||||||
|
Normalized load prevents "rich get richer"; cache bonus gives affinity.
|
||||||
|
Session affinity: multi-turn sessions stick to same instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import urllib.parse
|
||||||
|
import uuid
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
BLOCK_SIZE = 512
|
||||||
|
CACHE_HIT_ALPHA = 1.0 # weight for cache bonus in scoring
|
||||||
|
|
||||||
|
|
||||||
|
class InstanceState:
|
||||||
|
def __init__(self, url: str, bootstrap_port: int | None = None):
|
||||||
|
self.url = url
|
||||||
|
self.bootstrap_port = bootstrap_port
|
||||||
|
self.client = httpx.AsyncClient(
|
||||||
|
timeout=None, base_url=url,
|
||||||
|
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
|
||||||
|
)
|
||||||
|
self.ongoing_tokens = 0
|
||||||
|
self.engine_id: dict[int, str] = {}
|
||||||
|
self.dp_size = 1
|
||||||
|
self.cached_blocks: set[int] = set()
|
||||||
|
|
||||||
|
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
|
||||||
|
if not token_ids or len(token_ids) < BLOCK_SIZE:
|
||||||
|
return 0
|
||||||
|
hit = 0
|
||||||
|
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
|
||||||
|
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
|
||||||
|
if bh in self.cached_blocks:
|
||||||
|
hit += BLOCK_SIZE
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return hit
|
||||||
|
|
||||||
|
def record_prefix(self, token_ids: list[int] | None):
|
||||||
|
if not token_ids:
|
||||||
|
return
|
||||||
|
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
|
||||||
|
self.cached_blocks.add(hash(tuple(token_ids[i:i + BLOCK_SIZE])))
|
||||||
|
if len(self.cached_blocks) > 200000:
|
||||||
|
self.cached_blocks = set(list(self.cached_blocks)[-100000:])
|
||||||
|
|
||||||
|
|
||||||
|
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
|
||||||
|
session_id: str | None, input_length: int,
|
||||||
|
affinity: dict[str, int]) -> tuple[InstanceState, int]:
|
||||||
|
"""Normalized load - cache bonus scoring."""
|
||||||
|
if session_id and session_id in affinity:
|
||||||
|
idx = affinity[session_id]
|
||||||
|
if idx < len(instances):
|
||||||
|
return instances[idx], idx
|
||||||
|
|
||||||
|
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
|
||||||
|
best_idx, best_score = 0, float("inf")
|
||||||
|
for i, inst in enumerate(instances):
|
||||||
|
cache_hit = inst.estimate_cache_hit(token_ids)
|
||||||
|
cache_ratio = cache_hit / input_length if input_length > 0 else 0.0
|
||||||
|
score = inst.ongoing_tokens / avg_load - CACHE_HIT_ALPHA * cache_ratio
|
||||||
|
if score < best_score:
|
||||||
|
best_score = score
|
||||||
|
best_idx = i
|
||||||
|
|
||||||
|
if session_id:
|
||||||
|
affinity[session_id] = best_idx
|
||||||
|
return instances[best_idx], best_idx
|
||||||
|
|
||||||
|
|
||||||
|
global_args = None
|
||||||
|
combined_instances: list[InstanceState] = []
|
||||||
|
prefill_instances: list[InstanceState] = []
|
||||||
|
decode_instances: list[InstanceState] = []
|
||||||
|
session_affinity: dict[str, int] = {}
|
||||||
|
is_pd_sep = False
|
||||||
|
|
||||||
|
|
||||||
|
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
|
||||||
|
for inst in instances:
|
||||||
|
if inst.bootstrap_port is None:
|
||||||
|
continue
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await inst.client.get("/health")
|
||||||
|
except Exception:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
continue
|
||||||
|
parsed = urllib.parse.urlparse(str(inst.client.base_url))
|
||||||
|
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
|
||||||
|
resp = await inst.client.get(url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
for dp_rank, dp_entry in data.items():
|
||||||
|
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
|
||||||
|
inst.dp_size = len(data)
|
||||||
|
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
|
||||||
|
break
|
||||||
|
ready.set()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
global is_pd_sep
|
||||||
|
app.state.ready = asyncio.Event()
|
||||||
|
|
||||||
|
if global_args.combined:
|
||||||
|
is_pd_sep = False
|
||||||
|
for url in global_args.combined:
|
||||||
|
combined_instances.append(InstanceState(url))
|
||||||
|
app.state.ready.set()
|
||||||
|
print(f"Combined mode: {len(combined_instances)} instances")
|
||||||
|
else:
|
||||||
|
is_pd_sep = True
|
||||||
|
for url, bp in global_args.prefill:
|
||||||
|
prefill_instances.append(InstanceState(url, bp))
|
||||||
|
for url in global_args.decode:
|
||||||
|
decode_instances.append(InstanceState(url))
|
||||||
|
asyncio.create_task(init_prefill_bootstrap(prefill_instances, app.state.ready))
|
||||||
|
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
|
||||||
|
|
||||||
|
yield
|
||||||
|
for inst in combined_instances + prefill_instances + decode_instances:
|
||||||
|
await inst.client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/completions")
|
||||||
|
async def handle_completions(request: Request):
|
||||||
|
return await _handle(request, "/v1/completions")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def handle_chat(request: Request):
|
||||||
|
return await _handle(request, "/v1/chat/completions")
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle(request: Request, api: str):
|
||||||
|
if not app.state.ready.is_set():
|
||||||
|
raise HTTPException(status_code=503, detail="Service Unavailable")
|
||||||
|
|
||||||
|
req_data = await request.json()
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
prompt = req_data.get("prompt")
|
||||||
|
token_ids = prompt if isinstance(prompt, list) else None
|
||||||
|
input_length = len(token_ids) if token_ids else 0
|
||||||
|
session_id = request.headers.get("X-Session-Id")
|
||||||
|
|
||||||
|
headers = {"X-Request-Id": request_id}
|
||||||
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
if is_pd_sep:
|
||||||
|
return await _handle_pd_sep(api, req_data, request_id, token_ids,
|
||||||
|
input_length, session_id, headers)
|
||||||
|
else:
|
||||||
|
return await _handle_combined(api, req_data, token_ids,
|
||||||
|
input_length, session_id, headers)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
|
||||||
|
"""Combined mode: route to best instance, send normal request."""
|
||||||
|
inst, idx = pick_instance(combined_instances, token_ids, session_id,
|
||||||
|
input_length, session_affinity)
|
||||||
|
inst.ongoing_tokens += input_length
|
||||||
|
|
||||||
|
async def generate():
|
||||||
|
try:
|
||||||
|
async with inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
async for chunk in resp.aiter_bytes():
|
||||||
|
yield chunk
|
||||||
|
inst.record_prefix(token_ids)
|
||||||
|
finally:
|
||||||
|
inst.ongoing_tokens -= input_length
|
||||||
|
|
||||||
|
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
||||||
|
session_id, headers):
|
||||||
|
"""PD-Sep mode: await prefill, then stream decode."""
|
||||||
|
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
|
||||||
|
input_length, session_affinity)
|
||||||
|
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
|
||||||
|
|
||||||
|
# Await prefill
|
||||||
|
p_inst.ongoing_tokens += input_length
|
||||||
|
try:
|
||||||
|
prefill_data = req_data.copy()
|
||||||
|
prefill_data["kv_transfer_params"] = {
|
||||||
|
"do_remote_decode": True, "do_remote_prefill": False,
|
||||||
|
"transfer_id": f"xfer-{request_id}",
|
||||||
|
}
|
||||||
|
prefill_data["stream"] = False
|
||||||
|
prefill_data["max_tokens"] = 1
|
||||||
|
prefill_data.pop("max_completion_tokens", None)
|
||||||
|
prefill_data.pop("stream_options", None)
|
||||||
|
|
||||||
|
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||||||
|
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
|
||||||
|
resp.raise_for_status()
|
||||||
|
await resp.aclose()
|
||||||
|
p_inst.record_prefix(token_ids)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
|
||||||
|
finally:
|
||||||
|
p_inst.ongoing_tokens -= input_length
|
||||||
|
|
||||||
|
# Stream decode
|
||||||
|
d_inst.ongoing_tokens += input_length
|
||||||
|
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
|
||||||
|
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
|
||||||
|
|
||||||
|
decode_data = req_data.copy()
|
||||||
|
decode_data["kv_transfer_params"] = {
|
||||||
|
"do_remote_decode": False, "do_remote_prefill": True,
|
||||||
|
"remote_bootstrap_addr": bootstrap_addr,
|
||||||
|
"remote_engine_id": p_inst.engine_id.get(0, ""),
|
||||||
|
"transfer_id": f"xfer-{request_id}",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def generate():
|
||||||
|
try:
|
||||||
|
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
async for chunk in resp.aiter_bytes():
|
||||||
|
yield chunk
|
||||||
|
finally:
|
||||||
|
d_inst.ongoing_tokens -= input_length
|
||||||
|
|
||||||
|
return StreamingResponse(generate(), media_type="application/json")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
|
||||||
|
p.add_argument("--port", type=int, default=8000)
|
||||||
|
p.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
|
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
|
||||||
|
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
|
||||||
|
help="PD-Sep prefill: URL [bootstrap_port]")
|
||||||
|
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
|
||||||
|
help="PD-Sep decode: URL")
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
args.prefill = []
|
||||||
|
if args.prefill_raw:
|
||||||
|
for entry in args.prefill_raw:
|
||||||
|
url = entry[0]
|
||||||
|
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
|
||||||
|
args.prefill.append((url, bp))
|
||||||
|
args.decode = [e[0] for e in (args.decode_raw or [])]
|
||||||
|
|
||||||
|
if not args.combined and not args.prefill:
|
||||||
|
p.error("Must specify either --combined or --prefill/--decode")
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
global_args = parse_args()
|
||||||
|
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||||
102
scripts/compare_results.py
Normal file
102
scripts/compare_results.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""Compare benchmark results between PD-combined and PD-separated modes.
|
||||||
|
|
||||||
|
Reads summary JSON files and per-request metrics to produce a detailed
|
||||||
|
comparison report including TTFT, TPOT, E2E, cache hit ratio, and
|
||||||
|
throughput analysis.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/compare_results.py \
|
||||||
|
--combined outputs/combined_1000req/metrics.summary.json \
|
||||||
|
--separated outputs/separated_1000req/metrics.summary.json
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def load_summary(path: Path) -> dict:
|
||||||
|
return json.loads(path.read_text())
|
||||||
|
|
||||||
|
|
||||||
|
def load_metrics(path: Path) -> list[dict]:
|
||||||
|
rows = []
|
||||||
|
with path.open() as fh:
|
||||||
|
for line in fh:
|
||||||
|
rows.append(json.loads(line))
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def fmt_stat(stat: dict | None, unit: str = "s") -> str:
|
||||||
|
if stat is None:
|
||||||
|
return "N/A"
|
||||||
|
return (f"mean={stat['mean']:.3f}{unit} "
|
||||||
|
f"p50={stat['p50']:.3f}{unit} "
|
||||||
|
f"p90={stat['p90']:.3f}{unit} "
|
||||||
|
f"p99={stat['p99']:.3f}{unit}")
|
||||||
|
|
||||||
|
|
||||||
|
def compare(combined: dict, separated: dict) -> None:
|
||||||
|
print("=" * 70)
|
||||||
|
print("PD-Combined vs PD-Separated Performance Comparison")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
for label, s in [("PD-Combined", combined), ("PD-Separated", separated)]:
|
||||||
|
print(f"\n--- {label} ---")
|
||||||
|
print(f" Requests: {s['request_count']} (success: {s['success_count']}, errors: {s['error_count']})")
|
||||||
|
print(f" Wall clock: {s.get('wall_clock_s', 0):.1f}s")
|
||||||
|
print(f" TTFT: {fmt_stat(s.get('ttft_stats_s'))}")
|
||||||
|
print(f" TPOT: {fmt_stat(s.get('tpot_stats_s'))}")
|
||||||
|
print(f" E2E: {fmt_stat(s.get('latency_stats_s'))}")
|
||||||
|
hit_ratio = s.get('prefix_cache_hit_ratio', 0)
|
||||||
|
print(f" Prefix cache hit ratio: {hit_ratio*100:.1f}%")
|
||||||
|
queries = s.get('prefix_cache_queries_tokens', 0)
|
||||||
|
hits = s.get('prefix_cache_hits_tokens', 0)
|
||||||
|
print(f" ({hits}/{queries} tokens)")
|
||||||
|
|
||||||
|
print("\n--- Comparison (Separated vs Combined) ---")
|
||||||
|
for metric_key, label in [
|
||||||
|
("ttft_stats_s", "TTFT"),
|
||||||
|
("tpot_stats_s", "TPOT"),
|
||||||
|
("latency_stats_s", "E2E"),
|
||||||
|
]:
|
||||||
|
c = combined.get(metric_key, {})
|
||||||
|
s = separated.get(metric_key, {})
|
||||||
|
if c and s:
|
||||||
|
for pct in ["mean", "p50", "p90", "p99"]:
|
||||||
|
cv, sv = c.get(pct, 0), s.get(pct, 0)
|
||||||
|
if cv > 0:
|
||||||
|
change = (sv - cv) / cv * 100
|
||||||
|
direction = "slower" if change > 0 else "faster"
|
||||||
|
print(f" {label} {pct}: {abs(change):.1f}% {direction} "
|
||||||
|
f"({cv:.3f}s → {sv:.3f}s)")
|
||||||
|
|
||||||
|
c_ratio = combined.get("prefix_cache_hit_ratio", 0)
|
||||||
|
s_ratio = separated.get("prefix_cache_hit_ratio", 0)
|
||||||
|
print(f" Cache hit ratio: {c_ratio*100:.1f}% → {s_ratio*100:.1f}%")
|
||||||
|
|
||||||
|
c_wall = combined.get("wall_clock_s", 1)
|
||||||
|
s_wall = separated.get("wall_clock_s", 1)
|
||||||
|
c_tput = combined["success_count"] / c_wall
|
||||||
|
s_tput = separated["success_count"] / s_wall
|
||||||
|
print(f" Throughput: {c_tput:.1f} → {s_tput:.1f} req/s "
|
||||||
|
f"({(s_tput/c_tput - 1)*100:+.1f}%)")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser(description=__doc__,
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||||
|
p.add_argument("--combined", type=Path, required=True)
|
||||||
|
p.add_argument("--separated", type=Path, required=True)
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
combined = load_summary(args.combined)
|
||||||
|
separated = load_summary(args.separated)
|
||||||
|
compare(combined, separated)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
210
scripts/compute_roofline.py
Normal file
210
scripts/compute_roofline.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""Roofline analysis: compute/memory ratio for prefill vs decode
|
||||||
|
under different sequence lengths and KV cache reuse ratios.
|
||||||
|
|
||||||
|
Model: Qwen3-Coder-30B-A3B (MoE)
|
||||||
|
- 48 layers, hidden=2048, heads=32, kv_heads=4, head_dim=128
|
||||||
|
- MoE: 128 experts, top-8 active, intermediate=6144
|
||||||
|
- Total params: ~30B, Active params per token: ~3B
|
||||||
|
|
||||||
|
GPU: NVIDIA H20
|
||||||
|
- BF16 peak: 148 TFLOPS
|
||||||
|
- HBM bandwidth: 4.0 TB/s
|
||||||
|
- Roofline ridge point: 148/4.0 = 37 FLOP/byte
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
|
||||||
|
# ===== Model config =====
|
||||||
|
L = 48 # layers
|
||||||
|
D = 2048 # hidden dim
|
||||||
|
H = 32 # attention heads
|
||||||
|
H_kv = 4 # KV heads (GQA)
|
||||||
|
D_head = 128 # head dim
|
||||||
|
D_ffn = 6144 # FFN intermediate (per expert)
|
||||||
|
N_experts = 128 # total experts
|
||||||
|
K_experts = 8 # active experts per token
|
||||||
|
VOCAB = 151936
|
||||||
|
BYTES = 2 # BF16
|
||||||
|
|
||||||
|
# ===== GPU config (H20) =====
|
||||||
|
PEAK_FLOPS = 148e12 # BF16 TFLOPS
|
||||||
|
HBM_BW = 4.0e12 # bytes/s
|
||||||
|
RIDGE_POINT = PEAK_FLOPS / HBM_BW # ~37 FLOP/byte
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print(" ROOFLINE ANALYSIS: Prefill vs Decode under KV Cache Reuse")
|
||||||
|
print(" Model: Qwen3-Coder-30B-A3B (MoE 128E top-8) | GPU: H20")
|
||||||
|
print("=" * 80)
|
||||||
|
print(f" Ridge point: {RIDGE_POINT:.1f} FLOP/byte")
|
||||||
|
print(f" Above ridge → compute-bound | Below ridge → memory-bound")
|
||||||
|
|
||||||
|
# ===== Per-token compute & memory for each component =====
|
||||||
|
|
||||||
|
def attention_prefill_flops(seq_len, new_tokens):
|
||||||
|
"""FLOPs for attention on new_tokens with seq_len context."""
|
||||||
|
# QKV projection: new_tokens * D * (D + 2*D_kv) * 2
|
||||||
|
d_kv = H_kv * D_head
|
||||||
|
qkv_flops = new_tokens * (D * D * 2 + D * d_kv * 2 * 2) # Q + K + V
|
||||||
|
# Attention score: new_tokens * seq_len * D * 2 (Q@K^T + softmax@V)
|
||||||
|
attn_flops = new_tokens * seq_len * D * 2 * 2 # simplified: 2 matmuls
|
||||||
|
# Output projection: new_tokens * D * D * 2
|
||||||
|
out_flops = new_tokens * D * D * 2
|
||||||
|
return (qkv_flops + attn_flops + out_flops) * L
|
||||||
|
|
||||||
|
def attention_prefill_bytes(seq_len, new_tokens, cached_tokens):
|
||||||
|
"""Memory access for attention prefill."""
|
||||||
|
d_kv = H_kv * D_head
|
||||||
|
# Load model weights (QKV + O projections): D*(D+2*d_kv+D) * BYTES * L
|
||||||
|
weight_bytes = D * (D + 2 * d_kv + D) * BYTES * L
|
||||||
|
# Load cached KV: cached_tokens * 2 * d_kv * BYTES * L
|
||||||
|
cached_kv_bytes = cached_tokens * 2 * d_kv * BYTES * L
|
||||||
|
# Read input activations + write output: new_tokens * D * BYTES * 2 * L
|
||||||
|
act_bytes = new_tokens * D * BYTES * 2 * L
|
||||||
|
# Write new KV to cache: new_tokens * 2 * d_kv * BYTES * L
|
||||||
|
new_kv_bytes = new_tokens * 2 * d_kv * BYTES * L
|
||||||
|
return weight_bytes + cached_kv_bytes + act_bytes + new_kv_bytes
|
||||||
|
|
||||||
|
def ffn_flops(n_tokens):
|
||||||
|
"""FLOPs for MoE FFN on n_tokens."""
|
||||||
|
# Per expert: 3 * n_tokens * D * D_ffn * 2 (gate + up + down)
|
||||||
|
# Active experts: K_experts
|
||||||
|
return 3 * n_tokens * D * D_ffn * 2 * K_experts * L
|
||||||
|
|
||||||
|
def ffn_bytes(n_tokens):
|
||||||
|
"""Memory access for MoE FFN."""
|
||||||
|
# Load K_experts worth of weights per layer: K * 3 * D * D_ffn * BYTES
|
||||||
|
weight_bytes = K_experts * 3 * D * D_ffn * BYTES * L
|
||||||
|
# Activations: n_tokens * D * BYTES * 2 * L
|
||||||
|
act_bytes = n_tokens * D * BYTES * 2 * L
|
||||||
|
return weight_bytes + act_bytes
|
||||||
|
|
||||||
|
def decode_flops(seq_len):
|
||||||
|
"""FLOPs for 1 decode token."""
|
||||||
|
return attention_prefill_flops(seq_len, 1) + ffn_flops(1)
|
||||||
|
|
||||||
|
def decode_bytes(seq_len):
|
||||||
|
"""Memory bytes for 1 decode token."""
|
||||||
|
return attention_prefill_bytes(seq_len, 1, seq_len) + ffn_bytes(1)
|
||||||
|
|
||||||
|
# ===== Analysis =====
|
||||||
|
|
||||||
|
print("\n" + "-" * 80)
|
||||||
|
print(" PART 1: Decode Roofline (baseline)")
|
||||||
|
print("-" * 80)
|
||||||
|
print(f" {'SeqLen':>8} {'FLOP':>14} {'Bytes':>14} {'AI (F/B)':>10} {'Bound':>12}")
|
||||||
|
|
||||||
|
for seq_len in [1000, 4000, 8000, 16000, 32000, 64000, 128000]:
|
||||||
|
flops = decode_flops(seq_len)
|
||||||
|
bytes_ = decode_bytes(seq_len)
|
||||||
|
ai = flops / bytes_
|
||||||
|
bound = "COMPUTE" if ai > RIDGE_POINT else "MEMORY"
|
||||||
|
print(f" {seq_len:>8,} {flops:>14.2e} {bytes_:>14.2e} {ai:>10.1f} {bound:>12}")
|
||||||
|
|
||||||
|
print("\n" + "-" * 80)
|
||||||
|
print(" PART 2: Prefill with KV Cache Reuse")
|
||||||
|
print(" (Total input = seq_len, cached = seq_len * reuse_ratio, new = rest)")
|
||||||
|
print("-" * 80)
|
||||||
|
print(f" {'SeqLen':>8} {'Reuse%':>7} {'NewTok':>8} {'FLOP':>14} {'Bytes':>14} {'AI (F/B)':>10} {'Bound':>12} {'vs Decode':>10}")
|
||||||
|
|
||||||
|
for seq_len in [4000, 16000, 32000, 64000, 128000]:
|
||||||
|
for reuse in [0.0, 0.3, 0.5, 0.7, 0.9, 0.95]:
|
||||||
|
cached = int(seq_len * reuse)
|
||||||
|
new = seq_len - cached
|
||||||
|
|
||||||
|
# Attention: compute on new tokens, but read cached KV for context
|
||||||
|
attn_f = attention_prefill_flops(seq_len, new)
|
||||||
|
attn_b = attention_prefill_bytes(seq_len, new, cached)
|
||||||
|
|
||||||
|
# FFN: only on new tokens
|
||||||
|
ffn_f = ffn_flops(new)
|
||||||
|
ffn_b = ffn_bytes(new)
|
||||||
|
|
||||||
|
total_f = attn_f + ffn_f
|
||||||
|
total_b = attn_b + ffn_b
|
||||||
|
ai = total_f / total_b if total_b > 0 else 0
|
||||||
|
|
||||||
|
# Compare with decode at same seq_len
|
||||||
|
dec_f = decode_flops(seq_len)
|
||||||
|
dec_b = decode_bytes(seq_len)
|
||||||
|
dec_ai = dec_f / dec_b
|
||||||
|
|
||||||
|
bound = "COMPUTE" if ai > RIDGE_POINT else "MEMORY"
|
||||||
|
ratio = f"{ai/dec_ai:.1f}x" if dec_ai > 0 else "N/A"
|
||||||
|
|
||||||
|
print(f" {seq_len:>8,} {reuse*100:>6.0f}% {new:>8,} {total_f:>14.2e} {total_b:>14.2e} {ai:>10.1f} {bound:>12} {ratio:>10}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("-" * 80)
|
||||||
|
print(" PART 3: Key Thresholds")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
# At what reuse ratio does prefill become memory-bound?
|
||||||
|
for seq_len in [4000, 16000, 32000, 64000, 128000]:
|
||||||
|
for reuse_pct in range(0, 100):
|
||||||
|
reuse = reuse_pct / 100.0
|
||||||
|
cached = int(seq_len * reuse)
|
||||||
|
new = seq_len - cached
|
||||||
|
if new < 1: continue
|
||||||
|
attn_f = attention_prefill_flops(seq_len, new)
|
||||||
|
attn_b = attention_prefill_bytes(seq_len, new, cached)
|
||||||
|
ffn_f = ffn_flops(new)
|
||||||
|
ffn_b = ffn_bytes(new)
|
||||||
|
ai = (attn_f + ffn_f) / (attn_b + ffn_b)
|
||||||
|
if ai < RIDGE_POINT:
|
||||||
|
print(f" SeqLen={seq_len:>6,}: prefill becomes memory-bound at {reuse_pct}% reuse (AI={ai:.1f})")
|
||||||
|
break
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("-" * 80)
|
||||||
|
print(" PART 4: Agentic Workload Real Distribution")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
# Use actual trace data
|
||||||
|
import os
|
||||||
|
trace_path = "traces/sampled_1000req_seed42.jsonl"
|
||||||
|
if os.path.exists(trace_path):
|
||||||
|
BLOCK_SIZE = 512
|
||||||
|
seen = set()
|
||||||
|
compute_bound = 0
|
||||||
|
memory_bound = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
for line in open(trace_path):
|
||||||
|
d = json.loads(line)
|
||||||
|
seq_len = d["input_length"]
|
||||||
|
if seq_len < 1: continue
|
||||||
|
hids = d.get("hash_ids", [])
|
||||||
|
|
||||||
|
cached_blocks = 0
|
||||||
|
for hid in hids:
|
||||||
|
if hid in seen:
|
||||||
|
cached_blocks += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
for hid in hids:
|
||||||
|
seen.add(hid)
|
||||||
|
|
||||||
|
cached = cached_blocks * BLOCK_SIZE
|
||||||
|
new = max(1, seq_len - cached)
|
||||||
|
reuse = cached / seq_len
|
||||||
|
|
||||||
|
attn_f = attention_prefill_flops(seq_len, new)
|
||||||
|
attn_b = attention_prefill_bytes(seq_len, new, cached)
|
||||||
|
ffn_f = ffn_flops(new)
|
||||||
|
ffn_b = ffn_bytes(new)
|
||||||
|
ai = (attn_f + ffn_f) / (attn_b + ffn_b)
|
||||||
|
|
||||||
|
total += 1
|
||||||
|
if ai > RIDGE_POINT:
|
||||||
|
compute_bound += 1
|
||||||
|
else:
|
||||||
|
memory_bound += 1
|
||||||
|
|
||||||
|
print(f" With actual trace prefix cache pattern:")
|
||||||
|
print(f" Compute-bound prefills: {compute_bound} ({compute_bound*100//total}%)")
|
||||||
|
print(f" Memory-bound prefills: {memory_bound} ({memory_bound*100//total}%)")
|
||||||
|
print(f" (Decode is ALWAYS memory-bound at these seq lengths)")
|
||||||
|
print()
|
||||||
|
print(f" Implication: {memory_bound*100//total}% of agentic prefills behave like decode")
|
||||||
|
print(f" → PD separation treats them as 'compute-heavy' but they are actually memory-heavy")
|
||||||
86
scripts/final_comparison.py
Normal file
86
scripts/final_comparison.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""Final comparison of PD-Combined vs PD-Separated (Mooncake/RDMA)."""
|
||||||
|
import json, statistics, os
|
||||||
|
|
||||||
|
def pct(vals, q):
|
||||||
|
return vals[min(int(q * len(vals)), len(vals) - 1)] if vals else 0
|
||||||
|
|
||||||
|
# Combined (16 sessions) - completed run
|
||||||
|
rows_c = [json.loads(l) for l in open("outputs/v18_combined_1000req/metrics.jsonl")]
|
||||||
|
ok_c = [r for r in rows_c if not r.get("error")]
|
||||||
|
ttfts_c = sorted([r["ttft_s"] for r in ok_c if r.get("ttft_s")])
|
||||||
|
tpots_c = sorted([r["tpot_s"] for r in ok_c if r.get("tpot_s") and r["tpot_s"] > 0])
|
||||||
|
lats_c = sorted([r["latency_s"] for r in ok_c if r.get("latency_s")])
|
||||||
|
sc = json.load(open("outputs/v18_combined_1000req/metrics.summary.json"))
|
||||||
|
|
||||||
|
# PD-Separated Mooncake (first 200 stable requests)
|
||||||
|
rows_d = [json.loads(l) for l in open("outputs/v18_pd_mooncake_lowconc/metrics.jsonl")][:200]
|
||||||
|
ok_d = [r for r in rows_d if not r.get("error")]
|
||||||
|
ttfts_d = sorted([r["ttft_s"] for r in ok_d if r.get("ttft_s")])
|
||||||
|
tpots_d = sorted([r["tpot_s"] for r in ok_d if r.get("tpot_s") and r["tpot_s"] > 0])
|
||||||
|
lats_d = sorted([r["latency_s"] for r in ok_d if r.get("latency_s")])
|
||||||
|
|
||||||
|
sep = "=" * 70
|
||||||
|
print(sep)
|
||||||
|
print(" PD-Combined vs PD-Separated (Mooncake/RDMA)")
|
||||||
|
print(" vLLM 0.18.1 | Qwen3-Coder-30B-A3B | 8xH20")
|
||||||
|
print(sep)
|
||||||
|
|
||||||
|
header = " {:<12} {:>16} {:>16} {:>10}".format(
|
||||||
|
"Metric", "Combined(TP=8)", "PD-Sep(TP=4+4)", "Delta")
|
||||||
|
print(header)
|
||||||
|
dash = " {:<12} {:>16} {:>16} {:>10}".format("-" * 12, "-" * 16, "-" * 16, "-" * 10)
|
||||||
|
print(dash)
|
||||||
|
|
||||||
|
req_c = "{}/{}".format(len(ok_c), len(rows_c))
|
||||||
|
req_d = "{}/{}".format(len(ok_d), len(rows_d))
|
||||||
|
print(" {:<12} {:>16} {:>16}".format("Requests", req_c, req_d))
|
||||||
|
|
||||||
|
data = [
|
||||||
|
("TTFT p50", pct(ttfts_c, 0.5), pct(ttfts_d, 0.5)),
|
||||||
|
("TTFT p90", pct(ttfts_c, 0.9), pct(ttfts_d, 0.9)),
|
||||||
|
("TPOT p50", pct(tpots_c, 0.5), pct(tpots_d, 0.5)),
|
||||||
|
("TPOT p90", pct(tpots_c, 0.9), pct(tpots_d, 0.9)),
|
||||||
|
("E2E p50", pct(lats_c, 0.5), pct(lats_d, 0.5)),
|
||||||
|
("E2E p90", pct(lats_c, 0.9), pct(lats_d, 0.9)),
|
||||||
|
]
|
||||||
|
|
||||||
|
for label, cv, dv in data:
|
||||||
|
delta = "{:+.0f}%".format((dv / cv - 1) * 100) if cv > 0 else "N/A"
|
||||||
|
print(" {:<12} {:>15.3f}s {:>15.3f}s {:>10}".format(label, cv, dv, delta))
|
||||||
|
|
||||||
|
cache_c = sc.get("prefix_cache_hit_ratio", 0)
|
||||||
|
print(" {:<12} {:>15.1f}% {:>16}".format("Cache hit", cache_c * 100, "N/A"))
|
||||||
|
tput_c = len(ok_c) / sc.get("wall_clock_s", 1)
|
||||||
|
print(" {:<12} {:>14.2f}/s {:>16}".format("Throughput", tput_c, "~0.06/s"))
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(sep)
|
||||||
|
print(" CONCLUSIONS FOR AGENTIC WORKLOAD")
|
||||||
|
print(sep)
|
||||||
|
print()
|
||||||
|
print(" Trace characteristics:")
|
||||||
|
print(" - I/O ratio: 61.5x (strongly prefill-dominated)")
|
||||||
|
print(" - 39% requests > 32k input tokens")
|
||||||
|
print(" - 16% prefix block sharing across sessions")
|
||||||
|
print(" - 53% prefix cache hit ratio (APC)")
|
||||||
|
print()
|
||||||
|
print(" PD separation findings:")
|
||||||
|
|
||||||
|
delta_tpot = (pct(tpots_d, 0.5) / pct(tpots_c, 0.5) - 1) * 100 if tpots_c else 0
|
||||||
|
delta_ttft = (pct(ttfts_d, 0.5) / pct(ttfts_c, 0.5) - 1) * 100 if ttfts_c else 0
|
||||||
|
delta_e2e = (pct(lats_d, 0.5) / pct(lats_c, 0.5) - 1) * 100 if lats_c else 0
|
||||||
|
|
||||||
|
print(" 1. TPOT {:+.0f}% - decode isolation benefit is {}".format(
|
||||||
|
delta_tpot, "marginal" if abs(delta_tpot) < 20 else "significant"))
|
||||||
|
print(" 2. TTFT {:+.0f}% - KV transfer + TP=4 overhead dominates".format(delta_ttft))
|
||||||
|
print(" 3. E2E {:+.0f}% - net negative on single-machine".format(delta_e2e))
|
||||||
|
print(" 4. Stability: Mooncake connector crashes after ~200 reqs under load")
|
||||||
|
print()
|
||||||
|
print(" Recommendation:")
|
||||||
|
print(" - Single-machine 8 GPU: Combined mode is better (lower TTFT, stable)")
|
||||||
|
print(" - Multi-machine: PD-Sep is promising IF cross-machine latency")
|
||||||
|
print(" is hidden by RDMA and prefill doesn't share GPU with decode")
|
||||||
|
print(" - Key bottleneck: this workload's heavy prefill (avg 32k tokens)")
|
||||||
|
print(" makes KV transfer cost non-trivial relative to prefill time")
|
||||||
|
print(" - Prefill-as-a-Service (Goal 5) should focus on cross-machine")
|
||||||
|
print(" KV cache sharing, not same-machine PD split")
|
||||||
96
scripts/launch_pd_mooncake.sh
Executable file
96
scripts/launch_pd_mooncake.sh
Executable file
@@ -0,0 +1,96 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# PD-Disaggregated serving via Mooncake (RDMA + DRAM KV pool).
|
||||||
|
#
|
||||||
|
# Architecture:
|
||||||
|
# Client → Proxy (port 8000)
|
||||||
|
# → Prefill (port 8010, TP=4, GPUs 0-3, bootstrap 8998)
|
||||||
|
# [prefill + store KV to DRAM pool via RDMA]
|
||||||
|
# → Decode (port 8020, TP=4, GPUs 4-7)
|
||||||
|
# [pull KV from DRAM pool via RDMA + decode]
|
||||||
|
#
|
||||||
|
# Usage: bash scripts/launch_pd_mooncake.sh
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||||
|
VENV="$PROJECT_DIR/.venv/bin"
|
||||||
|
VLLM="$VENV/vllm"
|
||||||
|
|
||||||
|
MODEL_PATH="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
||||||
|
PROXY_PORT=8000
|
||||||
|
PREFILL_PORT=8010
|
||||||
|
DECODE_PORT=8020
|
||||||
|
BOOTSTRAP_PORT=8998
|
||||||
|
|
||||||
|
PROXY_SCRIPT="$PROJECT_DIR/third_party/vllm/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py"
|
||||||
|
|
||||||
|
trap 'echo "Cleaning up..."; kill $(jobs -p) 2>/dev/null; wait 2>/dev/null' EXIT INT TERM
|
||||||
|
|
||||||
|
echo "=== PD-Disaggregated vLLM 0.18.1 (Mooncake/RDMA) ==="
|
||||||
|
echo " Model: $MODEL_PATH"
|
||||||
|
echo " Prefill: GPUs 0-3 (TP=4), port $PREFILL_PORT, bootstrap $BOOTSTRAP_PORT"
|
||||||
|
echo " Decode: GPUs 4-7 (TP=4), port $DECODE_PORT"
|
||||||
|
echo " Proxy: port $PROXY_PORT"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Step 1: Start prefill instance (KV producer)
|
||||||
|
echo "[1/3] Starting prefill instance..."
|
||||||
|
VLLM_MOONCAKE_BOOTSTRAP_PORT=$BOOTSTRAP_PORT \
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
|
$VLLM serve "$MODEL_PATH" \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port $PREFILL_PORT \
|
||||||
|
--tensor-parallel-size 4 \
|
||||||
|
--trust-remote-code \
|
||||||
|
--enable-prefix-caching \
|
||||||
|
--enforce-eager \
|
||||||
|
--dtype auto \
|
||||||
|
--gpu-memory-utilization 0.9 \
|
||||||
|
--kv-transfer-config \
|
||||||
|
'{"kv_connector":"MooncakeConnector","kv_role":"kv_producer"}' &
|
||||||
|
PREFILL_PID=$!
|
||||||
|
echo " Prefill PID=$PREFILL_PID"
|
||||||
|
|
||||||
|
# Step 2: Start decode instance (KV consumer)
|
||||||
|
echo "[2/3] Starting decode instance..."
|
||||||
|
CUDA_VISIBLE_DEVICES=4,5,6,7 \
|
||||||
|
$VLLM serve "$MODEL_PATH" \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port $DECODE_PORT \
|
||||||
|
--tensor-parallel-size 4 \
|
||||||
|
--trust-remote-code \
|
||||||
|
--enable-prefix-caching \
|
||||||
|
--enforce-eager \
|
||||||
|
--dtype auto \
|
||||||
|
--gpu-memory-utilization 0.8 \
|
||||||
|
--kv-transfer-config \
|
||||||
|
'{"kv_connector":"MooncakeConnector","kv_role":"kv_consumer"}' &
|
||||||
|
DECODE_PID=$!
|
||||||
|
echo " Decode PID=$DECODE_PID"
|
||||||
|
|
||||||
|
# Wait for both instances
|
||||||
|
echo ""
|
||||||
|
echo "Waiting for instances..."
|
||||||
|
timeout 1200 bash -c "until curl -s localhost:$PREFILL_PORT/v1/models > /dev/null 2>&1; do sleep 5; done"
|
||||||
|
echo " Prefill ready!"
|
||||||
|
timeout 1200 bash -c "until curl -s localhost:$DECODE_PORT/v1/models > /dev/null 2>&1; do sleep 5; done"
|
||||||
|
echo " Decode ready!"
|
||||||
|
|
||||||
|
# Step 3: Start proxy (after instances are ready)
|
||||||
|
echo "[3/3] Starting proxy..."
|
||||||
|
$VENV/python "$PROXY_SCRIPT" \
|
||||||
|
--prefill "http://127.0.0.1:$PREFILL_PORT" "$BOOTSTRAP_PORT" \
|
||||||
|
--decode "http://127.0.0.1:$DECODE_PORT" \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port $PROXY_PORT &
|
||||||
|
PROXY_PID=$!
|
||||||
|
echo " Proxy PID=$PROXY_PID"
|
||||||
|
|
||||||
|
sleep 5
|
||||||
|
echo ""
|
||||||
|
echo "=== All ready ==="
|
||||||
|
echo " Send requests to: http://localhost:$PROXY_PORT"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
wait
|
||||||
89
scripts/launch_pd_separated.sh
Normal file
89
scripts/launch_pd_separated.sh
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# PD-Disaggregated serving: 1 prefill (TP=4, GPUs 0-3) + 1 decode (TP=4, GPUs 4-7)
|
||||||
|
# Uses vLLM 0.18.1's P2pNcclConnector + XpYd proxy.
|
||||||
|
#
|
||||||
|
# Architecture:
|
||||||
|
# Client → Proxy (port 10001)
|
||||||
|
# → Prefill (port 20003, kv_port 21001) [max_tokens=1, does prefill + KV push]
|
||||||
|
# → Decode (port 20005, kv_port 22001) [full generation, KV pulled from prefill]
|
||||||
|
#
|
||||||
|
# Usage: bash scripts/launch_pd_separated.sh
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||||
|
VENV="$PROJECT_DIR/.venv/bin"
|
||||||
|
VLLM="$VENV/vllm"
|
||||||
|
|
||||||
|
MODEL_PATH="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
||||||
|
PROXY_PORT=30001 # ZMQ service discovery
|
||||||
|
CLIENT_PORT=10001 # HTTP proxy for clients
|
||||||
|
PREFILL_PORT=20003
|
||||||
|
DECODE_PORT=20005
|
||||||
|
KV_PORT_P=21001
|
||||||
|
KV_PORT_D=22001
|
||||||
|
|
||||||
|
trap 'echo "Cleaning up..."; kill $(jobs -p) 2>/dev/null; wait 2>/dev/null' EXIT INT TERM
|
||||||
|
|
||||||
|
echo "=== PD-Disaggregated vLLM 0.18.1 ==="
|
||||||
|
echo " Model: $MODEL_PATH"
|
||||||
|
echo " Prefill: GPUs 0-3 (TP=4), port $PREFILL_PORT, kv_port $KV_PORT_P"
|
||||||
|
echo " Decode: GPUs 4-7 (TP=4), port $DECODE_PORT, kv_port $KV_PORT_D"
|
||||||
|
echo " Proxy: ZMQ=$PROXY_PORT, HTTP=$CLIENT_PORT"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Step 1: Start proxy FIRST (P/D instances register via ZMQ)
|
||||||
|
echo "[1/3] Starting proxy..."
|
||||||
|
PROXY_SCRIPT="$PROJECT_DIR/third_party/vllm/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py"
|
||||||
|
$VENV/python "$PROXY_SCRIPT" &
|
||||||
|
PROXY_PID=$!
|
||||||
|
sleep 2
|
||||||
|
echo " Proxy PID=$PROXY_PID"
|
||||||
|
|
||||||
|
# Step 2: Start prefill instance (KV producer)
|
||||||
|
echo "[2/3] Starting prefill instance..."
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 $VLLM serve "$MODEL_PATH" \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port $PREFILL_PORT \
|
||||||
|
--tensor-parallel-size 4 \
|
||||||
|
--trust-remote-code \
|
||||||
|
--enable-prefix-caching \
|
||||||
|
--enforce-eager \
|
||||||
|
--dtype auto \
|
||||||
|
--gpu-memory-utilization 0.9 \
|
||||||
|
--kv-transfer-config \
|
||||||
|
"{\"kv_connector\":\"P2pNcclConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_size\":\"1e1\",\"kv_port\":\"$KV_PORT_P\",\"kv_connector_extra_config\":{\"proxy_ip\":\"127.0.0.1\",\"proxy_port\":\"$PROXY_PORT\",\"http_port\":\"$PREFILL_PORT\",\"send_type\":\"PUT_ASYNC\",\"nccl_num_channels\":\"16\"}}" &
|
||||||
|
PREFILL_PID=$!
|
||||||
|
echo " Prefill PID=$PREFILL_PID"
|
||||||
|
|
||||||
|
# Step 3: Start decode instance (KV consumer)
|
||||||
|
echo "[3/3] Starting decode instance..."
|
||||||
|
CUDA_VISIBLE_DEVICES=4,5,6,7 $VLLM serve "$MODEL_PATH" \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port $DECODE_PORT \
|
||||||
|
--tensor-parallel-size 4 \
|
||||||
|
--trust-remote-code \
|
||||||
|
--enable-prefix-caching \
|
||||||
|
--enforce-eager \
|
||||||
|
--dtype auto \
|
||||||
|
--gpu-memory-utilization 0.8 \
|
||||||
|
--kv-transfer-config \
|
||||||
|
"{\"kv_connector\":\"P2pNcclConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_size\":\"8e9\",\"kv_port\":\"$KV_PORT_D\",\"kv_connector_extra_config\":{\"proxy_ip\":\"127.0.0.1\",\"proxy_port\":\"$PROXY_PORT\",\"http_port\":\"$DECODE_PORT\",\"send_type\":\"PUT_ASYNC\",\"nccl_num_channels\":\"16\"}}" &
|
||||||
|
DECODE_PID=$!
|
||||||
|
echo " Decode PID=$DECODE_PID"
|
||||||
|
|
||||||
|
# Wait for readiness
|
||||||
|
echo ""
|
||||||
|
echo "Waiting for instances..."
|
||||||
|
timeout 1200 bash -c "until curl -s localhost:$PREFILL_PORT/v1/completions > /dev/null 2>&1; do sleep 5; done"
|
||||||
|
echo " Prefill ready!"
|
||||||
|
timeout 1200 bash -c "until curl -s localhost:$DECODE_PORT/v1/completions > /dev/null 2>&1; do sleep 5; done"
|
||||||
|
echo " Decode ready!"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== All ready ==="
|
||||||
|
echo " Send requests to: http://localhost:$CLIENT_PORT"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
wait
|
||||||
23
scripts/launch_vllm.sh
Executable file
23
scripts/launch_vllm.sh
Executable file
@@ -0,0 +1,23 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Launch vLLM 0.18.1 in PD-combined mode (TP=8, all GPUs).
|
||||||
|
#
|
||||||
|
# Usage: bash scripts/launch_vllm.sh
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||||
|
VLLM="$PROJECT_DIR/.venv/bin/vllm"
|
||||||
|
|
||||||
|
MODEL_PATH="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
||||||
|
HOST="${HOST:-0.0.0.0}"
|
||||||
|
PORT="${PORT:-8000}"
|
||||||
|
|
||||||
|
echo "Starting vLLM 0.18.1 in PD-combined mode (TP=8) on port $PORT ..."
|
||||||
|
$VLLM serve "$MODEL_PATH" \
|
||||||
|
--trust-remote-code \
|
||||||
|
--enable-prefix-caching \
|
||||||
|
--dtype auto \
|
||||||
|
--tensor-parallel-size 8 \
|
||||||
|
--host "$HOST" \
|
||||||
|
--port "$PORT"
|
||||||
77
scripts/run_benchmark.sh
Executable file
77
scripts/run_benchmark.sh
Executable file
@@ -0,0 +1,77 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Run the full benchmark suite: sample trace → replay against vLLM → collect metrics.
|
||||||
|
#
|
||||||
|
# Prerequisites:
|
||||||
|
# - vLLM server running (use scripts/launch_vllm.sh)
|
||||||
|
# - Sampled trace file exists (or will be created)
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# bash scripts/run_benchmark.sh [--endpoint URL] [--tag NAME]
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||||
|
cd "$PROJECT_DIR"
|
||||||
|
|
||||||
|
# Defaults
|
||||||
|
TRACE_INPUT="${TRACE_INPUT:-$HOME/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl}"
|
||||||
|
ENDPOINT="${ENDPOINT:-http://localhost:8000}"
|
||||||
|
TAG="${TAG:-default}"
|
||||||
|
TARGET_REQUESTS="${TARGET_REQUESTS:-5000}"
|
||||||
|
TIME_SCALE="${TIME_SCALE:-1.0}"
|
||||||
|
MAX_INFLIGHT="${MAX_INFLIGHT:-32}"
|
||||||
|
SEED="${SEED:-42}"
|
||||||
|
|
||||||
|
# Parse args
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case "$1" in
|
||||||
|
--endpoint) ENDPOINT="$2"; shift 2 ;;
|
||||||
|
--tag) TAG="$2"; shift 2 ;;
|
||||||
|
--target-requests) TARGET_REQUESTS="$2"; shift 2 ;;
|
||||||
|
--time-scale) TIME_SCALE="$2"; shift 2 ;;
|
||||||
|
--max-inflight) MAX_INFLIGHT="$2"; shift 2 ;;
|
||||||
|
*) echo "Unknown arg: $1"; exit 1 ;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
SAMPLED_TRACE="traces/sampled_${TARGET_REQUESTS}req_seed${SEED}.jsonl"
|
||||||
|
OUTPUT_DIR="outputs/${TAG}_$(date +%Y%m%d_%H%M%S)"
|
||||||
|
|
||||||
|
echo "=== Benchmark: tag=$TAG ==="
|
||||||
|
echo " Trace: $TRACE_INPUT"
|
||||||
|
echo " Endpoint: $ENDPOINT"
|
||||||
|
echo " Target requests: $TARGET_REQUESTS"
|
||||||
|
echo " Time scale: $TIME_SCALE"
|
||||||
|
echo " Max inflight sessions: $MAX_INFLIGHT"
|
||||||
|
|
||||||
|
# Step 1: Sample trace (if not already done)
|
||||||
|
if [ ! -f "$SAMPLED_TRACE" ]; then
|
||||||
|
echo ""
|
||||||
|
echo "=== Step 1: Sampling trace ==="
|
||||||
|
python scripts/sample_trace.py \
|
||||||
|
--input "$TRACE_INPUT" \
|
||||||
|
--output "$SAMPLED_TRACE" \
|
||||||
|
--target-requests "$TARGET_REQUESTS" \
|
||||||
|
--seed "$SEED"
|
||||||
|
else
|
||||||
|
echo ""
|
||||||
|
echo "=== Step 1: Using existing sampled trace: $SAMPLED_TRACE ==="
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Step 2: Run replay
|
||||||
|
echo ""
|
||||||
|
echo "=== Step 2: Replaying trace ==="
|
||||||
|
mkdir -p "$OUTPUT_DIR"
|
||||||
|
python -m replayer \
|
||||||
|
--trace "$SAMPLED_TRACE" \
|
||||||
|
--output "$OUTPUT_DIR/metrics.jsonl" \
|
||||||
|
--endpoint "$ENDPOINT" \
|
||||||
|
--time-scale "$TIME_SCALE" \
|
||||||
|
--max-inflight-sessions "$MAX_INFLIGHT" \
|
||||||
|
-v
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== Done ==="
|
||||||
|
echo " Metrics: $OUTPUT_DIR/metrics.jsonl"
|
||||||
|
echo " Summary: $OUTPUT_DIR/metrics.summary.json"
|
||||||
254
scripts/run_experiments.sh
Executable file
254
scripts/run_experiments.sh
Executable file
@@ -0,0 +1,254 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Run the complete experiment matrix:
|
||||||
|
# 1. Combined TP=2 DP=4 (4 instances, baseline)
|
||||||
|
# 2. Combined TP=1 DP=8 (8 instances, max throughput)
|
||||||
|
# 3. PD-Sep TP=1: P×4 + D×4 via Mooncake/RDMA
|
||||||
|
#
|
||||||
|
# All use the same trace, same concurrency, same timeout.
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||||
|
VENV="$PROJECT_DIR/.venv/bin"
|
||||||
|
VLLM="$VENV/vllm"
|
||||||
|
PYTHON="$VENV/python"
|
||||||
|
|
||||||
|
MODEL="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
||||||
|
TRACE="$PROJECT_DIR/traces/sampled_1000req_seed42.jsonl"
|
||||||
|
|
||||||
|
# Uniform benchmark params
|
||||||
|
MAX_SESSIONS=${MAX_SESSIONS:-8}
|
||||||
|
MAX_CONCURRENT=${MAX_CONCURRENT:-16}
|
||||||
|
TIME_SCALE=10
|
||||||
|
REQUEST_TIMEOUT=${REQUEST_TIMEOUT:-300}
|
||||||
|
REQUEST_LIMIT="${REQUEST_LIMIT:-}" # empty = all 1000
|
||||||
|
|
||||||
|
cleanup_gpu() {
|
||||||
|
pkill -9 -f "vllm" 2>/dev/null || true
|
||||||
|
pkill -9 -f "cache_aware_proxy\|mooncake_connector_proxy\|uvicorn" 2>/dev/null || true
|
||||||
|
fuser 9090/tcp 8000/tcp 2>/dev/null | xargs -r kill -9 2>/dev/null || true
|
||||||
|
sleep 5
|
||||||
|
fuser /dev/nvidia* 2>/dev/null | tr " " "\n" | sort -u | xargs -r kill -9 2>/dev/null || true
|
||||||
|
sleep 10
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_for_server() {
|
||||||
|
local port=$1
|
||||||
|
local timeout=${2:-600}
|
||||||
|
timeout "$timeout" bash -c "until curl -s localhost:$port/v1/models >/dev/null 2>&1; do sleep 5; done"
|
||||||
|
}
|
||||||
|
|
||||||
|
run_benchmark() {
|
||||||
|
local tag=$1
|
||||||
|
local endpoint=$2
|
||||||
|
local extra_args="${3:-}"
|
||||||
|
local outdir="$PROJECT_DIR/outputs/$tag"
|
||||||
|
|
||||||
|
echo " Running benchmark -> $outdir"
|
||||||
|
local limit_arg=""
|
||||||
|
if [ -n "$REQUEST_LIMIT" ]; then
|
||||||
|
limit_arg="--request-limit $REQUEST_LIMIT"
|
||||||
|
fi
|
||||||
|
|
||||||
|
$PYTHON -m replayer \
|
||||||
|
--trace "$TRACE" \
|
||||||
|
--output "$outdir/metrics.jsonl" \
|
||||||
|
--endpoint "$endpoint" \
|
||||||
|
--model "$MODEL" \
|
||||||
|
--time-scale $TIME_SCALE \
|
||||||
|
--max-inflight-sessions $MAX_SESSIONS \
|
||||||
|
--concurrency-limit $MAX_CONCURRENT \
|
||||||
|
--request-timeout $REQUEST_TIMEOUT \
|
||||||
|
$limit_arg \
|
||||||
|
-v
|
||||||
|
|
||||||
|
echo " Done: $(wc -l < "$outdir/metrics.jsonl") requests"
|
||||||
|
}
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# Experiment 1: Combined TP=2 DP=4
|
||||||
|
#######################################################################
|
||||||
|
run_combined_tp2_dp4() {
|
||||||
|
echo ""
|
||||||
|
echo "================================================================"
|
||||||
|
echo " Experiment 1: Combined TP=2 DP=4 (4 instances on 8 GPUs)"
|
||||||
|
echo "================================================================"
|
||||||
|
cleanup_gpu
|
||||||
|
|
||||||
|
for i in 0 1 2 3; do
|
||||||
|
local gpu_start=$((i * 2))
|
||||||
|
local gpu_end=$((gpu_start + 1))
|
||||||
|
local port=$((8000 + i))
|
||||||
|
echo " Starting instance $i: GPUs $gpu_start,$gpu_end, port $port"
|
||||||
|
CUDA_VISIBLE_DEVICES=$gpu_start,$gpu_end $VLLM serve "$MODEL" \
|
||||||
|
--host 0.0.0.0 --port $port \
|
||||||
|
--tensor-parallel-size 2 \
|
||||||
|
--trust-remote-code --enable-prefix-caching --enforce-eager \
|
||||||
|
--dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 &
|
||||||
|
done
|
||||||
|
|
||||||
|
for i in 0 1 2 3; do
|
||||||
|
wait_for_server $((8000 + i))
|
||||||
|
echo " Instance $i ready"
|
||||||
|
done
|
||||||
|
echo " All 4 instances ready"
|
||||||
|
|
||||||
|
# Start global scheduler (cache-aware proxy in combined mode)
|
||||||
|
echo " Starting global scheduler..."
|
||||||
|
$PYTHON "$PROJECT_DIR/scripts/cache_aware_proxy.py" \
|
||||||
|
--combined http://127.0.0.1:8000 http://127.0.0.1:8001 http://127.0.0.1:8002 http://127.0.0.1:8003 \
|
||||||
|
--port 9090 &
|
||||||
|
sleep 5
|
||||||
|
|
||||||
|
run_benchmark "exp1_combined_tp2_dp4" "http://localhost:9090"
|
||||||
|
}
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# Experiment 2: Combined TP=1 DP=8
|
||||||
|
#######################################################################
|
||||||
|
run_combined_tp1_dp8() {
|
||||||
|
echo ""
|
||||||
|
echo "================================================================"
|
||||||
|
echo " Experiment 2: Combined TP=1 DP=8 (8 instances on 8 GPUs)"
|
||||||
|
echo "================================================================"
|
||||||
|
cleanup_gpu
|
||||||
|
|
||||||
|
for i in $(seq 0 7); do
|
||||||
|
local port=$((8000 + i))
|
||||||
|
echo " Starting instance $i: GPU $i, port $port"
|
||||||
|
CUDA_VISIBLE_DEVICES=$i $VLLM serve "$MODEL" \
|
||||||
|
--host 0.0.0.0 --port $port \
|
||||||
|
--tensor-parallel-size 1 \
|
||||||
|
--trust-remote-code --enable-prefix-caching --enforce-eager \
|
||||||
|
--dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 &
|
||||||
|
done
|
||||||
|
|
||||||
|
for i in $(seq 0 7); do
|
||||||
|
wait_for_server $((8000 + i))
|
||||||
|
echo " Instance $i ready"
|
||||||
|
done
|
||||||
|
echo " All 8 instances ready"
|
||||||
|
|
||||||
|
# Start global scheduler (cache-aware proxy in combined mode)
|
||||||
|
echo " Starting global scheduler..."
|
||||||
|
$PYTHON "$PROJECT_DIR/scripts/cache_aware_proxy.py" \
|
||||||
|
--combined http://127.0.0.1:8000 http://127.0.0.1:8001 http://127.0.0.1:8002 http://127.0.0.1:8003 \
|
||||||
|
http://127.0.0.1:8004 http://127.0.0.1:8005 http://127.0.0.1:8006 http://127.0.0.1:8007 \
|
||||||
|
--port 9090 &
|
||||||
|
sleep 5
|
||||||
|
|
||||||
|
run_benchmark "exp2_combined_tp1_dp8" "http://localhost:9090"
|
||||||
|
}
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# Experiment 3: PD-Sep TP=1 P×4 D×4 (Mooncake/RDMA)
|
||||||
|
#######################################################################
|
||||||
|
run_pd_sep_tp1() {
|
||||||
|
echo ""
|
||||||
|
echo "================================================================"
|
||||||
|
echo " Experiment 3: PD-Sep TP=1 P×4 + D×4 (Mooncake/RDMA)"
|
||||||
|
echo "================================================================"
|
||||||
|
cleanup_gpu
|
||||||
|
|
||||||
|
PROXY_SCRIPT="$PROJECT_DIR/scripts/cache_aware_proxy.py"
|
||||||
|
|
||||||
|
# Start 4 prefill instances (GPUs 0-3)
|
||||||
|
local prefill_args=""
|
||||||
|
for i in 0 1 2 3; do
|
||||||
|
local port=$((8010 + i))
|
||||||
|
local bootstrap=$((8998 + i))
|
||||||
|
echo " Prefill $i: GPU $i, port $port, bootstrap $bootstrap"
|
||||||
|
VLLM_MOONCAKE_BOOTSTRAP_PORT=$bootstrap \
|
||||||
|
CUDA_VISIBLE_DEVICES=$i $VLLM serve "$MODEL" \
|
||||||
|
--host 0.0.0.0 --port $port \
|
||||||
|
--tensor-parallel-size 1 \
|
||||||
|
--trust-remote-code --enable-prefix-caching --enforce-eager \
|
||||||
|
--dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 \
|
||||||
|
--kv-transfer-config \
|
||||||
|
"{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_producer\"}" &
|
||||||
|
prefill_args="$prefill_args --prefill http://127.0.0.1:$port $bootstrap"
|
||||||
|
done
|
||||||
|
|
||||||
|
# Start 4 decode instances (GPUs 4-7)
|
||||||
|
local decode_args=""
|
||||||
|
for i in 0 1 2 3; do
|
||||||
|
local gpu=$((4 + i))
|
||||||
|
local port=$((8020 + i))
|
||||||
|
echo " Decode $i: GPU $gpu, port $port"
|
||||||
|
CUDA_VISIBLE_DEVICES=$gpu $VLLM serve "$MODEL" \
|
||||||
|
--host 0.0.0.0 --port $port \
|
||||||
|
--tensor-parallel-size 1 \
|
||||||
|
--trust-remote-code --enable-prefix-caching --enforce-eager \
|
||||||
|
--dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 \
|
||||||
|
--kv-transfer-config \
|
||||||
|
"{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_consumer\",\"kv_load_failure_policy\":\"recompute\"}" &
|
||||||
|
decode_args="$decode_args --decode http://127.0.0.1:$port"
|
||||||
|
done
|
||||||
|
|
||||||
|
# Wait for all instances
|
||||||
|
for i in 0 1 2 3; do
|
||||||
|
wait_for_server $((8010 + i))
|
||||||
|
echo " Prefill $i ready"
|
||||||
|
done
|
||||||
|
for i in 0 1 2 3; do
|
||||||
|
wait_for_server $((8020 + i))
|
||||||
|
echo " Decode $i ready"
|
||||||
|
done
|
||||||
|
|
||||||
|
# Start proxy (wait for bootstrap to be queryable first)
|
||||||
|
echo " Waiting for bootstrap servers..."
|
||||||
|
for bp in 8998 8999 9000 9001; do
|
||||||
|
timeout 120 bash -c "until curl -s localhost:$bp/query > /dev/null 2>&1; do sleep 2; done"
|
||||||
|
echo " Bootstrap $bp ready"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo " Starting proxy on port 9000..."
|
||||||
|
$PYTHON "$PROXY_SCRIPT" $prefill_args $decode_args --host 0.0.0.0 --port 9090 &
|
||||||
|
sleep 15
|
||||||
|
|
||||||
|
# Smoke test with retry
|
||||||
|
echo " Smoke test..."
|
||||||
|
for attempt in 1 2 3; do
|
||||||
|
result=$(curl -s -m 120 http://localhost:9090/v1/completions \
|
||||||
|
-X POST -H "Content-Type: application/json" \
|
||||||
|
-d "{\"model\":\"$MODEL\",\"prompt\":[100,200,300],\"max_tokens\":3,\"temperature\":0}" 2>&1)
|
||||||
|
if echo "$result" | grep -q "choices"; then
|
||||||
|
echo " Smoke test passed!"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
echo " Attempt $attempt failed, retrying..."
|
||||||
|
sleep 10
|
||||||
|
done
|
||||||
|
|
||||||
|
run_benchmark "exp3_pd_sep_tp1_mooncake" "http://localhost:9090"
|
||||||
|
}
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# Main
|
||||||
|
#######################################################################
|
||||||
|
echo "Starting experiment matrix on $(hostname)"
|
||||||
|
echo "Model: $MODEL"
|
||||||
|
echo "Trace: $TRACE"
|
||||||
|
echo "Params: sessions=$MAX_SESSIONS, concurrent=$MAX_CONCURRENT, time_scale=$TIME_SCALE"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
case "${1:-all}" in
|
||||||
|
1|tp2dp4) run_combined_tp2_dp4 ;;
|
||||||
|
2|tp1dp8) run_combined_tp1_dp8 ;;
|
||||||
|
3|pdsep) run_pd_sep_tp1 ;;
|
||||||
|
all)
|
||||||
|
run_combined_tp2_dp4
|
||||||
|
run_combined_tp1_dp8
|
||||||
|
run_pd_sep_tp1
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Usage: $0 {1|2|3|all|tp2dp4|tp1dp8|pdsep}"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "================================================================"
|
||||||
|
echo " All experiments complete!"
|
||||||
|
echo "================================================================"
|
||||||
|
cleanup_gpu
|
||||||
204
scripts/sample_trace.py
Normal file
204
scripts/sample_trace.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
"""Sample sessions from the full cluster-scale trace to fit a single machine.
|
||||||
|
|
||||||
|
Preserves:
|
||||||
|
- Complete session structure (all turns within a session kept together)
|
||||||
|
- Original arrival timing (inter-session and intra-session gaps)
|
||||||
|
- hash_ids for KV cache reuse patterns
|
||||||
|
- Request type distribution
|
||||||
|
|
||||||
|
Sampling strategy:
|
||||||
|
1. Group requests by session (derived from parent_chat_id chains)
|
||||||
|
2. Randomly sample N sessions (or until target request count reached)
|
||||||
|
3. Re-zero timestamps so first event starts at t=0
|
||||||
|
4. Optionally compress time axis to increase load density
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/sample_trace.py \\
|
||||||
|
--input ~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl \\
|
||||||
|
--output traces/sampled.jsonl \\
|
||||||
|
--target-requests 5000 \\
|
||||||
|
--seed 42
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import collections
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def load_raw_rows(path: Path) -> dict[str, list[dict]]:
|
||||||
|
"""Load trace, group rows by resolved session_id. Preserve file order."""
|
||||||
|
chat_to_session: dict[int, str] = {}
|
||||||
|
rows_by_session: dict[str, list[dict]] = collections.OrderedDict()
|
||||||
|
|
||||||
|
with path.open("r", encoding="utf-8") as fh:
|
||||||
|
for line in fh:
|
||||||
|
row = json.loads(line)
|
||||||
|
cid = int(row["chat_id"])
|
||||||
|
pid = int(row["parent_chat_id"])
|
||||||
|
|
||||||
|
if "session_id" in row:
|
||||||
|
sid = str(row["session_id"])
|
||||||
|
elif pid < 0:
|
||||||
|
sid = str(cid)
|
||||||
|
else:
|
||||||
|
sid = chat_to_session.get(pid, str(pid))
|
||||||
|
chat_to_session[cid] = sid
|
||||||
|
|
||||||
|
row["_session_id"] = sid
|
||||||
|
rows_by_session.setdefault(sid, []).append(row)
|
||||||
|
|
||||||
|
return rows_by_session
|
||||||
|
|
||||||
|
|
||||||
|
def sample_sessions(
|
||||||
|
rows_by_session: dict[str, list[dict]],
|
||||||
|
*,
|
||||||
|
target_requests: int,
|
||||||
|
seed: int,
|
||||||
|
strategy: str = "random",
|
||||||
|
) -> list[str]:
|
||||||
|
"""Select sessions until target request count is reached."""
|
||||||
|
all_sids = list(rows_by_session.keys())
|
||||||
|
rng = random.Random(seed)
|
||||||
|
|
||||||
|
if strategy == "random":
|
||||||
|
rng.shuffle(all_sids)
|
||||||
|
elif strategy == "sequential":
|
||||||
|
pass # keep file order
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown strategy: {strategy}")
|
||||||
|
|
||||||
|
selected = []
|
||||||
|
total = 0
|
||||||
|
for sid in all_sids:
|
||||||
|
selected.append(sid)
|
||||||
|
total += len(rows_by_session[sid])
|
||||||
|
if total >= target_requests:
|
||||||
|
break
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
|
|
||||||
|
def build_output(
|
||||||
|
rows_by_session: dict[str, list[dict]],
|
||||||
|
selected: list[str],
|
||||||
|
*,
|
||||||
|
time_scale: float = 1.0,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Build output rows with re-zeroed timestamps."""
|
||||||
|
out_rows = []
|
||||||
|
for sid in selected:
|
||||||
|
for row in rows_by_session[sid]:
|
||||||
|
out = {k: v for k, v in row.items() if not k.startswith("_")}
|
||||||
|
out["session_id"] = sid
|
||||||
|
out_rows.append(out)
|
||||||
|
|
||||||
|
out_rows.sort(key=lambda r: float(r["timestamp"]))
|
||||||
|
|
||||||
|
if not out_rows:
|
||||||
|
return out_rows
|
||||||
|
|
||||||
|
# Re-zero: subtract earliest timestamp
|
||||||
|
t0 = float(out_rows[0]["timestamp"])
|
||||||
|
for row in out_rows:
|
||||||
|
row["timestamp"] = (float(row["timestamp"]) - t0) / time_scale
|
||||||
|
|
||||||
|
return out_rows
|
||||||
|
|
||||||
|
|
||||||
|
def print_summary(
|
||||||
|
rows_by_session: dict[str, list[dict]],
|
||||||
|
selected: list[str],
|
||||||
|
out_rows: list[dict],
|
||||||
|
) -> None:
|
||||||
|
n_sessions = len(selected)
|
||||||
|
n_requests = len(out_rows)
|
||||||
|
turns_per_session = [len(rows_by_session[s]) for s in selected]
|
||||||
|
multi_turn = sum(1 for t in turns_per_session if t > 1)
|
||||||
|
|
||||||
|
input_lens = [r["input_length"] for r in out_rows]
|
||||||
|
output_lens = [r["output_length"] for r in out_rows]
|
||||||
|
|
||||||
|
span_s = float(out_rows[-1]["timestamp"]) if out_rows else 0
|
||||||
|
session_starts = {}
|
||||||
|
for r in out_rows:
|
||||||
|
sid = r["session_id"]
|
||||||
|
ts = float(r["timestamp"])
|
||||||
|
if sid not in session_starts:
|
||||||
|
session_starts[sid] = ts
|
||||||
|
starts_sorted = sorted(session_starts.values())
|
||||||
|
deltas = [starts_sorted[i+1] - starts_sorted[i]
|
||||||
|
for i in range(len(starts_sorted) - 1)]
|
||||||
|
|
||||||
|
# hash_ids overlap: count unique hash_ids across all requests
|
||||||
|
all_hashes = set()
|
||||||
|
for r in out_rows:
|
||||||
|
all_hashes.update(r.get("hash_ids", []))
|
||||||
|
|
||||||
|
print(f"Sampled: {n_sessions} sessions, {n_requests} requests")
|
||||||
|
print(f" Multi-turn sessions: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)")
|
||||||
|
print(f" Turns/session: min={min(turns_per_session)} max={max(turns_per_session)} "
|
||||||
|
f"avg={sum(turns_per_session)/len(turns_per_session):.1f}")
|
||||||
|
print(f" Input length: min={min(input_lens)} max={max(input_lens)} "
|
||||||
|
f"avg={sum(input_lens)/len(input_lens):.0f}")
|
||||||
|
print(f" Output length: min={min(output_lens)} max={max(output_lens)} "
|
||||||
|
f"avg={sum(output_lens)/len(output_lens):.0f}")
|
||||||
|
print(f" Trace span: {span_s:.1f}s ({span_s/60:.1f} min)")
|
||||||
|
print(f" Unique hash blocks: {len(all_hashes)}")
|
||||||
|
if deltas:
|
||||||
|
deltas.sort()
|
||||||
|
p = lambda q: deltas[min(int(q * len(deltas)), len(deltas) - 1)]
|
||||||
|
print(f" Session arrival deltas (s): p10={p(0.1):.2f} p50={p(0.5):.2f} "
|
||||||
|
f"p90={p(0.9):.2f} max={max(deltas):.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
p = argparse.ArgumentParser(description=__doc__,
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||||
|
p.add_argument("--input", type=Path, required=True,
|
||||||
|
help="Path to the full trace JSONL file")
|
||||||
|
p.add_argument("--output", type=Path, required=True,
|
||||||
|
help="Path to write sampled trace JSONL")
|
||||||
|
p.add_argument("--target-requests", type=int, default=5000,
|
||||||
|
help="Target number of requests (stops after session that crosses it)")
|
||||||
|
p.add_argument("--strategy", choices=["random", "sequential"], default="random",
|
||||||
|
help="Session selection strategy")
|
||||||
|
p.add_argument("--time-scale", type=float, default=1.0,
|
||||||
|
help="Compress time axis by this factor (>1 = faster arrival)")
|
||||||
|
p.add_argument("--seed", type=int, default=42)
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
print(f"Loading trace from {args.input} ...")
|
||||||
|
rows_by_session = load_raw_rows(args.input)
|
||||||
|
total_sessions = len(rows_by_session)
|
||||||
|
total_requests = sum(len(v) for v in rows_by_session.values())
|
||||||
|
print(f"Full trace: {total_sessions} sessions, {total_requests} requests")
|
||||||
|
|
||||||
|
selected = sample_sessions(
|
||||||
|
rows_by_session,
|
||||||
|
target_requests=args.target_requests,
|
||||||
|
seed=args.seed,
|
||||||
|
strategy=args.strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_rows = build_output(
|
||||||
|
rows_by_session, selected,
|
||||||
|
time_scale=args.time_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
print_summary(rows_by_session, selected, out_rows)
|
||||||
|
|
||||||
|
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with args.output.open("w", encoding="utf-8") as fh:
|
||||||
|
for row in out_rows:
|
||||||
|
fh.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||||
|
print(f"\nWrote {len(out_rows)} rows to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user