Agentic workload PD separation analysis with trace-driven benchmarks
Systematic study of prefill-decode disaggregation for agentic LLM workloads using production GLM-5.1 coder trace (2.1M requests, 71B input tokens). Key findings: - Cache-aware routing improves TPOT p90 by 15% and APC from 20.8% to 44.7% without PD separation, matching PD-Sep's decode isolation benefit - PD separation adds +72% TTFT overhead (KV transfer) with no TPOT gain when using the same cache-aware scheduler - Prefill remains compute-bound even at 95% KV cache reuse (AI >1000x vs decode AI <2), but absolute FLOPs drop 71% from cache hits - For agentic MoE workloads, cache-aware routing > PD separation Infrastructure: - Trace sampler preserving session structure + hash_ids for prefix sharing - Async trace replayer with streaming TTFT/TPOT/E2E measurement - Unified cache-aware + token-level load-balanced global scheduler proxy supporting both PD-colocated and PD-disaggregated (Mooncake/RDMA) modes - vLLM 0.18.1 scheduler patch for KV transfer abort race condition - Roofline analysis tool for prefill/decode compute characterization Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
8
.gitignore
vendored
Normal file
8
.gitignore
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.venv/
|
||||
*.egg-info/
|
||||
outputs/
|
||||
traces/
|
||||
third_party/
|
||||
*.log
|
||||
25
TODO.md
Normal file
25
TODO.md
Normal file
@@ -0,0 +1,25 @@
|
||||
实验 setup:
|
||||
|
||||
GPU 机器:dash0,是 8*H20 的机器,可以直接 `ssh dash0` 进行连接访问
|
||||
|
||||
推理引擎:基于 vllm 0.18.1,self build,支持后续 patch 放在 git 中维护
|
||||
|
||||
模型:`~/models/Qwen/Qwen3-Coder-30B-A3B-Instruct`
|
||||
|
||||
推理 trace:原始完整 2h trace 在 dash0 的 `~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl`
|
||||
|
||||
|
||||
|
||||
性能指标:每个请求的 TTFT/TPOT/E2E/TBT,prefix KVCache hit ratio
|
||||
|
||||
|
||||
|
||||
目标:
|
||||
|
||||
1. 先实现标准的 trace-sampler,将 cluster 规模的原始 trace,sample 到合适当前机器数量的规模来跑,保持一份统一的 sample 后的 trace file 作为输入
|
||||
2. 实现标准的 trace replayer,保证能够体现线上流量的流量到来特征,KVCache 重用特征等
|
||||
3. 跑通 PD 分离,确认 PD 分离能够比普通的 PD 混合一起跑的性能要好,给出两者详细的性能对比以及原因分析
|
||||
4. 判断 trace 的 pattern,是否有必要 PD 完全混合或者 PD 分离
|
||||
5. 参考本地的 `~/phd/agentic-pd-hybrid`,判断是否能够实现一套 prefill-as-a-service 的架构,把重的 prefill 交给 prefill service,prefill service 能够从本地 GPU/DRAM/别的 GPU 机器上 pull KVCache,提高本地的 prefix KVCache hit ratio,不影响 decoding 的 prefill,就可以交给过去 PD 分离定义中 D-node 来做,提高 KVCache 命中率
|
||||
|
||||
|
||||
289
analysis/pd_separation_analysis.md
Normal file
289
analysis/pd_separation_analysis.md
Normal file
@@ -0,0 +1,289 @@
|
||||
# PD 分离在 Agentic Workload 下的系统分析
|
||||
|
||||
## 1. Trace 特征 (GLM-5.1 Agentic Coder, 2h, 2.1M requests)
|
||||
|
||||
```
|
||||
Total requests: 2,114,220
|
||||
Input tokens: 71.1B (avg 33.6k/req, p50=20k, p90=88k)
|
||||
Output tokens: 940M (avg 445/req, p50=80, p90=811)
|
||||
I/O ratio: 75.6x (aggregate), 217.8x (per-req median)
|
||||
Prefill share: 98% of total tokens
|
||||
Sessions: 1.3M (90% single-turn, 9% multi-turn)
|
||||
```
|
||||
|
||||
**与传统 chatbot workload 的根本区别:**
|
||||
|
||||
| 特征 | Traditional Chatbot | Agentic Coder (GLM-5.1) |
|
||||
|------|-------------------|------------------------|
|
||||
| I/O ratio | 1-10x | **75.6x** |
|
||||
| Input p50 | 500-2000 tokens | **20,030 tokens** |
|
||||
| Output p50 | 200-500 tokens | **80 tokens** |
|
||||
| Prefill token share | 50-80% | **98%** |
|
||||
| >32k input | <5% | **38%** |
|
||||
| Multi-turn | 50-80% | **9%** |
|
||||
|
||||
**KV Cache 复用特征:**
|
||||
|
||||
```
|
||||
Unique hash blocks: 20,650,883
|
||||
Shared blocks (ref>1): 9,749,379 (47%)
|
||||
Highly shared (ref>10): 2,428,160
|
||||
Intra-session reuse: 57%
|
||||
Top-10 blocks ref count: 64,754 (system prompt blocks)
|
||||
Theoretical cache hit: 71% (infinite cache, first 100k requests)
|
||||
```
|
||||
|
||||
**Input length 分布与 token 占比:**
|
||||
|
||||
```
|
||||
<1k: 202,396 reqs ( 9%) 89M tokens ( 0%)
|
||||
1-8k: 380,009 reqs (17%) 1.6B tokens ( 2%)
|
||||
8-32k: 720,871 reqs (34%) 12.7B tokens (17%)
|
||||
32-65k: 405,371 reqs (19%) 19.4B tokens (27%)
|
||||
65-131k: 394,014 reqs (18%) 35.7B tokens (50%)
|
||||
>131k: 11,559 reqs ( 0%) 1.6B tokens ( 2%)
|
||||
```
|
||||
|
||||
50% 的 token 计算量来自 65-131k 的长 context 请求。
|
||||
|
||||
## 2. DistServe 等 PD 分离的核心假设
|
||||
|
||||
DistServe (OSDI'24), Splitwise, TetriInfer 等 PD 分离工作基于以下假设:
|
||||
|
||||
### 假设 A: Prefill 和 Decode 有不同的计算特征
|
||||
- **Prefill**: compute-bound, 高 GPU 利用率, batch 越大越好
|
||||
- **Decode**: memory-bandwidth-bound, 低 GPU 利用率, latency-sensitive
|
||||
|
||||
**在 agentic workload 中的验证**: ✅ 成立,但需要细化
|
||||
|
||||
Roofline 分析显示(详见 Section 5):
|
||||
|
||||
```
|
||||
Arithmetic Intensity (FLOP/byte)
|
||||
Decode: 1.0 - 1.9 (memory-bound, 始终远低于 ridge point)
|
||||
Prefill 0% reuse: 23,000-72,000 (strongly compute-bound)
|
||||
Prefill 70% reuse: 10,000-42,000 (仍然 compute-bound!)
|
||||
Prefill 95% reuse: 1,900-10,800 (仍然 compute-bound!)
|
||||
Ridge point (H20): 37
|
||||
```
|
||||
|
||||
**即使 95% KV cache reuse,prefill 仍然是 compute-bound。** 但绝对计算量大幅减少。
|
||||
|
||||
### 假设 B: PD co-location 导致互相干扰
|
||||
- Prefill 的大 batch 计算会抢占 GPU 资源,导致 decode 的 TPOT 升高
|
||||
- Decode 的持续小计算会占用 GPU 调度槽位,影响 prefill 吞吐
|
||||
|
||||
**在 agentic workload 中的验证**: ⚠️ 干扰存在,但 **可被 cache-aware routing 消除**
|
||||
|
||||
```
|
||||
同一 cache-aware scheduler, TP=1, 8 GPU:
|
||||
Combined TP=1 DP=8: TPOT p90 = 0.073s
|
||||
PD-Sep TP=1 4P+4D: TPOT p90 = 0.074s
|
||||
→ 差异 <2%, 不显著
|
||||
```
|
||||
|
||||
对比 round-robin routing:
|
||||
```
|
||||
Combined TP=1 DP=8 (RR): TPOT p90 = 0.086s
|
||||
Combined TP=1 DP=8 (cache-aware): TPOT p90 = 0.073s → -15%
|
||||
→ routing 改善 > PD 分离改善
|
||||
```
|
||||
|
||||
**原因**: cache-aware routing 让 high-cache-hit 的请求集中到特定 instance,
|
||||
每个 instance 的实际 prefill 新 token 数大幅减少(71% 被 cache),
|
||||
prefill-decode 干扰因 prefill 工作量降低而自然缓解。
|
||||
|
||||
### 假设 C: KV Cache 传输开销可以忽略
|
||||
- DistServe 假设 P→D 的 KV 传输延迟远小于 prefill 计算时间
|
||||
- 在 InfiniBand/NVLink 等高带宽互联下成立
|
||||
|
||||
**在 agentic workload 中的验证**: ❌ 不成立
|
||||
|
||||
```
|
||||
PD-Sep TTFT p50 = 1.261s vs Combined TTFT p50 = 0.731s (+72%)
|
||||
```
|
||||
|
||||
原因:
|
||||
1. Agentic workload 的 input 极长(p50=20k, p90=88k tokens),KV cache 很大
|
||||
2. 单请求 KV cache = 20k tokens × 48 layers × 2(K+V) × 512 bytes ≈ 1GB
|
||||
3. 更重要的是 await-prefill 链路的串行延迟:proxy → prefill → KV transfer → decode → first token
|
||||
|
||||
### 假设 D: 专用 prefill 节点可以提高 prefill 吞吐
|
||||
- Prefill 节点不做 decode,GPU 利用率更高
|
||||
- 可以用更大的 batch size
|
||||
|
||||
**在 agentic workload 中的验证**: ⚠️ 收益被 cache 稀释
|
||||
|
||||
```
|
||||
理论 prefix cache hit (infinite cache): 71% of input tokens
|
||||
实际 APC (Combined, cache-aware, 8 inst): 44.7%
|
||||
```
|
||||
|
||||
71% cache hit → 只有 29% 的 input tokens 需要实际 prefill compute。
|
||||
Nominal avg input 33.6k → Actual avg new prefill ~9.7k tokens。
|
||||
专用 prefill 的 GPU 利用率优势因 prefill 工作量降低而缩小。
|
||||
|
||||
## 3. Roofline 分析:Prefill 在高 Cache Reuse 下的计算/访存特性
|
||||
|
||||
### 3.1 模型计算结构
|
||||
|
||||
```
|
||||
Qwen3-Coder-30B-A3B (MoE 128E top-8):
|
||||
48 layers, hidden=2048, heads=32, kv_heads=4 (GQA), head_dim=128
|
||||
FFN: 6144 intermediate per expert, 8 experts active per token
|
||||
Active params per token: ~3B
|
||||
|
||||
H20 GPU: 148 TFLOPS (BF16), 4.0 TB/s HBM → Ridge point: 37 FLOP/byte
|
||||
```
|
||||
|
||||
### 3.2 Decode 永远 memory-bound
|
||||
|
||||
```
|
||||
SeqLen FLOP Bytes AI (F/B) Bound
|
||||
1,000 3.04e+10 3.01e+10 1.0 MEMORY
|
||||
16,000 3.63e+10 3.16e+10 1.1 MEMORY
|
||||
64,000 5.52e+10 3.63e+10 1.5 MEMORY
|
||||
128,000 8.03e+10 4.26e+10 1.9 MEMORY
|
||||
```
|
||||
|
||||
Decode 的 AI 始终 < 2,远低于 ridge point (37)。每个 decode step 只处理 1 个 token,
|
||||
计算量极小,瓶颈在于读取模型权重和全量 KV cache。
|
||||
|
||||
### 3.3 Prefill 即使 95% reuse 仍然 compute-bound
|
||||
|
||||
```
|
||||
SeqLen Reuse% NewTok AI (F/B) Bound vs Decode
|
||||
32,000 0% 32,000 23,368 COMPUTE 18,190x
|
||||
32,000 50% 16,000 14,899 COMPUTE 11,597x
|
||||
32,000 70% 9,600 10,045 COMPUTE 7,819x
|
||||
32,000 90% 3,200 3,821 COMPUTE 2,974x
|
||||
32,000 95% 1,600 1,980 COMPUTE 1,542x
|
||||
|
||||
64,000 0% 64,000 40,758 COMPUTE 26,813x
|
||||
64,000 70% 19,200 20,610 COMPUTE 13,559x
|
||||
64,000 90% 6,400 8,544 COMPUTE 5,621x
|
||||
64,000 95% 3,200 4,549 COMPUTE 2,993x
|
||||
```
|
||||
|
||||
### 3.4 为什么高 reuse 不改变 compute-bound 性质
|
||||
|
||||
KV cache reuse 减少的:
|
||||
- K/V projection 计算(只算 new tokens)
|
||||
- KV 写入(只写 new tokens)
|
||||
|
||||
KV cache reuse **不减少**的:
|
||||
- **Q×K^T attention**: 每个 new Q 都要和全部 seq_len 个 KV 做 attention
|
||||
```
|
||||
FLOPs = new_tokens × seq_len × head_dim × num_heads × 2 × num_layers
|
||||
```
|
||||
At 95% reuse, 32k seq: 1600 × 32000 × 128 × 32 × 2 × 48 ≈ 2×10^13
|
||||
这个二次项在长 context 下主导总计算量
|
||||
|
||||
- **MoE FFN**: 每个 new token 激活 8 experts
|
||||
```
|
||||
FLOPs = new_tokens × 3 × D × D_ffn × 2 × K_experts × num_layers
|
||||
```
|
||||
|
||||
**Prefill 只在接近 100% reuse (< 10 new tokens) 时才变成 memory-bound。**
|
||||
|
||||
### 3.5 Prefill 什么时候变 memory-bound
|
||||
|
||||
```
|
||||
SeqLen=32,000: new_tokens ≈ 5-10 时 → AI ≈ 37 (ridge point)
|
||||
SeqLen=64,000: new_tokens ≈ 5-10 时 → AI ≈ 37
|
||||
```
|
||||
|
||||
在实际 agentic trace 中:
|
||||
```
|
||||
Compute-bound prefills: 961 (96%)
|
||||
Memory-bound prefills: 37 (3%) ← 近 100% reuse 的极端 warm 请求
|
||||
```
|
||||
|
||||
### 3.6 关键洞察:"Compute-bound but lightweight"
|
||||
|
||||
高 cache reuse 下的 prefill 处于一种独特状态:
|
||||
|
||||
```
|
||||
Prefill bound 类型: Compute-bound (不变)
|
||||
Prefill 绝对工作量: 大幅降低 (71% cache → 只算 29% 的 tokens)
|
||||
Prefill-Decode 干扰: 因绝对工作量降低而减轻 (不需要物理隔离)
|
||||
```
|
||||
|
||||
这解释了为什么 PD 分离没有帮助:
|
||||
- PD 分离解决的是 "prefill 太重干扰 decode" 的问题
|
||||
- 但 cache-aware routing 已经把 prefill 的实际工作量降到足够轻
|
||||
- 物理隔离(PD 分离)的收益被 KV 传输开销抵消
|
||||
|
||||
## 4. 实验结果
|
||||
|
||||
### 4.1 完整实验矩阵
|
||||
|
||||
所有实验使用统一的 cache-aware + token-level load-balanced global scheduler。
|
||||
|
||||
| Config | OK/N | TTFT p50 | TPOT p90 | E2E p50 | APC |
|
||||
|--------|------|----------|----------|---------|-----|
|
||||
| TP=8 DP=1 (single instance) | 998/1000 | 0.467s | 0.129s | 3.30s | 53.0% |
|
||||
| TP=2 DP=4 (4 inst, RR) | 997/999 | 0.844s | 0.095s | 4.92s | 33.5% |
|
||||
| TP=1 DP=8 (8 inst, RR) | 997/999 | 1.836s | 0.086s | 6.67s | 20.8% |
|
||||
| **TP=1 DP=8 (cache-aware)** | **997/999** | **0.731s** | **0.073s** | **4.48s** | **44.7%** |
|
||||
| TP=1 PD-Sep 4P+4D (cache-aware) | 509/564 | 1.261s | 0.074s | 5.61s | 40.2% |
|
||||
|
||||
### 4.2 Cache-Aware Routing 的效果
|
||||
|
||||
```
|
||||
Round-robin → Cache-aware (Combined TP=1 DP=8):
|
||||
TTFT p50: 1.836s → 0.731s (-60%)
|
||||
TPOT p90: 0.086s → 0.073s (-15%)
|
||||
E2E p50: 6.673s → 4.480s (-33%)
|
||||
APC: 20.8% → 44.7% (+24pp)
|
||||
```
|
||||
|
||||
Cache-aware routing 的提升远大于 PD 分离的提升。
|
||||
|
||||
### 4.3 修复工程问题的过程
|
||||
|
||||
实验过程中发现并修复了多个 PD 分离的工程问题:
|
||||
|
||||
| 问题 | 根因 | 修复 |
|
||||
|------|------|------|
|
||||
| Decode engine crash | vLLM scheduler assert: KV transfer 回调时 request 已 abort | Patch scheduler.py: assert → graceful skip |
|
||||
| Head-of-line blocking | Proxy 按 request count 做 LB,不区分大小请求 | Token-level ongoing_tokens load balancing |
|
||||
| "Timeout waiting for P side ready" | Proxy fire-and-forget prefill, decode 盲等 KV | Await-prefill + kv_load_failure_policy=recompute |
|
||||
| Port collision on startup | 8 Mooncake instances 同时启动争抢 torch distributed port | Staggered startup + explicit MASTER_PORT |
|
||||
| Cache routing "rich get richer" | score = ongoing - alpha*cached 导致流量集中到一个 instance | Normalized scoring: ongoing/avg_load - alpha*cache_ratio |
|
||||
|
||||
## 5. 结论
|
||||
|
||||
### 5.1 PD 分离为什么在 Agentic Workload 不生效
|
||||
|
||||
1. **Cache reuse 大幅降低 prefill 绝对工作量(71% cache hit → 只算 29%)**,使得 P-D 干扰不显著
|
||||
2. **Prefill 仍然 compute-bound**(即使 95% reuse,AI 仍 >1000),但每个请求的总 FLOPs 因 new_tokens 减少而大幅降低
|
||||
3. **Cache-aware routing 提供 "软 PD 隔离"**,效果等同于物理隔离但无 KV 传输开销
|
||||
4. **KV 传输开销不可忽略**(TTFT +72%),抵消了隔离收益
|
||||
5. **MoE 模型 active params 小**(3B),per-token compute 本身较轻
|
||||
|
||||
### 5.2 PD 分离在什么条件下有价值
|
||||
|
||||
| 条件 | Chatbot (有价值) | Agentic (无价值) |
|
||||
|------|-----------------|-----------------|
|
||||
| Cache hit rate | <10% | **71%** |
|
||||
| Model active params | 70B (dense) | **3B (MoE)** |
|
||||
| I/O ratio | 1-10x | **75.6x** |
|
||||
| Per-request prefill FLOPs | Very high | **Low (after cache)** |
|
||||
| KV transfer cost vs prefill cost | Negligible | **Significant** |
|
||||
|
||||
### 5.3 Agentic Workload 应该怎么优化
|
||||
|
||||
1. **Cache-aware routing** (已验证有效): 用 ongoing_tokens + prefix_cache_hit 做联合调度,
|
||||
将 APC 从 20.8% (RR) 提升到 44.7%,TPOT p90 降低 15%
|
||||
|
||||
2. **Cross-instance KV cache sharing**: 让多个 instance 共享全局 KV pool,
|
||||
进一步提升 cache hit 率接近理论 71%
|
||||
|
||||
3. **Prefix pre-warming**: 对 cold start 请求(55%,0% cache hit),
|
||||
预计算 common prefix (system prompt blocks) 并分发到所有 instance
|
||||
|
||||
4. **不同 workload 类型的差异化处理**:
|
||||
- Warm 请求 (22%, >90% cache hit, avg 1.3k new tokens): 几乎免费,任何 instance 都能处理
|
||||
- Cold 请求 (55%, 0% cache hit, avg 17.7k new tokens): prefill-heavy,需要有足够 compute
|
||||
- 可以用 request-type-aware routing 进一步优化
|
||||
130
analysis/roofline_analysis.md
Normal file
130
analysis/roofline_analysis.md
Normal file
@@ -0,0 +1,130 @@
|
||||
# Prefill 在高 KV Cache Reuse 下的计算/访存分析
|
||||
|
||||
## Model & GPU
|
||||
|
||||
```
|
||||
Qwen3-Coder-30B-A3B (MoE 128E top-8)
|
||||
48 layers, hidden=2048, heads=32, kv_heads=4 (GQA), head_dim=128
|
||||
FFN: 6144 intermediate per expert, 8 experts active per token
|
||||
Active params: ~3B per token
|
||||
|
||||
H20 GPU: 148 TFLOPS (BF16) / 4.0 TB/s HBM
|
||||
Ridge point: 37 FLOP/byte
|
||||
```
|
||||
|
||||
## 核心发现:Prefill 即使 95% reuse 仍然是 compute-bound
|
||||
|
||||
```
|
||||
SeqLen Reuse% NewTok AI (F/B) Bound vs Decode AI
|
||||
32,000 0% 32,000 23368 COMPUTE 18189x
|
||||
32,000 70% 9,600 10045 COMPUTE 7819x
|
||||
32,000 90% 3,200 3821 COMPUTE 2974x
|
||||
32,000 95% 1,600 1980 COMPUTE 1541x
|
||||
|
||||
64,000 0% 64,000 40758 COMPUTE 26813x
|
||||
64,000 70% 19,200 20610 COMPUTE 13559x
|
||||
64,000 90% 6,400 8544 COMPUTE 5621x
|
||||
64,000 95% 3,200 4549 COMPUTE 2993x
|
||||
|
||||
Decode (always):
|
||||
32,000 - 1 1.3 MEMORY 1x
|
||||
64,000 - 1 1.5 MEMORY 1x
|
||||
```
|
||||
|
||||
**关键**:
|
||||
- Decode 的 arithmetic intensity (AI) = 1.0-1.9 — 远低于 ridge point (37),始终 memory-bound
|
||||
- Prefill 即使 95% reuse (只有 5% 新 token),AI 仍然 >1000 — 远高于 ridge point,依然 compute-bound
|
||||
|
||||
## 为什么高 reuse 的 prefill 仍然是 compute-bound?
|
||||
|
||||
### 原因:Attention 的计算量与 seq_len 成正比
|
||||
|
||||
当有 95% cache reuse (seq_len=64k, new_tokens=3200):
|
||||
```
|
||||
Q projection: new_tokens × D × D → 只处理 3200 new tokens ✓
|
||||
K,V projection: new_tokens × D × D_kv → 只处理 3200 new tokens ✓
|
||||
|
||||
但 Attention score: new_tokens × seq_len × D_head × H × L
|
||||
= 3200 × 64000 × 128 × 32 × 48
|
||||
→ 仍然要对全部 64k context 做注意力计算!
|
||||
|
||||
FFN (MoE): new_tokens × 3 × D × D_ffn × 2 × K_experts × L
|
||||
= 3200 × 3 × 2048 × 6144 × 2 × 8 × 48
|
||||
→ 8 个 expert 的计算量仍然很大
|
||||
```
|
||||
|
||||
KV cache reuse 减少的是:
|
||||
- K/V projection 的计算(只算 new tokens)
|
||||
- KV 写入(只写 new tokens)
|
||||
|
||||
但 **不减少的是**:
|
||||
- Q 对全部 context 的 attention(每个 new Q 都要和所有 64k tokens 做 attention)
|
||||
- MoE FFN 的计算(每个 new token 激活 8 个 expert)
|
||||
|
||||
所以 prefill 的 FLOPs 虽然随 reuse 减少,但 **减少的是线性部分(投影),不减少的是二次部分(attention)**。
|
||||
在长 context 下,二次部分主导,使得即使 95% reuse,AI 仍远高于 ridge point。
|
||||
|
||||
## Prefill 什么时候才变成 memory-bound?
|
||||
|
||||
```
|
||||
SeqLen=32,000: new_tokens ≈ 5-10 时 (reuse > 99.97%) → AI ≈ 37
|
||||
SeqLen=64,000: new_tokens ≈ 5-10 时 → AI ≈ 37
|
||||
```
|
||||
|
||||
只有在 **近乎 100% reuse**(仅 5-10 个 new tokens)时,prefill 才接近 memory-bound。
|
||||
在实际 agentic trace 中,只有 3% 的请求达到这个程度。
|
||||
|
||||
## 对 PD 分离的影响:修正之前的分析
|
||||
|
||||
### 之前的错误结论(已修正)
|
||||
> "Prefill 大部分是 cache lookup 不是 compute"
|
||||
|
||||
这是 **错误的**。即使 70% cache reuse,prefill 的 AI 仍然是 decode 的 7000-14000 倍。
|
||||
Prefill 始终是 compute-bound,decode 始终是 memory-bound。
|
||||
|
||||
### 那为什么 PD 分离在我们的实验中没有帮助?
|
||||
|
||||
正确的解释不是 "prefill 变成了 memory-bound",而是:
|
||||
|
||||
**1. Cache reuse 大幅减少了 prefill 的绝对计算量**
|
||||
```
|
||||
无 cache: avg 33.6k tokens × prefill compute = X FLOPs
|
||||
71% cache: avg 9.4k tokens × prefill compute = 0.28X FLOPs
|
||||
```
|
||||
虽然 prefill 仍是 compute-bound,但 **总工作量只有原来的 28%**。
|
||||
在 8 instance 并行 + cache-aware routing 下,每个 instance 的 prefill 负载非常轻,
|
||||
不足以产生对 decode 的显著干扰。
|
||||
|
||||
**2. MoE 模型的 per-token compute 本身较小**
|
||||
Active params 只有 3B(全参数的 10%),单个 token 的计算量不大。
|
||||
对比 Dense 70B 模型,同样的 GPU 上 prefill-decode 干扰会严重得多。
|
||||
|
||||
**3. Cache-aware routing 的 "负载均衡" 效应**
|
||||
当请求被路由到 cache 命中率高的 instance 时,该 instance 的实际 prefill 工作量更小,
|
||||
自然减少了 P-D 争抢。这相当于 routing 层面的 "软 PD 分离"。
|
||||
|
||||
## 对比不同 workload 类型的 roofline 特征
|
||||
|
||||
```
|
||||
Prefill AI Decode AI PD-Sep 价值
|
||||
Dense 70B, Chatbot: 200-1000x 1-2x HIGH (compute-heavy P 干扰 D)
|
||||
Dense 70B, Agent: 100-500x 1-2x MEDIUM (cache reduces P load)
|
||||
MoE 30B, Chatbot: 100-500x 1-2x MEDIUM
|
||||
MoE 30B, Agent: 50-200x 1-2x LOW (small active params + cache)
|
||||
← 我们的位置
|
||||
```
|
||||
|
||||
**PD 分离的 ROI 随着 cache hit 率升高和模型 active params 减少而下降。**
|
||||
Agentic MoE 模型恰好在两个方面都不利于 PD 分离。
|
||||
|
||||
## 实际 trace 的 prefill bound 分布
|
||||
|
||||
```
|
||||
With actual trace prefix cache pattern (1000 sampled requests):
|
||||
Compute-bound prefills: 961 (96%)
|
||||
Memory-bound prefills: 37 (3%) ← 近 100% reuse 的 warm 请求
|
||||
(Decode is ALWAYS memory-bound)
|
||||
```
|
||||
|
||||
96% 的 prefill 仍然是 compute-bound,但 **absolute compute 因 cache 大幅降低**。
|
||||
这是一个 "compute-bound but lightweight" 的独特状态 —— bound 类型没变,但强度大幅降低。
|
||||
16
pyproject.toml
Normal file
16
pyproject.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[project]
|
||||
name = "agentic-kv"
|
||||
version = "0.1.0"
|
||||
description = "Trace-driven KV cache benchmarking for agentic LLM workloads"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"httpx>=0.27",
|
||||
"numpy>=1.24",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pytest"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
0
replayer/__init__.py
Normal file
0
replayer/__init__.py
Normal file
55
replayer/__main__.py
Normal file
55
replayer/__main__.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""CLI entry point: python -m replayer replay ..."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from .replay import ReplayConfig, replay_trace
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description="Trace replayer for vLLM benchmarking")
|
||||
p.add_argument("--trace", type=Path, required=True, help="Sampled trace JSONL")
|
||||
p.add_argument("--output", type=Path, required=True, help="Output metrics JSONL")
|
||||
p.add_argument("--endpoint", type=str, required=True,
|
||||
help="vLLM server URL (e.g. http://localhost:8000)")
|
||||
p.add_argument("--model", type=str, default="default", help="Model name for API")
|
||||
p.add_argument("--time-scale", type=float, default=1.0,
|
||||
help="Time compression (>1 = faster)")
|
||||
p.add_argument("--max-inflight-sessions", type=int, default=32)
|
||||
p.add_argument("--concurrency-limit", type=int, default=256)
|
||||
p.add_argument("--request-timeout", type=float, default=600.0)
|
||||
p.add_argument("--request-limit", type=int, default=None,
|
||||
help="Limit number of requests to replay")
|
||||
p.add_argument("-v", "--verbose", action="store_true")
|
||||
args = p.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
config = ReplayConfig(
|
||||
trace_path=args.trace,
|
||||
output_path=args.output,
|
||||
endpoint_url=args.endpoint.rstrip("/"),
|
||||
model_name=args.model,
|
||||
time_scale=args.time_scale,
|
||||
max_inflight_sessions=args.max_inflight_sessions,
|
||||
concurrency_limit=args.concurrency_limit,
|
||||
request_timeout_s=args.request_timeout,
|
||||
request_limit=args.request_limit,
|
||||
)
|
||||
|
||||
results = asyncio.run(replay_trace(config))
|
||||
succeeded = sum(1 for r in results if r.error is None)
|
||||
print(f"\nDone: {succeeded}/{len(results)} requests succeeded")
|
||||
print(f"Metrics: {args.output}")
|
||||
print(f"Summary: {args.output.with_suffix('.summary.json')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
107
replayer/metrics.py
Normal file
107
replayer/metrics.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Per-request metrics collection and summary reporting."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import statistics
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestMetrics:
|
||||
request_id: str
|
||||
session_id: str
|
||||
turn_id: int
|
||||
trace_timestamp_s: float
|
||||
input_length: int
|
||||
output_length: int
|
||||
request_type: str
|
||||
effective_input_length: int | None
|
||||
cached_tokens: int
|
||||
latency_s: float | None
|
||||
ttft_s: float | None
|
||||
tpot_s: float | None
|
||||
actual_output_tokens: int | None = None
|
||||
requested_output_tokens: int | None = None
|
||||
finish_reason: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class IncrementalMetricSink:
|
||||
"""Append each RequestMetrics to JSONL immediately (crash-safe)."""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
self.path = path
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text("")
|
||||
self._lock = asyncio.Lock()
|
||||
self._fh = path.open("a", encoding="utf-8", buffering=1)
|
||||
|
||||
async def append(self, metric: RequestMetrics) -> None:
|
||||
line = json.dumps(asdict(metric), sort_keys=True) + "\n"
|
||||
async with self._lock:
|
||||
self._fh.write(line)
|
||||
self._fh.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self._fh.flush()
|
||||
self._fh.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def write_summary_json(path: Path, rows: list[RequestMetrics]) -> None:
|
||||
successful = [r for r in rows if r.error is None]
|
||||
latencies = [r.latency_s for r in successful if r.latency_s is not None]
|
||||
ttfts = [r.ttft_s for r in successful if r.ttft_s is not None]
|
||||
tpots = [r.tpot_s for r in successful if r.tpot_s is not None]
|
||||
|
||||
total_input = sum(r.input_length for r in successful)
|
||||
total_cached = sum(r.cached_tokens for r in successful)
|
||||
|
||||
summary: dict[str, Any] = {
|
||||
"request_count": len(rows),
|
||||
"success_count": len(successful),
|
||||
"error_count": sum(1 for r in rows if r.error is not None),
|
||||
"latency_stats_s": _stats(latencies),
|
||||
"ttft_stats_s": _stats(ttfts),
|
||||
"tpot_stats_s": _stats(tpots),
|
||||
"cache_hit_request_count": sum(1 for r in successful if r.cached_tokens > 0),
|
||||
"total_input_tokens": total_input,
|
||||
"total_cached_tokens": total_cached,
|
||||
"prefix_cache_hit_ratio": total_cached / total_input if total_input > 0 else 0.0,
|
||||
"cached_tokens_stats": _stats([float(r.cached_tokens) for r in successful]),
|
||||
"actual_output_tokens_stats": _stats(
|
||||
[float(r.actual_output_tokens) for r in successful
|
||||
if r.actual_output_tokens is not None]
|
||||
),
|
||||
}
|
||||
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", encoding="utf-8") as fh:
|
||||
json.dump(summary, fh, indent=2, sort_keys=True)
|
||||
|
||||
|
||||
def _stats(values: list[float | None]) -> dict[str, float] | None:
|
||||
clean = [v for v in values if v is not None]
|
||||
if not clean:
|
||||
return None
|
||||
clean.sort()
|
||||
return {
|
||||
"count": float(len(clean)),
|
||||
"mean": statistics.fmean(clean),
|
||||
"p50": _percentile(clean, 0.50),
|
||||
"p90": _percentile(clean, 0.90),
|
||||
"p99": _percentile(clean, 0.99),
|
||||
}
|
||||
|
||||
|
||||
def _percentile(sorted_vals: list[float], pct: float) -> float:
|
||||
if len(sorted_vals) == 1:
|
||||
return sorted_vals[0]
|
||||
idx = round((len(sorted_vals) - 1) * pct)
|
||||
return sorted_vals[idx]
|
||||
343
replayer/replay.py
Normal file
343
replayer/replay.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""Trace replayer — send requests to vLLM following trace timing.
|
||||
|
||||
Supports both vLLM's /v1/completions (OpenAI-compatible) and /generate
|
||||
(SGLang-style) endpoints. Uses hash_ids from the trace to construct
|
||||
synthetic prompts that reproduce realistic prefix-cache hit patterns.
|
||||
|
||||
Key behaviors:
|
||||
- Per-session sequencing: turns within a session are sent in order,
|
||||
each waiting for the previous to complete before dispatching.
|
||||
- Inter-session arrival: sessions start at their trace timestamps,
|
||||
scaled by --time-scale.
|
||||
- Concurrency control: --max-inflight-sessions caps concurrent sessions;
|
||||
--concurrency-limit caps total in-flight requests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import random as _random
|
||||
|
||||
import httpx
|
||||
|
||||
from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json
|
||||
from .trace import TraceRequest, load_trace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BLOCK_SIZE = 512
|
||||
VOCAB_SIZE = 151936
|
||||
TOKEN_RANGE_START = 100
|
||||
TOKEN_RANGE_END = VOCAB_SIZE - 100
|
||||
|
||||
_block_cache: dict[int, list[int]] = {}
|
||||
|
||||
|
||||
def _hash_id_to_token_ids(hash_id: int) -> list[int]:
|
||||
"""Deterministically map a hash_id to BLOCK_SIZE token IDs."""
|
||||
if hash_id in _block_cache:
|
||||
return _block_cache[hash_id]
|
||||
rng = _random.Random(hash_id)
|
||||
ids = [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END) for _ in range(BLOCK_SIZE)]
|
||||
_block_cache[hash_id] = ids
|
||||
return ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplayConfig:
|
||||
trace_path: Path
|
||||
output_path: Path
|
||||
endpoint_url: str # comma-separated for round-robin: "http://host:8000,http://host:8001"
|
||||
time_scale: float = 1.0
|
||||
max_inflight_sessions: int = 32
|
||||
concurrency_limit: int = 256
|
||||
request_timeout_s: float = 600.0
|
||||
request_limit: int | None = None
|
||||
model_name: str = "default"
|
||||
|
||||
|
||||
def _build_prompt_token_ids(req: TraceRequest) -> list[int]:
|
||||
"""Build token IDs from hash_ids for prefix-cache-aware replay.
|
||||
|
||||
Same hash_id prefix → same token ID prefix → APC cache hit in vLLM.
|
||||
"""
|
||||
ids: list[int] = []
|
||||
for hid in req.hash_ids:
|
||||
ids.extend(_hash_id_to_token_ids(hid))
|
||||
# Pad to input_length with deterministic tokens
|
||||
pad_rng = _random.Random(req.chat_id)
|
||||
while len(ids) < req.input_length:
|
||||
ids.append(pad_rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END))
|
||||
return ids[:req.input_length]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SessionState:
|
||||
session_id: str
|
||||
turns: list[TraceRequest]
|
||||
metrics: list[RequestMetrics] = field(default_factory=list)
|
||||
|
||||
|
||||
_endpoint_counter = 0
|
||||
|
||||
|
||||
def _pick_endpoint(config: ReplayConfig) -> str:
|
||||
"""Round-robin across comma-separated endpoints."""
|
||||
global _endpoint_counter
|
||||
endpoints = [e.strip() for e in config.endpoint_url.split(",")]
|
||||
url = endpoints[_endpoint_counter % len(endpoints)]
|
||||
_endpoint_counter += 1
|
||||
return url
|
||||
|
||||
|
||||
async def _dispatch_request(
|
||||
*,
|
||||
client: httpx.AsyncClient,
|
||||
config: ReplayConfig,
|
||||
req: TraceRequest,
|
||||
prompt_token_ids: list[int],
|
||||
sem: asyncio.Semaphore,
|
||||
) -> RequestMetrics:
|
||||
"""Send one request via /v1/completions (streaming) and collect metrics."""
|
||||
endpoint = _pick_endpoint(config)
|
||||
payload = {
|
||||
"model": config.model_name,
|
||||
"prompt": prompt_token_ids,
|
||||
"max_tokens": max(1, req.output_length),
|
||||
"temperature": 0,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
|
||||
start = time.perf_counter()
|
||||
ttft_s = None
|
||||
n_output = 0
|
||||
cached_tokens = 0
|
||||
finish_reason = None
|
||||
err = None
|
||||
token_times: list[float] = []
|
||||
|
||||
async with sem:
|
||||
try:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{endpoint}/v1/completions",
|
||||
json=payload,
|
||||
timeout=config.request_timeout_s,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for raw_line in resp.aiter_lines():
|
||||
if not raw_line or not raw_line.startswith("data:"):
|
||||
continue
|
||||
data = raw_line[5:].strip()
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
now = time.perf_counter()
|
||||
if ttft_s is None:
|
||||
ttft_s = now - start
|
||||
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("text", "")
|
||||
if delta:
|
||||
token_times.append(now)
|
||||
fr = choices[0].get("finish_reason")
|
||||
if fr:
|
||||
finish_reason = fr
|
||||
|
||||
usage = chunk.get("usage")
|
||||
if usage:
|
||||
n_output = usage.get("completion_tokens", n_output)
|
||||
cached_tokens = _extract_cached_tokens(usage)
|
||||
except Exception as exc:
|
||||
err = repr(exc)[:300]
|
||||
|
||||
end = time.perf_counter()
|
||||
e2e = end - start
|
||||
if n_output == 0 and token_times:
|
||||
n_output = len(token_times)
|
||||
|
||||
tpot = 0.0
|
||||
if len(token_times) > 1:
|
||||
inter_token = [token_times[i+1] - token_times[i]
|
||||
for i in range(len(token_times) - 1)]
|
||||
tpot = sum(inter_token) / len(inter_token)
|
||||
|
||||
return RequestMetrics(
|
||||
request_id=req.request_id,
|
||||
session_id=req.session_id,
|
||||
turn_id=req.turn_id,
|
||||
trace_timestamp_s=req.timestamp_s,
|
||||
input_length=req.input_length,
|
||||
output_length=req.output_length,
|
||||
request_type=req.request_type,
|
||||
effective_input_length=len(prompt_token_ids),
|
||||
cached_tokens=cached_tokens,
|
||||
latency_s=e2e,
|
||||
ttft_s=ttft_s,
|
||||
tpot_s=tpot,
|
||||
actual_output_tokens=n_output,
|
||||
requested_output_tokens=req.output_length,
|
||||
finish_reason=finish_reason,
|
||||
error=err,
|
||||
)
|
||||
|
||||
|
||||
def _extract_cached_tokens(usage: dict) -> int:
|
||||
ct = 0
|
||||
details = usage.get("prompt_tokens_details")
|
||||
if isinstance(details, dict):
|
||||
ct = details.get("cached_tokens", 0) or 0
|
||||
if ct == 0:
|
||||
ct = usage.get("cached_tokens", 0) or 0
|
||||
return int(ct)
|
||||
|
||||
|
||||
async def _run_session(
|
||||
*,
|
||||
state: _SessionState,
|
||||
config: ReplayConfig,
|
||||
client: httpx.AsyncClient,
|
||||
session_sem: asyncio.Semaphore,
|
||||
request_sem: asyncio.Semaphore,
|
||||
earliest_ts: float,
|
||||
sweep_start: float,
|
||||
sink: IncrementalMetricSink,
|
||||
) -> list[RequestMetrics]:
|
||||
async with session_sem:
|
||||
# Wait until this session's start time
|
||||
offset = (state.turns[0].timestamp_s - earliest_ts) / config.time_scale
|
||||
wait = offset - (time.perf_counter() - sweep_start)
|
||||
if wait > 0:
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
for req in state.turns:
|
||||
# Intra-session: wait for turn's relative offset
|
||||
if req != state.turns[0]:
|
||||
target = (req.timestamp_s - state.turns[0].timestamp_s) / config.time_scale
|
||||
elapsed = time.perf_counter() - sweep_start - offset
|
||||
if elapsed < target:
|
||||
await asyncio.sleep(target - elapsed)
|
||||
|
||||
token_ids = _build_prompt_token_ids(req)
|
||||
metric = await _dispatch_request(
|
||||
client=client, config=config, req=req,
|
||||
prompt_token_ids=token_ids, sem=request_sem,
|
||||
)
|
||||
state.metrics.append(metric)
|
||||
await sink.append(metric)
|
||||
|
||||
return state.metrics
|
||||
|
||||
|
||||
async def _snapshot_prefix_cache_metrics(url_csv: str) -> dict[str, float]:
|
||||
"""Scrape vLLM /metrics for prefix cache counters (aggregated across endpoints)."""
|
||||
total = {"queries": 0.0, "hits": 0.0}
|
||||
endpoints = [e.strip() for e in url_csv.split(",")]
|
||||
async with httpx.AsyncClient(timeout=10) as c:
|
||||
for url in endpoints:
|
||||
try:
|
||||
r = await c.get(f"{url}/metrics")
|
||||
for line in r.text.split("\n"):
|
||||
if line.startswith("vllm:prefix_cache_queries_total"):
|
||||
total["queries"] += float(line.split()[-1])
|
||||
elif line.startswith("vllm:prefix_cache_hits_total"):
|
||||
total["hits"] += float(line.split()[-1])
|
||||
except Exception:
|
||||
pass
|
||||
return total
|
||||
|
||||
|
||||
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
|
||||
"""Main entry: load trace, replay against endpoint, return metrics."""
|
||||
requests = load_trace(config.trace_path, request_limit=config.request_limit)
|
||||
if not requests:
|
||||
return []
|
||||
|
||||
by_session: dict[str, list[TraceRequest]] = defaultdict(list)
|
||||
for r in requests:
|
||||
by_session[r.session_id].append(r)
|
||||
for sid in by_session:
|
||||
by_session[sid].sort(key=lambda r: (r.turn_id, r.timestamp_s))
|
||||
|
||||
sessions = sorted(by_session.items(), key=lambda kv: kv[1][0].timestamp_s)
|
||||
earliest_ts = sessions[0][1][0].timestamp_s
|
||||
|
||||
session_sem = asyncio.Semaphore(config.max_inflight_sessions)
|
||||
request_sem = asyncio.Semaphore(config.concurrency_limit)
|
||||
|
||||
sink = IncrementalMetricSink(config.output_path)
|
||||
|
||||
n_sessions = len(sessions)
|
||||
n_requests = len(requests)
|
||||
logger.info("Replaying %d sessions (%d requests), time_scale=%.1f",
|
||||
n_sessions, n_requests, config.time_scale)
|
||||
|
||||
pre_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
|
||||
sweep_start = time.perf_counter()
|
||||
|
||||
try:
|
||||
limits = httpx.Limits(
|
||||
max_connections=2000,
|
||||
max_keepalive_connections=500,
|
||||
keepalive_expiry=30.0,
|
||||
)
|
||||
async with httpx.AsyncClient(
|
||||
timeout=config.request_timeout_s,
|
||||
trust_env=False,
|
||||
limits=limits,
|
||||
) as client:
|
||||
tasks = [
|
||||
asyncio.create_task(_run_session(
|
||||
state=_SessionState(session_id=sid, turns=turns),
|
||||
config=config, client=client,
|
||||
session_sem=session_sem, request_sem=request_sem,
|
||||
earliest_ts=earliest_ts, sweep_start=sweep_start,
|
||||
sink=sink,
|
||||
))
|
||||
for sid, turns in sessions
|
||||
]
|
||||
all_results = await asyncio.gather(*tasks)
|
||||
finally:
|
||||
sink.close()
|
||||
|
||||
sweep_elapsed = time.perf_counter() - sweep_start
|
||||
post_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
|
||||
|
||||
flat = [m for group in all_results for m in group]
|
||||
summary_path = config.output_path.with_suffix(".summary.json")
|
||||
write_summary_json(summary_path, flat)
|
||||
|
||||
# Compute aggregate prefix cache hit ratio from /metrics deltas
|
||||
delta_queries = post_metrics.get("queries", 0) - pre_metrics.get("queries", 0)
|
||||
delta_hits = post_metrics.get("hits", 0) - pre_metrics.get("hits", 0)
|
||||
hit_ratio = delta_hits / delta_queries if delta_queries > 0 else 0.0
|
||||
|
||||
logger.info("Done: %d/%d succeeded in %.1fs", sum(1 for m in flat if m.error is None), len(flat), sweep_elapsed)
|
||||
logger.info("Prefix cache: %.1f%% hit ratio (%d/%d tokens)",
|
||||
hit_ratio * 100, int(delta_hits), int(delta_queries))
|
||||
|
||||
# Append cache stats to summary
|
||||
import json as _json
|
||||
summary = _json.loads(summary_path.read_text())
|
||||
summary["prefix_cache_queries_tokens"] = int(delta_queries)
|
||||
summary["prefix_cache_hits_tokens"] = int(delta_hits)
|
||||
summary["prefix_cache_hit_ratio"] = hit_ratio
|
||||
summary["wall_clock_s"] = sweep_elapsed
|
||||
summary_path.write_text(_json.dumps(summary, indent=2, sort_keys=True))
|
||||
|
||||
logger.info("Summary written to %s", summary_path)
|
||||
return flat
|
||||
84
replayer/trace.py
Normal file
84
replayer/trace.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Trace data structures and loader for the Ali agentic-coder trace format.
|
||||
|
||||
Trace format (one JSON per line):
|
||||
chat_id, parent_chat_id, timestamp, input_length, output_length,
|
||||
type, turn, hash_ids[]
|
||||
|
||||
Sessions are derived from parent_chat_id chains:
|
||||
- parent_chat_id == -1 → new session root
|
||||
- parent_chat_id >= 0 → belongs to the same session as the parent
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TraceRequest:
|
||||
request_id: str
|
||||
session_id: str
|
||||
chat_id: int
|
||||
parent_chat_id: int
|
||||
timestamp_s: float
|
||||
input_length: int
|
||||
output_length: int
|
||||
request_type: str
|
||||
turn_id: int
|
||||
hash_ids: tuple[int, ...]
|
||||
|
||||
|
||||
def load_trace(
|
||||
path: Path,
|
||||
*,
|
||||
request_limit: int | None = None,
|
||||
) -> list[TraceRequest]:
|
||||
"""Load trace and resolve session IDs from parent_chat_id chains."""
|
||||
chat_to_session: dict[int, str] = {}
|
||||
requests: list[TraceRequest] = []
|
||||
|
||||
with path.open("r", encoding="utf-8") as fh:
|
||||
for idx, line in enumerate(fh):
|
||||
if request_limit is not None and len(requests) >= request_limit:
|
||||
break
|
||||
row = json.loads(line)
|
||||
chat_id = int(row["chat_id"])
|
||||
parent_chat_id = int(row["parent_chat_id"])
|
||||
|
||||
if "session_id" in row:
|
||||
session_id = str(row["session_id"])
|
||||
else:
|
||||
session_id = _resolve_session_id(
|
||||
chat_id, parent_chat_id, chat_to_session,
|
||||
)
|
||||
chat_to_session[chat_id] = session_id
|
||||
|
||||
requests.append(TraceRequest(
|
||||
request_id=f"{session_id}:{row['turn']}:{chat_id}:{idx}",
|
||||
session_id=session_id,
|
||||
chat_id=chat_id,
|
||||
parent_chat_id=parent_chat_id,
|
||||
timestamp_s=float(row["timestamp"]),
|
||||
input_length=int(row["input_length"]),
|
||||
output_length=int(row["output_length"]),
|
||||
request_type=str(row["type"]),
|
||||
turn_id=int(row["turn"]),
|
||||
hash_ids=tuple(int(h) for h in row.get("hash_ids", [])),
|
||||
))
|
||||
|
||||
return requests
|
||||
|
||||
|
||||
def _resolve_session_id(
|
||||
chat_id: int,
|
||||
parent_chat_id: int,
|
||||
chat_to_session: dict[int, str],
|
||||
) -> str:
|
||||
if parent_chat_id < 0:
|
||||
session_id = str(chat_id)
|
||||
else:
|
||||
session_id = chat_to_session.get(parent_chat_id, str(parent_chat_id))
|
||||
chat_to_session[chat_id] = session_id
|
||||
return session_id
|
||||
196
scripts/analyze_cache_hit.py
Normal file
196
scripts/analyze_cache_hit.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Analyze theoretical vs actual KV cache hit ratio for the agentic trace."""
|
||||
import json
|
||||
from collections import Counter
|
||||
|
||||
rows = [json.loads(l) for l in open("traces/sampled_1000req_seed42.jsonl")]
|
||||
rows.sort(key=lambda r: float(r["timestamp"]))
|
||||
|
||||
BLOCK_SIZE = 512
|
||||
|
||||
# === 1. Theoretical max: infinite cache, single instance ===
|
||||
total_tokens = 0
|
||||
total_cached = 0
|
||||
seen_blocks = set()
|
||||
per_req = []
|
||||
|
||||
for r in rows:
|
||||
input_len = r["input_length"]
|
||||
hash_ids = r.get("hash_ids", [])
|
||||
total_tokens += input_len
|
||||
|
||||
cached_blocks = 0
|
||||
prefix_broken = False
|
||||
for hid in hash_ids:
|
||||
if not prefix_broken and hid in seen_blocks:
|
||||
cached_blocks += 1
|
||||
else:
|
||||
prefix_broken = True
|
||||
|
||||
cached_tokens = cached_blocks * BLOCK_SIZE
|
||||
total_cached += cached_tokens
|
||||
for hid in hash_ids:
|
||||
seen_blocks.add(hid)
|
||||
|
||||
per_req.append({
|
||||
"input_length": input_len,
|
||||
"cached_tokens": cached_tokens,
|
||||
"new_tokens": max(0, input_len - cached_tokens),
|
||||
"ratio": cached_tokens / input_len if input_len > 0 else 0,
|
||||
})
|
||||
|
||||
sep = "=" * 70
|
||||
print(sep)
|
||||
print(" THEORETICAL KV CACHE HIT (infinite cache, single instance)")
|
||||
print(sep)
|
||||
print(f" Total input tokens: {total_tokens:>14,}")
|
||||
print(f" Cacheable (prefix hit): {total_cached:>14,} ({total_cached*100//total_tokens}%)")
|
||||
print(f" Must prefill (new): {total_tokens-total_cached:>14,} ({(total_tokens-total_cached)*100//total_tokens}%)")
|
||||
|
||||
ratios = sorted([s["ratio"] for s in per_req if s["input_length"] > 0])
|
||||
new_tokens = sorted([s["new_tokens"] for s in per_req if s["input_length"] > 0])
|
||||
p = lambda v, q: v[min(int(q*len(v)), len(v)-1)]
|
||||
|
||||
print(f"\n Per-request cache hit ratio:")
|
||||
print(f" p10={p(ratios,.1)*100:.1f}% p50={p(ratios,.5)*100:.1f}% p90={p(ratios,.9)*100:.1f}% mean={sum(ratios)/len(ratios)*100:.1f}%")
|
||||
high = sum(1 for r in ratios if r > 0.5)
|
||||
very_high = sum(1 for r in ratios if r > 0.9)
|
||||
zero = sum(1 for r in ratios if r == 0)
|
||||
print(f" 0% hit (cold start): {zero} ({zero*100//len(ratios)}%)")
|
||||
print(f" >50% hit: {high} ({high*100//len(ratios)}%)")
|
||||
print(f" >90% hit: {very_high} ({very_high*100//len(ratios)}%)")
|
||||
|
||||
print(f"\n Actual new tokens to prefill per request:")
|
||||
print(f" p10={p(new_tokens,.1):>7,} p50={p(new_tokens,.5):>7,} p90={p(new_tokens,.9):>7,} max={max(new_tokens):>7,}")
|
||||
|
||||
# === 2. 4-instance split (simulating DP=4 or 4 prefill instances) ===
|
||||
print(f"\n{sep}")
|
||||
print(" 4-INSTANCE SPLIT (round-robin, per-instance cache)")
|
||||
print(sep)
|
||||
|
||||
instance_seen = [set() for _ in range(4)]
|
||||
inst_total = [0]*4
|
||||
inst_cached = [0]*4
|
||||
|
||||
for i, r in enumerate(rows):
|
||||
inst = i % 4
|
||||
input_len = r["input_length"]
|
||||
hash_ids = r.get("hash_ids", [])
|
||||
inst_total[inst] += input_len
|
||||
|
||||
cached_blocks = 0
|
||||
prefix_broken = False
|
||||
for hid in hash_ids:
|
||||
if not prefix_broken and hid in instance_seen[inst]:
|
||||
cached_blocks += 1
|
||||
else:
|
||||
prefix_broken = True
|
||||
|
||||
inst_cached[inst] += cached_blocks * BLOCK_SIZE
|
||||
for hid in hash_ids:
|
||||
instance_seen[inst].add(hid)
|
||||
|
||||
rr_total = sum(inst_total)
|
||||
rr_cached = sum(inst_cached)
|
||||
print(f" Cache hit ratio (RR): {rr_cached*100//rr_total}%")
|
||||
|
||||
# === 3. Cache-aware routing (route to instance with best prefix match) ===
|
||||
print(f"\n{sep}")
|
||||
print(" 4-INSTANCE CACHE-AWARE ROUTING")
|
||||
print(sep)
|
||||
|
||||
ca_seen = [set() for _ in range(4)]
|
||||
ca_total = [0]*4
|
||||
ca_cached = [0]*4
|
||||
|
||||
for r in rows:
|
||||
input_len = r["input_length"]
|
||||
hash_ids = r.get("hash_ids", [])
|
||||
|
||||
# Pick instance with most prefix blocks cached
|
||||
best_inst = 0
|
||||
best_hit = 0
|
||||
for inst in range(4):
|
||||
hit = 0
|
||||
for hid in hash_ids:
|
||||
if hid in ca_seen[inst]:
|
||||
hit += 1
|
||||
else:
|
||||
break
|
||||
if hit > best_hit:
|
||||
best_hit = hit
|
||||
best_inst = inst
|
||||
|
||||
ca_total[best_inst] += input_len
|
||||
ca_cached[best_inst] += best_hit * BLOCK_SIZE
|
||||
for hid in hash_ids:
|
||||
ca_seen[best_inst].add(hid)
|
||||
|
||||
ca_total_sum = sum(ca_total)
|
||||
ca_cached_sum = sum(ca_cached)
|
||||
print(f" Cache hit ratio: {ca_cached_sum*100//ca_total_sum}%")
|
||||
print(f" vs RR: {rr_cached*100//rr_total}% -> {ca_cached_sum*100//ca_total_sum}% (+{(ca_cached_sum-rr_cached)*100//rr_total}pp)")
|
||||
|
||||
# === 4. Session structure analysis ===
|
||||
print(f"\n{sep}")
|
||||
print(" SESSION & MULTI-TURN ANALYSIS")
|
||||
print(sep)
|
||||
|
||||
sessions = {}
|
||||
chat_to_session = {}
|
||||
for r in rows:
|
||||
cid = int(r["chat_id"])
|
||||
pid = int(r["parent_chat_id"])
|
||||
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)))
|
||||
chat_to_session[cid] = str(sid)
|
||||
sessions.setdefault(str(sid), []).append(r)
|
||||
|
||||
multi = {k: v for k, v in sessions.items() if len(v) > 1}
|
||||
single = {k: v for k, v in sessions.items() if len(v) == 1}
|
||||
|
||||
print(f" Sessions: {len(sessions)} total, {len(multi)} multi-turn ({len(multi)*100//len(sessions)}%)")
|
||||
|
||||
# Multi-turn: cache hit in turn 2+
|
||||
mt_new = 0
|
||||
mt_reuse = 0
|
||||
for sid, turns in multi.items():
|
||||
turns.sort(key=lambda r: r["turn"])
|
||||
prev_blocks = set()
|
||||
for t in turns:
|
||||
hids = t.get("hash_ids", [])
|
||||
for hid in hids:
|
||||
if hid in prev_blocks:
|
||||
mt_reuse += BLOCK_SIZE
|
||||
else:
|
||||
mt_new += BLOCK_SIZE
|
||||
prev_blocks.add(hid)
|
||||
|
||||
mt_total_tok = mt_new + mt_reuse
|
||||
print(f" Multi-turn intra-session reuse: {mt_reuse*100//mt_total_tok}% of tokens")
|
||||
print(f" (Turn 2+ reuses KV from prior turns in same session)")
|
||||
|
||||
# Single-turn: cross-session sharing via system prompt
|
||||
block_freq = Counter()
|
||||
for r in rows:
|
||||
for hid in r.get("hash_ids", []):
|
||||
block_freq[hid] += 1
|
||||
|
||||
shared = {k: v for k, v in block_freq.items() if v > 1}
|
||||
top = block_freq.most_common(5)
|
||||
print(f"\n Cross-session block sharing:")
|
||||
print(f" Unique blocks: {len(block_freq):,}")
|
||||
print(f" Shared (ref>1): {len(shared):,} ({len(shared)*100//len(block_freq)}%)")
|
||||
print(f" Top-5 block ref counts: {[c for _,c in top]}")
|
||||
print(f" (Shared blocks = system prompt / common code context)")
|
||||
|
||||
# === 5. Implication for PD separation ===
|
||||
print(f"\n{sep}")
|
||||
print(" IMPLICATION FOR PD SEPARATION")
|
||||
print(sep)
|
||||
actual_prefill_pct = (total_tokens - total_cached) * 100 // total_tokens
|
||||
print(f" With perfect caching, only {actual_prefill_pct}% of tokens need actual prefill compute.")
|
||||
print(f" The remaining {100-actual_prefill_pct}% are prefix cache hits (skip prefill, reuse KV).")
|
||||
print(f" This means PD separation's prefill overhead is much smaller than it appears:")
|
||||
print(f" - Nominal avg input: {total_tokens//len(rows):,} tokens/request")
|
||||
new_per_req = sorted([s["new_tokens"] for s in per_req if s["input_length"] > 0])
|
||||
print(f" - Actual avg prefill: {sum(new_per_req)//len(new_per_req):,} tokens/request (after cache hit)")
|
||||
print(f" - KV transfer size is also reduced (only transfer new blocks)")
|
||||
163
scripts/analyze_trace.py
Normal file
163
scripts/analyze_trace.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Analyze trace patterns to assess PD separation benefit.
|
||||
|
||||
Computes metrics relevant to deciding PD-combined vs PD-separated:
|
||||
- Input/output token ratio (high ratio = prefill-heavy → PD sep benefits)
|
||||
- Prefix sharing density (high sharing → benefits from shared KV cache)
|
||||
- Session length distribution (multi-turn = more prefix reuse)
|
||||
- Arrival burstiness (bursty prefill → PD sep can absorb spikes)
|
||||
- Compute-intensity ratio: prefill FLOP share vs decode FLOP share
|
||||
|
||||
Usage:
|
||||
python scripts/analyze_trace.py --input traces/sampled_1000req_seed42.jsonl
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
import statistics
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
p.add_argument("--input", type=Path, required=True)
|
||||
args = p.parse_args()
|
||||
|
||||
rows = []
|
||||
with args.input.open() as fh:
|
||||
for line in fh:
|
||||
rows.append(json.loads(line))
|
||||
|
||||
# Session structure
|
||||
sessions: dict[str, list[dict]] = collections.OrderedDict()
|
||||
chat_to_session: dict[int, str] = {}
|
||||
for r in rows:
|
||||
cid = int(r["chat_id"])
|
||||
pid = int(r["parent_chat_id"])
|
||||
sid = r.get("session_id")
|
||||
if sid is None:
|
||||
sid = str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))
|
||||
chat_to_session[cid] = str(sid)
|
||||
sessions.setdefault(str(sid), []).append(r)
|
||||
|
||||
n_sessions = len(sessions)
|
||||
turns_per_session = [len(v) for v in sessions.values()]
|
||||
multi_turn = sum(1 for t in turns_per_session if t > 1)
|
||||
|
||||
input_lens = [r["input_length"] for r in rows]
|
||||
output_lens = [r["output_length"] for r in rows]
|
||||
total_input = sum(input_lens)
|
||||
total_output = sum(output_lens)
|
||||
|
||||
print("=" * 60)
|
||||
print("Trace Pattern Analysis for PD Separation Decision")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. Input/Output ratio
|
||||
io_ratio = total_input / max(total_output, 1)
|
||||
print(f"\n1. Input/Output Token Ratio")
|
||||
print(f" Total input tokens: {total_input:>12,}")
|
||||
print(f" Total output tokens: {total_output:>12,}")
|
||||
print(f" I/O ratio: {io_ratio:>12.1f}x")
|
||||
print(f" → {'STRONGLY' if io_ratio > 50 else 'Moderately' if io_ratio > 10 else 'Weakly'} prefill-heavy")
|
||||
|
||||
# 2. Prefill compute share
|
||||
# Approximate: prefill FLOP ∝ input_length, decode FLOP ∝ output_length * input_length
|
||||
# More precisely: prefill dominates when input >> output
|
||||
prefill_share = total_input / (total_input + total_output)
|
||||
print(f"\n2. Compute Split (token count proxy)")
|
||||
print(f" Prefill share: {prefill_share*100:.1f}%")
|
||||
print(f" Decode share: {(1-prefill_share)*100:.1f}%")
|
||||
|
||||
# 3. Session structure
|
||||
print(f"\n3. Session Structure")
|
||||
print(f" Sessions: {n_sessions}")
|
||||
print(f" Requests: {len(rows)}")
|
||||
print(f" Multi-turn: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)")
|
||||
print(f" Turns/sess: min={min(turns_per_session)} max={max(turns_per_session)} "
|
||||
f"avg={statistics.fmean(turns_per_session):.1f}")
|
||||
|
||||
# 4. Prefix sharing
|
||||
all_hash_ids = set()
|
||||
per_request_hashes = []
|
||||
for r in rows:
|
||||
hids = set(r.get("hash_ids", []))
|
||||
per_request_hashes.append(hids)
|
||||
all_hash_ids.update(hids)
|
||||
|
||||
hash_refcount = collections.Counter()
|
||||
for hids in per_request_hashes:
|
||||
for h in hids:
|
||||
hash_refcount[h] += 1
|
||||
|
||||
shared_blocks = sum(1 for h, c in hash_refcount.items() if c > 1)
|
||||
total_blocks = len(all_hash_ids)
|
||||
block_reuse = shared_blocks / max(total_blocks, 1)
|
||||
avg_refcount = statistics.fmean(hash_refcount.values()) if hash_refcount else 0
|
||||
|
||||
print(f"\n4. Prefix Block Sharing")
|
||||
print(f" Unique blocks: {total_blocks:>10,}")
|
||||
print(f" Shared (ref>1): {shared_blocks:>10,} ({block_reuse*100:.1f}%)")
|
||||
print(f" Avg refcount: {avg_refcount:>10.2f}")
|
||||
print(f" → {'High' if block_reuse > 0.3 else 'Moderate' if block_reuse > 0.1 else 'Low'} prefix reuse potential")
|
||||
|
||||
# 5. Input length distribution
|
||||
input_sorted = sorted(input_lens)
|
||||
pct = lambda q: input_sorted[min(int(q * len(input_sorted)), len(input_sorted) - 1)]
|
||||
print(f"\n5. Input Length Distribution")
|
||||
print(f" p10={pct(0.1):>8,} p50={pct(0.5):>8,} p90={pct(0.9):>8,} max={max(input_lens):>8,}")
|
||||
long_context = sum(1 for l in input_lens if l > 32000)
|
||||
print(f" Requests >32k tokens: {long_context} ({long_context/len(rows)*100:.1f}%)")
|
||||
|
||||
# 6. Arrival pattern
|
||||
timestamps = sorted(float(r["timestamp"]) for r in rows)
|
||||
span = timestamps[-1] - timestamps[0]
|
||||
avg_rate = len(rows) / max(span, 0.001)
|
||||
|
||||
# Burstiness: coefficient of variation of inter-arrival times
|
||||
inter_arrivals = [timestamps[i+1] - timestamps[i] for i in range(len(timestamps) - 1)]
|
||||
inter_arrivals = [t for t in inter_arrivals if t > 0]
|
||||
if inter_arrivals:
|
||||
cv = statistics.stdev(inter_arrivals) / statistics.fmean(inter_arrivals)
|
||||
else:
|
||||
cv = 0
|
||||
print(f"\n6. Arrival Pattern")
|
||||
print(f" Span: {span:.1f}s ({span/60:.1f} min)")
|
||||
print(f" Avg rate: {avg_rate:.2f} req/s")
|
||||
print(f" Burstiness (CoV): {cv:.2f}")
|
||||
print(f" → {'Bursty' if cv > 1.5 else 'Moderate' if cv > 0.8 else 'Steady'} arrival pattern")
|
||||
|
||||
# Summary
|
||||
print(f"\n{'=' * 60}")
|
||||
print("Summary: PD Separation Recommendation")
|
||||
print(f"{'=' * 60}")
|
||||
factors = []
|
||||
if io_ratio > 50:
|
||||
factors.append("Very high I/O ratio (prefill-dominated)")
|
||||
elif io_ratio > 10:
|
||||
factors.append("High I/O ratio")
|
||||
if block_reuse > 0.1:
|
||||
factors.append(f"Significant prefix reuse ({block_reuse*100:.0f}% shared blocks)")
|
||||
if long_context / len(rows) > 0.3:
|
||||
factors.append(f"Many long-context requests ({long_context/len(rows)*100:.0f}%)")
|
||||
if cv > 1.0:
|
||||
factors.append("Bursty arrivals (PD sep absorbs prefill spikes)")
|
||||
|
||||
if len(factors) >= 2:
|
||||
print("→ RECOMMEND PD separation:")
|
||||
elif len(factors) == 1:
|
||||
print("→ PD separation MAY help:")
|
||||
else:
|
||||
print("→ PD separation likely NOT beneficial:")
|
||||
|
||||
for f in factors:
|
||||
print(f" • {f}")
|
||||
if not factors:
|
||||
print(" • No strong indicators for PD separation benefit")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
280
scripts/cache_aware_proxy.py
Normal file
280
scripts/cache_aware_proxy.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""Unified cache-aware + token-level load-balanced global scheduler.
|
||||
|
||||
Supports two modes:
|
||||
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
|
||||
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
|
||||
|
||||
Routing policy (same for both modes):
|
||||
score = ongoing_tokens / avg_ongoing - ALPHA * cache_hit_ratio
|
||||
Normalized load prevents "rich get richer"; cache bonus gives affinity.
|
||||
Session affinity: multi-turn sessions stick to same instance.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
BLOCK_SIZE = 512
|
||||
CACHE_HIT_ALPHA = 1.0 # weight for cache bonus in scoring
|
||||
|
||||
|
||||
class InstanceState:
|
||||
def __init__(self, url: str, bootstrap_port: int | None = None):
|
||||
self.url = url
|
||||
self.bootstrap_port = bootstrap_port
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=None, base_url=url,
|
||||
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
|
||||
)
|
||||
self.ongoing_tokens = 0
|
||||
self.engine_id: dict[int, str] = {}
|
||||
self.dp_size = 1
|
||||
self.cached_blocks: set[int] = set()
|
||||
|
||||
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
|
||||
if not token_ids or len(token_ids) < BLOCK_SIZE:
|
||||
return 0
|
||||
hit = 0
|
||||
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
|
||||
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
|
||||
if bh in self.cached_blocks:
|
||||
hit += BLOCK_SIZE
|
||||
else:
|
||||
break
|
||||
return hit
|
||||
|
||||
def record_prefix(self, token_ids: list[int] | None):
|
||||
if not token_ids:
|
||||
return
|
||||
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
|
||||
self.cached_blocks.add(hash(tuple(token_ids[i:i + BLOCK_SIZE])))
|
||||
if len(self.cached_blocks) > 200000:
|
||||
self.cached_blocks = set(list(self.cached_blocks)[-100000:])
|
||||
|
||||
|
||||
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
|
||||
session_id: str | None, input_length: int,
|
||||
affinity: dict[str, int]) -> tuple[InstanceState, int]:
|
||||
"""Normalized load - cache bonus scoring."""
|
||||
if session_id and session_id in affinity:
|
||||
idx = affinity[session_id]
|
||||
if idx < len(instances):
|
||||
return instances[idx], idx
|
||||
|
||||
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
|
||||
best_idx, best_score = 0, float("inf")
|
||||
for i, inst in enumerate(instances):
|
||||
cache_hit = inst.estimate_cache_hit(token_ids)
|
||||
cache_ratio = cache_hit / input_length if input_length > 0 else 0.0
|
||||
score = inst.ongoing_tokens / avg_load - CACHE_HIT_ALPHA * cache_ratio
|
||||
if score < best_score:
|
||||
best_score = score
|
||||
best_idx = i
|
||||
|
||||
if session_id:
|
||||
affinity[session_id] = best_idx
|
||||
return instances[best_idx], best_idx
|
||||
|
||||
|
||||
global_args = None
|
||||
combined_instances: list[InstanceState] = []
|
||||
prefill_instances: list[InstanceState] = []
|
||||
decode_instances: list[InstanceState] = []
|
||||
session_affinity: dict[str, int] = {}
|
||||
is_pd_sep = False
|
||||
|
||||
|
||||
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
|
||||
for inst in instances:
|
||||
if inst.bootstrap_port is None:
|
||||
continue
|
||||
while True:
|
||||
try:
|
||||
await inst.client.get("/health")
|
||||
except Exception:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
parsed = urllib.parse.urlparse(str(inst.client.base_url))
|
||||
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
|
||||
resp = await inst.client.get(url)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
for dp_rank, dp_entry in data.items():
|
||||
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
|
||||
inst.dp_size = len(data)
|
||||
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
|
||||
break
|
||||
ready.set()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global is_pd_sep
|
||||
app.state.ready = asyncio.Event()
|
||||
|
||||
if global_args.combined:
|
||||
is_pd_sep = False
|
||||
for url in global_args.combined:
|
||||
combined_instances.append(InstanceState(url))
|
||||
app.state.ready.set()
|
||||
print(f"Combined mode: {len(combined_instances)} instances")
|
||||
else:
|
||||
is_pd_sep = True
|
||||
for url, bp in global_args.prefill:
|
||||
prefill_instances.append(InstanceState(url, bp))
|
||||
for url in global_args.decode:
|
||||
decode_instances.append(InstanceState(url))
|
||||
asyncio.create_task(init_prefill_bootstrap(prefill_instances, app.state.ready))
|
||||
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
|
||||
|
||||
yield
|
||||
for inst in combined_instances + prefill_instances + decode_instances:
|
||||
await inst.client.aclose()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def handle_completions(request: Request):
|
||||
return await _handle(request, "/v1/completions")
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def handle_chat(request: Request):
|
||||
return await _handle(request, "/v1/chat/completions")
|
||||
|
||||
|
||||
async def _handle(request: Request, api: str):
|
||||
if not app.state.ready.is_set():
|
||||
raise HTTPException(status_code=503, detail="Service Unavailable")
|
||||
|
||||
req_data = await request.json()
|
||||
request_id = str(uuid.uuid4())
|
||||
prompt = req_data.get("prompt")
|
||||
token_ids = prompt if isinstance(prompt, list) else None
|
||||
input_length = len(token_ids) if token_ids else 0
|
||||
session_id = request.headers.get("X-Session-Id")
|
||||
|
||||
headers = {"X-Request-Id": request_id}
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
if is_pd_sep:
|
||||
return await _handle_pd_sep(api, req_data, request_id, token_ids,
|
||||
input_length, session_id, headers)
|
||||
else:
|
||||
return await _handle_combined(api, req_data, token_ids,
|
||||
input_length, session_id, headers)
|
||||
|
||||
|
||||
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
|
||||
"""Combined mode: route to best instance, send normal request."""
|
||||
inst, idx = pick_instance(combined_instances, token_ids, session_id,
|
||||
input_length, session_affinity)
|
||||
inst.ongoing_tokens += input_length
|
||||
|
||||
async def generate():
|
||||
try:
|
||||
async with inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
yield chunk
|
||||
inst.record_prefix(token_ids)
|
||||
finally:
|
||||
inst.ongoing_tokens -= input_length
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
||||
session_id, headers):
|
||||
"""PD-Sep mode: await prefill, then stream decode."""
|
||||
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
|
||||
input_length, session_affinity)
|
||||
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
|
||||
|
||||
# Await prefill
|
||||
p_inst.ongoing_tokens += input_length
|
||||
try:
|
||||
prefill_data = req_data.copy()
|
||||
prefill_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": True, "do_remote_prefill": False,
|
||||
"transfer_id": f"xfer-{request_id}",
|
||||
}
|
||||
prefill_data["stream"] = False
|
||||
prefill_data["max_tokens"] = 1
|
||||
prefill_data.pop("max_completion_tokens", None)
|
||||
prefill_data.pop("stream_options", None)
|
||||
|
||||
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||||
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
|
||||
resp.raise_for_status()
|
||||
await resp.aclose()
|
||||
p_inst.record_prefix(token_ids)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
|
||||
finally:
|
||||
p_inst.ongoing_tokens -= input_length
|
||||
|
||||
# Stream decode
|
||||
d_inst.ongoing_tokens += input_length
|
||||
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
|
||||
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
|
||||
|
||||
decode_data = req_data.copy()
|
||||
decode_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": False, "do_remote_prefill": True,
|
||||
"remote_bootstrap_addr": bootstrap_addr,
|
||||
"remote_engine_id": p_inst.engine_id.get(0, ""),
|
||||
"transfer_id": f"xfer-{request_id}",
|
||||
}
|
||||
|
||||
async def generate():
|
||||
try:
|
||||
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
yield chunk
|
||||
finally:
|
||||
d_inst.ongoing_tokens -= input_length
|
||||
|
||||
return StreamingResponse(generate(), media_type="application/json")
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
|
||||
p.add_argument("--port", type=int, default=8000)
|
||||
p.add_argument("--host", type=str, default="0.0.0.0")
|
||||
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
|
||||
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
|
||||
help="PD-Sep prefill: URL [bootstrap_port]")
|
||||
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
|
||||
help="PD-Sep decode: URL")
|
||||
args = p.parse_args()
|
||||
|
||||
args.prefill = []
|
||||
if args.prefill_raw:
|
||||
for entry in args.prefill_raw:
|
||||
url = entry[0]
|
||||
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
|
||||
args.prefill.append((url, bp))
|
||||
args.decode = [e[0] for e in (args.decode_raw or [])]
|
||||
|
||||
if not args.combined and not args.prefill:
|
||||
p.error("Must specify either --combined or --prefill/--decode")
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
global_args = parse_args()
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||
102
scripts/compare_results.py
Normal file
102
scripts/compare_results.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Compare benchmark results between PD-combined and PD-separated modes.
|
||||
|
||||
Reads summary JSON files and per-request metrics to produce a detailed
|
||||
comparison report including TTFT, TPOT, E2E, cache hit ratio, and
|
||||
throughput analysis.
|
||||
|
||||
Usage:
|
||||
python scripts/compare_results.py \
|
||||
--combined outputs/combined_1000req/metrics.summary.json \
|
||||
--separated outputs/separated_1000req/metrics.summary.json
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_summary(path: Path) -> dict:
|
||||
return json.loads(path.read_text())
|
||||
|
||||
|
||||
def load_metrics(path: Path) -> list[dict]:
|
||||
rows = []
|
||||
with path.open() as fh:
|
||||
for line in fh:
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def fmt_stat(stat: dict | None, unit: str = "s") -> str:
|
||||
if stat is None:
|
||||
return "N/A"
|
||||
return (f"mean={stat['mean']:.3f}{unit} "
|
||||
f"p50={stat['p50']:.3f}{unit} "
|
||||
f"p90={stat['p90']:.3f}{unit} "
|
||||
f"p99={stat['p99']:.3f}{unit}")
|
||||
|
||||
|
||||
def compare(combined: dict, separated: dict) -> None:
|
||||
print("=" * 70)
|
||||
print("PD-Combined vs PD-Separated Performance Comparison")
|
||||
print("=" * 70)
|
||||
|
||||
for label, s in [("PD-Combined", combined), ("PD-Separated", separated)]:
|
||||
print(f"\n--- {label} ---")
|
||||
print(f" Requests: {s['request_count']} (success: {s['success_count']}, errors: {s['error_count']})")
|
||||
print(f" Wall clock: {s.get('wall_clock_s', 0):.1f}s")
|
||||
print(f" TTFT: {fmt_stat(s.get('ttft_stats_s'))}")
|
||||
print(f" TPOT: {fmt_stat(s.get('tpot_stats_s'))}")
|
||||
print(f" E2E: {fmt_stat(s.get('latency_stats_s'))}")
|
||||
hit_ratio = s.get('prefix_cache_hit_ratio', 0)
|
||||
print(f" Prefix cache hit ratio: {hit_ratio*100:.1f}%")
|
||||
queries = s.get('prefix_cache_queries_tokens', 0)
|
||||
hits = s.get('prefix_cache_hits_tokens', 0)
|
||||
print(f" ({hits}/{queries} tokens)")
|
||||
|
||||
print("\n--- Comparison (Separated vs Combined) ---")
|
||||
for metric_key, label in [
|
||||
("ttft_stats_s", "TTFT"),
|
||||
("tpot_stats_s", "TPOT"),
|
||||
("latency_stats_s", "E2E"),
|
||||
]:
|
||||
c = combined.get(metric_key, {})
|
||||
s = separated.get(metric_key, {})
|
||||
if c and s:
|
||||
for pct in ["mean", "p50", "p90", "p99"]:
|
||||
cv, sv = c.get(pct, 0), s.get(pct, 0)
|
||||
if cv > 0:
|
||||
change = (sv - cv) / cv * 100
|
||||
direction = "slower" if change > 0 else "faster"
|
||||
print(f" {label} {pct}: {abs(change):.1f}% {direction} "
|
||||
f"({cv:.3f}s → {sv:.3f}s)")
|
||||
|
||||
c_ratio = combined.get("prefix_cache_hit_ratio", 0)
|
||||
s_ratio = separated.get("prefix_cache_hit_ratio", 0)
|
||||
print(f" Cache hit ratio: {c_ratio*100:.1f}% → {s_ratio*100:.1f}%")
|
||||
|
||||
c_wall = combined.get("wall_clock_s", 1)
|
||||
s_wall = separated.get("wall_clock_s", 1)
|
||||
c_tput = combined["success_count"] / c_wall
|
||||
s_tput = separated["success_count"] / s_wall
|
||||
print(f" Throughput: {c_tput:.1f} → {s_tput:.1f} req/s "
|
||||
f"({(s_tput/c_tput - 1)*100:+.1f}%)")
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
p.add_argument("--combined", type=Path, required=True)
|
||||
p.add_argument("--separated", type=Path, required=True)
|
||||
args = p.parse_args()
|
||||
|
||||
combined = load_summary(args.combined)
|
||||
separated = load_summary(args.separated)
|
||||
compare(combined, separated)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
210
scripts/compute_roofline.py
Normal file
210
scripts/compute_roofline.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Roofline analysis: compute/memory ratio for prefill vs decode
|
||||
under different sequence lengths and KV cache reuse ratios.
|
||||
|
||||
Model: Qwen3-Coder-30B-A3B (MoE)
|
||||
- 48 layers, hidden=2048, heads=32, kv_heads=4, head_dim=128
|
||||
- MoE: 128 experts, top-8 active, intermediate=6144
|
||||
- Total params: ~30B, Active params per token: ~3B
|
||||
|
||||
GPU: NVIDIA H20
|
||||
- BF16 peak: 148 TFLOPS
|
||||
- HBM bandwidth: 4.0 TB/s
|
||||
- Roofline ridge point: 148/4.0 = 37 FLOP/byte
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
|
||||
# ===== Model config =====
|
||||
L = 48 # layers
|
||||
D = 2048 # hidden dim
|
||||
H = 32 # attention heads
|
||||
H_kv = 4 # KV heads (GQA)
|
||||
D_head = 128 # head dim
|
||||
D_ffn = 6144 # FFN intermediate (per expert)
|
||||
N_experts = 128 # total experts
|
||||
K_experts = 8 # active experts per token
|
||||
VOCAB = 151936
|
||||
BYTES = 2 # BF16
|
||||
|
||||
# ===== GPU config (H20) =====
|
||||
PEAK_FLOPS = 148e12 # BF16 TFLOPS
|
||||
HBM_BW = 4.0e12 # bytes/s
|
||||
RIDGE_POINT = PEAK_FLOPS / HBM_BW # ~37 FLOP/byte
|
||||
|
||||
print("=" * 80)
|
||||
print(" ROOFLINE ANALYSIS: Prefill vs Decode under KV Cache Reuse")
|
||||
print(" Model: Qwen3-Coder-30B-A3B (MoE 128E top-8) | GPU: H20")
|
||||
print("=" * 80)
|
||||
print(f" Ridge point: {RIDGE_POINT:.1f} FLOP/byte")
|
||||
print(f" Above ridge → compute-bound | Below ridge → memory-bound")
|
||||
|
||||
# ===== Per-token compute & memory for each component =====
|
||||
|
||||
def attention_prefill_flops(seq_len, new_tokens):
|
||||
"""FLOPs for attention on new_tokens with seq_len context."""
|
||||
# QKV projection: new_tokens * D * (D + 2*D_kv) * 2
|
||||
d_kv = H_kv * D_head
|
||||
qkv_flops = new_tokens * (D * D * 2 + D * d_kv * 2 * 2) # Q + K + V
|
||||
# Attention score: new_tokens * seq_len * D * 2 (Q@K^T + softmax@V)
|
||||
attn_flops = new_tokens * seq_len * D * 2 * 2 # simplified: 2 matmuls
|
||||
# Output projection: new_tokens * D * D * 2
|
||||
out_flops = new_tokens * D * D * 2
|
||||
return (qkv_flops + attn_flops + out_flops) * L
|
||||
|
||||
def attention_prefill_bytes(seq_len, new_tokens, cached_tokens):
|
||||
"""Memory access for attention prefill."""
|
||||
d_kv = H_kv * D_head
|
||||
# Load model weights (QKV + O projections): D*(D+2*d_kv+D) * BYTES * L
|
||||
weight_bytes = D * (D + 2 * d_kv + D) * BYTES * L
|
||||
# Load cached KV: cached_tokens * 2 * d_kv * BYTES * L
|
||||
cached_kv_bytes = cached_tokens * 2 * d_kv * BYTES * L
|
||||
# Read input activations + write output: new_tokens * D * BYTES * 2 * L
|
||||
act_bytes = new_tokens * D * BYTES * 2 * L
|
||||
# Write new KV to cache: new_tokens * 2 * d_kv * BYTES * L
|
||||
new_kv_bytes = new_tokens * 2 * d_kv * BYTES * L
|
||||
return weight_bytes + cached_kv_bytes + act_bytes + new_kv_bytes
|
||||
|
||||
def ffn_flops(n_tokens):
|
||||
"""FLOPs for MoE FFN on n_tokens."""
|
||||
# Per expert: 3 * n_tokens * D * D_ffn * 2 (gate + up + down)
|
||||
# Active experts: K_experts
|
||||
return 3 * n_tokens * D * D_ffn * 2 * K_experts * L
|
||||
|
||||
def ffn_bytes(n_tokens):
|
||||
"""Memory access for MoE FFN."""
|
||||
# Load K_experts worth of weights per layer: K * 3 * D * D_ffn * BYTES
|
||||
weight_bytes = K_experts * 3 * D * D_ffn * BYTES * L
|
||||
# Activations: n_tokens * D * BYTES * 2 * L
|
||||
act_bytes = n_tokens * D * BYTES * 2 * L
|
||||
return weight_bytes + act_bytes
|
||||
|
||||
def decode_flops(seq_len):
|
||||
"""FLOPs for 1 decode token."""
|
||||
return attention_prefill_flops(seq_len, 1) + ffn_flops(1)
|
||||
|
||||
def decode_bytes(seq_len):
|
||||
"""Memory bytes for 1 decode token."""
|
||||
return attention_prefill_bytes(seq_len, 1, seq_len) + ffn_bytes(1)
|
||||
|
||||
# ===== Analysis =====
|
||||
|
||||
print("\n" + "-" * 80)
|
||||
print(" PART 1: Decode Roofline (baseline)")
|
||||
print("-" * 80)
|
||||
print(f" {'SeqLen':>8} {'FLOP':>14} {'Bytes':>14} {'AI (F/B)':>10} {'Bound':>12}")
|
||||
|
||||
for seq_len in [1000, 4000, 8000, 16000, 32000, 64000, 128000]:
|
||||
flops = decode_flops(seq_len)
|
||||
bytes_ = decode_bytes(seq_len)
|
||||
ai = flops / bytes_
|
||||
bound = "COMPUTE" if ai > RIDGE_POINT else "MEMORY"
|
||||
print(f" {seq_len:>8,} {flops:>14.2e} {bytes_:>14.2e} {ai:>10.1f} {bound:>12}")
|
||||
|
||||
print("\n" + "-" * 80)
|
||||
print(" PART 2: Prefill with KV Cache Reuse")
|
||||
print(" (Total input = seq_len, cached = seq_len * reuse_ratio, new = rest)")
|
||||
print("-" * 80)
|
||||
print(f" {'SeqLen':>8} {'Reuse%':>7} {'NewTok':>8} {'FLOP':>14} {'Bytes':>14} {'AI (F/B)':>10} {'Bound':>12} {'vs Decode':>10}")
|
||||
|
||||
for seq_len in [4000, 16000, 32000, 64000, 128000]:
|
||||
for reuse in [0.0, 0.3, 0.5, 0.7, 0.9, 0.95]:
|
||||
cached = int(seq_len * reuse)
|
||||
new = seq_len - cached
|
||||
|
||||
# Attention: compute on new tokens, but read cached KV for context
|
||||
attn_f = attention_prefill_flops(seq_len, new)
|
||||
attn_b = attention_prefill_bytes(seq_len, new, cached)
|
||||
|
||||
# FFN: only on new tokens
|
||||
ffn_f = ffn_flops(new)
|
||||
ffn_b = ffn_bytes(new)
|
||||
|
||||
total_f = attn_f + ffn_f
|
||||
total_b = attn_b + ffn_b
|
||||
ai = total_f / total_b if total_b > 0 else 0
|
||||
|
||||
# Compare with decode at same seq_len
|
||||
dec_f = decode_flops(seq_len)
|
||||
dec_b = decode_bytes(seq_len)
|
||||
dec_ai = dec_f / dec_b
|
||||
|
||||
bound = "COMPUTE" if ai > RIDGE_POINT else "MEMORY"
|
||||
ratio = f"{ai/dec_ai:.1f}x" if dec_ai > 0 else "N/A"
|
||||
|
||||
print(f" {seq_len:>8,} {reuse*100:>6.0f}% {new:>8,} {total_f:>14.2e} {total_b:>14.2e} {ai:>10.1f} {bound:>12} {ratio:>10}")
|
||||
print()
|
||||
|
||||
print("-" * 80)
|
||||
print(" PART 3: Key Thresholds")
|
||||
print("-" * 80)
|
||||
|
||||
# At what reuse ratio does prefill become memory-bound?
|
||||
for seq_len in [4000, 16000, 32000, 64000, 128000]:
|
||||
for reuse_pct in range(0, 100):
|
||||
reuse = reuse_pct / 100.0
|
||||
cached = int(seq_len * reuse)
|
||||
new = seq_len - cached
|
||||
if new < 1: continue
|
||||
attn_f = attention_prefill_flops(seq_len, new)
|
||||
attn_b = attention_prefill_bytes(seq_len, new, cached)
|
||||
ffn_f = ffn_flops(new)
|
||||
ffn_b = ffn_bytes(new)
|
||||
ai = (attn_f + ffn_f) / (attn_b + ffn_b)
|
||||
if ai < RIDGE_POINT:
|
||||
print(f" SeqLen={seq_len:>6,}: prefill becomes memory-bound at {reuse_pct}% reuse (AI={ai:.1f})")
|
||||
break
|
||||
|
||||
print()
|
||||
print("-" * 80)
|
||||
print(" PART 4: Agentic Workload Real Distribution")
|
||||
print("-" * 80)
|
||||
|
||||
# Use actual trace data
|
||||
import os
|
||||
trace_path = "traces/sampled_1000req_seed42.jsonl"
|
||||
if os.path.exists(trace_path):
|
||||
BLOCK_SIZE = 512
|
||||
seen = set()
|
||||
compute_bound = 0
|
||||
memory_bound = 0
|
||||
total = 0
|
||||
|
||||
for line in open(trace_path):
|
||||
d = json.loads(line)
|
||||
seq_len = d["input_length"]
|
||||
if seq_len < 1: continue
|
||||
hids = d.get("hash_ids", [])
|
||||
|
||||
cached_blocks = 0
|
||||
for hid in hids:
|
||||
if hid in seen:
|
||||
cached_blocks += 1
|
||||
else:
|
||||
break
|
||||
for hid in hids:
|
||||
seen.add(hid)
|
||||
|
||||
cached = cached_blocks * BLOCK_SIZE
|
||||
new = max(1, seq_len - cached)
|
||||
reuse = cached / seq_len
|
||||
|
||||
attn_f = attention_prefill_flops(seq_len, new)
|
||||
attn_b = attention_prefill_bytes(seq_len, new, cached)
|
||||
ffn_f = ffn_flops(new)
|
||||
ffn_b = ffn_bytes(new)
|
||||
ai = (attn_f + ffn_f) / (attn_b + ffn_b)
|
||||
|
||||
total += 1
|
||||
if ai > RIDGE_POINT:
|
||||
compute_bound += 1
|
||||
else:
|
||||
memory_bound += 1
|
||||
|
||||
print(f" With actual trace prefix cache pattern:")
|
||||
print(f" Compute-bound prefills: {compute_bound} ({compute_bound*100//total}%)")
|
||||
print(f" Memory-bound prefills: {memory_bound} ({memory_bound*100//total}%)")
|
||||
print(f" (Decode is ALWAYS memory-bound at these seq lengths)")
|
||||
print()
|
||||
print(f" Implication: {memory_bound*100//total}% of agentic prefills behave like decode")
|
||||
print(f" → PD separation treats them as 'compute-heavy' but they are actually memory-heavy")
|
||||
86
scripts/final_comparison.py
Normal file
86
scripts/final_comparison.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Final comparison of PD-Combined vs PD-Separated (Mooncake/RDMA)."""
|
||||
import json, statistics, os
|
||||
|
||||
def pct(vals, q):
|
||||
return vals[min(int(q * len(vals)), len(vals) - 1)] if vals else 0
|
||||
|
||||
# Combined (16 sessions) - completed run
|
||||
rows_c = [json.loads(l) for l in open("outputs/v18_combined_1000req/metrics.jsonl")]
|
||||
ok_c = [r for r in rows_c if not r.get("error")]
|
||||
ttfts_c = sorted([r["ttft_s"] for r in ok_c if r.get("ttft_s")])
|
||||
tpots_c = sorted([r["tpot_s"] for r in ok_c if r.get("tpot_s") and r["tpot_s"] > 0])
|
||||
lats_c = sorted([r["latency_s"] for r in ok_c if r.get("latency_s")])
|
||||
sc = json.load(open("outputs/v18_combined_1000req/metrics.summary.json"))
|
||||
|
||||
# PD-Separated Mooncake (first 200 stable requests)
|
||||
rows_d = [json.loads(l) for l in open("outputs/v18_pd_mooncake_lowconc/metrics.jsonl")][:200]
|
||||
ok_d = [r for r in rows_d if not r.get("error")]
|
||||
ttfts_d = sorted([r["ttft_s"] for r in ok_d if r.get("ttft_s")])
|
||||
tpots_d = sorted([r["tpot_s"] for r in ok_d if r.get("tpot_s") and r["tpot_s"] > 0])
|
||||
lats_d = sorted([r["latency_s"] for r in ok_d if r.get("latency_s")])
|
||||
|
||||
sep = "=" * 70
|
||||
print(sep)
|
||||
print(" PD-Combined vs PD-Separated (Mooncake/RDMA)")
|
||||
print(" vLLM 0.18.1 | Qwen3-Coder-30B-A3B | 8xH20")
|
||||
print(sep)
|
||||
|
||||
header = " {:<12} {:>16} {:>16} {:>10}".format(
|
||||
"Metric", "Combined(TP=8)", "PD-Sep(TP=4+4)", "Delta")
|
||||
print(header)
|
||||
dash = " {:<12} {:>16} {:>16} {:>10}".format("-" * 12, "-" * 16, "-" * 16, "-" * 10)
|
||||
print(dash)
|
||||
|
||||
req_c = "{}/{}".format(len(ok_c), len(rows_c))
|
||||
req_d = "{}/{}".format(len(ok_d), len(rows_d))
|
||||
print(" {:<12} {:>16} {:>16}".format("Requests", req_c, req_d))
|
||||
|
||||
data = [
|
||||
("TTFT p50", pct(ttfts_c, 0.5), pct(ttfts_d, 0.5)),
|
||||
("TTFT p90", pct(ttfts_c, 0.9), pct(ttfts_d, 0.9)),
|
||||
("TPOT p50", pct(tpots_c, 0.5), pct(tpots_d, 0.5)),
|
||||
("TPOT p90", pct(tpots_c, 0.9), pct(tpots_d, 0.9)),
|
||||
("E2E p50", pct(lats_c, 0.5), pct(lats_d, 0.5)),
|
||||
("E2E p90", pct(lats_c, 0.9), pct(lats_d, 0.9)),
|
||||
]
|
||||
|
||||
for label, cv, dv in data:
|
||||
delta = "{:+.0f}%".format((dv / cv - 1) * 100) if cv > 0 else "N/A"
|
||||
print(" {:<12} {:>15.3f}s {:>15.3f}s {:>10}".format(label, cv, dv, delta))
|
||||
|
||||
cache_c = sc.get("prefix_cache_hit_ratio", 0)
|
||||
print(" {:<12} {:>15.1f}% {:>16}".format("Cache hit", cache_c * 100, "N/A"))
|
||||
tput_c = len(ok_c) / sc.get("wall_clock_s", 1)
|
||||
print(" {:<12} {:>14.2f}/s {:>16}".format("Throughput", tput_c, "~0.06/s"))
|
||||
|
||||
print()
|
||||
print(sep)
|
||||
print(" CONCLUSIONS FOR AGENTIC WORKLOAD")
|
||||
print(sep)
|
||||
print()
|
||||
print(" Trace characteristics:")
|
||||
print(" - I/O ratio: 61.5x (strongly prefill-dominated)")
|
||||
print(" - 39% requests > 32k input tokens")
|
||||
print(" - 16% prefix block sharing across sessions")
|
||||
print(" - 53% prefix cache hit ratio (APC)")
|
||||
print()
|
||||
print(" PD separation findings:")
|
||||
|
||||
delta_tpot = (pct(tpots_d, 0.5) / pct(tpots_c, 0.5) - 1) * 100 if tpots_c else 0
|
||||
delta_ttft = (pct(ttfts_d, 0.5) / pct(ttfts_c, 0.5) - 1) * 100 if ttfts_c else 0
|
||||
delta_e2e = (pct(lats_d, 0.5) / pct(lats_c, 0.5) - 1) * 100 if lats_c else 0
|
||||
|
||||
print(" 1. TPOT {:+.0f}% - decode isolation benefit is {}".format(
|
||||
delta_tpot, "marginal" if abs(delta_tpot) < 20 else "significant"))
|
||||
print(" 2. TTFT {:+.0f}% - KV transfer + TP=4 overhead dominates".format(delta_ttft))
|
||||
print(" 3. E2E {:+.0f}% - net negative on single-machine".format(delta_e2e))
|
||||
print(" 4. Stability: Mooncake connector crashes after ~200 reqs under load")
|
||||
print()
|
||||
print(" Recommendation:")
|
||||
print(" - Single-machine 8 GPU: Combined mode is better (lower TTFT, stable)")
|
||||
print(" - Multi-machine: PD-Sep is promising IF cross-machine latency")
|
||||
print(" is hidden by RDMA and prefill doesn't share GPU with decode")
|
||||
print(" - Key bottleneck: this workload's heavy prefill (avg 32k tokens)")
|
||||
print(" makes KV transfer cost non-trivial relative to prefill time")
|
||||
print(" - Prefill-as-a-Service (Goal 5) should focus on cross-machine")
|
||||
print(" KV cache sharing, not same-machine PD split")
|
||||
96
scripts/launch_pd_mooncake.sh
Executable file
96
scripts/launch_pd_mooncake.sh
Executable file
@@ -0,0 +1,96 @@
|
||||
#!/bin/bash
|
||||
# PD-Disaggregated serving via Mooncake (RDMA + DRAM KV pool).
|
||||
#
|
||||
# Architecture:
|
||||
# Client → Proxy (port 8000)
|
||||
# → Prefill (port 8010, TP=4, GPUs 0-3, bootstrap 8998)
|
||||
# [prefill + store KV to DRAM pool via RDMA]
|
||||
# → Decode (port 8020, TP=4, GPUs 4-7)
|
||||
# [pull KV from DRAM pool via RDMA + decode]
|
||||
#
|
||||
# Usage: bash scripts/launch_pd_mooncake.sh
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
VENV="$PROJECT_DIR/.venv/bin"
|
||||
VLLM="$VENV/vllm"
|
||||
|
||||
MODEL_PATH="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
||||
PROXY_PORT=8000
|
||||
PREFILL_PORT=8010
|
||||
DECODE_PORT=8020
|
||||
BOOTSTRAP_PORT=8998
|
||||
|
||||
PROXY_SCRIPT="$PROJECT_DIR/third_party/vllm/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py"
|
||||
|
||||
trap 'echo "Cleaning up..."; kill $(jobs -p) 2>/dev/null; wait 2>/dev/null' EXIT INT TERM
|
||||
|
||||
echo "=== PD-Disaggregated vLLM 0.18.1 (Mooncake/RDMA) ==="
|
||||
echo " Model: $MODEL_PATH"
|
||||
echo " Prefill: GPUs 0-3 (TP=4), port $PREFILL_PORT, bootstrap $BOOTSTRAP_PORT"
|
||||
echo " Decode: GPUs 4-7 (TP=4), port $DECODE_PORT"
|
||||
echo " Proxy: port $PROXY_PORT"
|
||||
echo ""
|
||||
|
||||
# Step 1: Start prefill instance (KV producer)
|
||||
echo "[1/3] Starting prefill instance..."
|
||||
VLLM_MOONCAKE_BOOTSTRAP_PORT=$BOOTSTRAP_PORT \
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
$VLLM serve "$MODEL_PATH" \
|
||||
--host 0.0.0.0 \
|
||||
--port $PREFILL_PORT \
|
||||
--tensor-parallel-size 4 \
|
||||
--trust-remote-code \
|
||||
--enable-prefix-caching \
|
||||
--enforce-eager \
|
||||
--dtype auto \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"MooncakeConnector","kv_role":"kv_producer"}' &
|
||||
PREFILL_PID=$!
|
||||
echo " Prefill PID=$PREFILL_PID"
|
||||
|
||||
# Step 2: Start decode instance (KV consumer)
|
||||
echo "[2/3] Starting decode instance..."
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 \
|
||||
$VLLM serve "$MODEL_PATH" \
|
||||
--host 0.0.0.0 \
|
||||
--port $DECODE_PORT \
|
||||
--tensor-parallel-size 4 \
|
||||
--trust-remote-code \
|
||||
--enable-prefix-caching \
|
||||
--enforce-eager \
|
||||
--dtype auto \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"MooncakeConnector","kv_role":"kv_consumer"}' &
|
||||
DECODE_PID=$!
|
||||
echo " Decode PID=$DECODE_PID"
|
||||
|
||||
# Wait for both instances
|
||||
echo ""
|
||||
echo "Waiting for instances..."
|
||||
timeout 1200 bash -c "until curl -s localhost:$PREFILL_PORT/v1/models > /dev/null 2>&1; do sleep 5; done"
|
||||
echo " Prefill ready!"
|
||||
timeout 1200 bash -c "until curl -s localhost:$DECODE_PORT/v1/models > /dev/null 2>&1; do sleep 5; done"
|
||||
echo " Decode ready!"
|
||||
|
||||
# Step 3: Start proxy (after instances are ready)
|
||||
echo "[3/3] Starting proxy..."
|
||||
$VENV/python "$PROXY_SCRIPT" \
|
||||
--prefill "http://127.0.0.1:$PREFILL_PORT" "$BOOTSTRAP_PORT" \
|
||||
--decode "http://127.0.0.1:$DECODE_PORT" \
|
||||
--host 0.0.0.0 \
|
||||
--port $PROXY_PORT &
|
||||
PROXY_PID=$!
|
||||
echo " Proxy PID=$PROXY_PID"
|
||||
|
||||
sleep 5
|
||||
echo ""
|
||||
echo "=== All ready ==="
|
||||
echo " Send requests to: http://localhost:$PROXY_PORT"
|
||||
echo ""
|
||||
|
||||
wait
|
||||
89
scripts/launch_pd_separated.sh
Normal file
89
scripts/launch_pd_separated.sh
Normal file
@@ -0,0 +1,89 @@
|
||||
#!/bin/bash
|
||||
# PD-Disaggregated serving: 1 prefill (TP=4, GPUs 0-3) + 1 decode (TP=4, GPUs 4-7)
|
||||
# Uses vLLM 0.18.1's P2pNcclConnector + XpYd proxy.
|
||||
#
|
||||
# Architecture:
|
||||
# Client → Proxy (port 10001)
|
||||
# → Prefill (port 20003, kv_port 21001) [max_tokens=1, does prefill + KV push]
|
||||
# → Decode (port 20005, kv_port 22001) [full generation, KV pulled from prefill]
|
||||
#
|
||||
# Usage: bash scripts/launch_pd_separated.sh
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
VENV="$PROJECT_DIR/.venv/bin"
|
||||
VLLM="$VENV/vllm"
|
||||
|
||||
MODEL_PATH="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
||||
PROXY_PORT=30001 # ZMQ service discovery
|
||||
CLIENT_PORT=10001 # HTTP proxy for clients
|
||||
PREFILL_PORT=20003
|
||||
DECODE_PORT=20005
|
||||
KV_PORT_P=21001
|
||||
KV_PORT_D=22001
|
||||
|
||||
trap 'echo "Cleaning up..."; kill $(jobs -p) 2>/dev/null; wait 2>/dev/null' EXIT INT TERM
|
||||
|
||||
echo "=== PD-Disaggregated vLLM 0.18.1 ==="
|
||||
echo " Model: $MODEL_PATH"
|
||||
echo " Prefill: GPUs 0-3 (TP=4), port $PREFILL_PORT, kv_port $KV_PORT_P"
|
||||
echo " Decode: GPUs 4-7 (TP=4), port $DECODE_PORT, kv_port $KV_PORT_D"
|
||||
echo " Proxy: ZMQ=$PROXY_PORT, HTTP=$CLIENT_PORT"
|
||||
echo ""
|
||||
|
||||
# Step 1: Start proxy FIRST (P/D instances register via ZMQ)
|
||||
echo "[1/3] Starting proxy..."
|
||||
PROXY_SCRIPT="$PROJECT_DIR/third_party/vllm/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py"
|
||||
$VENV/python "$PROXY_SCRIPT" &
|
||||
PROXY_PID=$!
|
||||
sleep 2
|
||||
echo " Proxy PID=$PROXY_PID"
|
||||
|
||||
# Step 2: Start prefill instance (KV producer)
|
||||
echo "[2/3] Starting prefill instance..."
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 $VLLM serve "$MODEL_PATH" \
|
||||
--host 0.0.0.0 \
|
||||
--port $PREFILL_PORT \
|
||||
--tensor-parallel-size 4 \
|
||||
--trust-remote-code \
|
||||
--enable-prefix-caching \
|
||||
--enforce-eager \
|
||||
--dtype auto \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--kv-transfer-config \
|
||||
"{\"kv_connector\":\"P2pNcclConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_size\":\"1e1\",\"kv_port\":\"$KV_PORT_P\",\"kv_connector_extra_config\":{\"proxy_ip\":\"127.0.0.1\",\"proxy_port\":\"$PROXY_PORT\",\"http_port\":\"$PREFILL_PORT\",\"send_type\":\"PUT_ASYNC\",\"nccl_num_channels\":\"16\"}}" &
|
||||
PREFILL_PID=$!
|
||||
echo " Prefill PID=$PREFILL_PID"
|
||||
|
||||
# Step 3: Start decode instance (KV consumer)
|
||||
echo "[3/3] Starting decode instance..."
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 $VLLM serve "$MODEL_PATH" \
|
||||
--host 0.0.0.0 \
|
||||
--port $DECODE_PORT \
|
||||
--tensor-parallel-size 4 \
|
||||
--trust-remote-code \
|
||||
--enable-prefix-caching \
|
||||
--enforce-eager \
|
||||
--dtype auto \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--kv-transfer-config \
|
||||
"{\"kv_connector\":\"P2pNcclConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_size\":\"8e9\",\"kv_port\":\"$KV_PORT_D\",\"kv_connector_extra_config\":{\"proxy_ip\":\"127.0.0.1\",\"proxy_port\":\"$PROXY_PORT\",\"http_port\":\"$DECODE_PORT\",\"send_type\":\"PUT_ASYNC\",\"nccl_num_channels\":\"16\"}}" &
|
||||
DECODE_PID=$!
|
||||
echo " Decode PID=$DECODE_PID"
|
||||
|
||||
# Wait for readiness
|
||||
echo ""
|
||||
echo "Waiting for instances..."
|
||||
timeout 1200 bash -c "until curl -s localhost:$PREFILL_PORT/v1/completions > /dev/null 2>&1; do sleep 5; done"
|
||||
echo " Prefill ready!"
|
||||
timeout 1200 bash -c "until curl -s localhost:$DECODE_PORT/v1/completions > /dev/null 2>&1; do sleep 5; done"
|
||||
echo " Decode ready!"
|
||||
|
||||
echo ""
|
||||
echo "=== All ready ==="
|
||||
echo " Send requests to: http://localhost:$CLIENT_PORT"
|
||||
echo ""
|
||||
|
||||
wait
|
||||
23
scripts/launch_vllm.sh
Executable file
23
scripts/launch_vllm.sh
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
# Launch vLLM 0.18.1 in PD-combined mode (TP=8, all GPUs).
|
||||
#
|
||||
# Usage: bash scripts/launch_vllm.sh
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
VLLM="$PROJECT_DIR/.venv/bin/vllm"
|
||||
|
||||
MODEL_PATH="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
||||
HOST="${HOST:-0.0.0.0}"
|
||||
PORT="${PORT:-8000}"
|
||||
|
||||
echo "Starting vLLM 0.18.1 in PD-combined mode (TP=8) on port $PORT ..."
|
||||
$VLLM serve "$MODEL_PATH" \
|
||||
--trust-remote-code \
|
||||
--enable-prefix-caching \
|
||||
--dtype auto \
|
||||
--tensor-parallel-size 8 \
|
||||
--host "$HOST" \
|
||||
--port "$PORT"
|
||||
77
scripts/run_benchmark.sh
Executable file
77
scripts/run_benchmark.sh
Executable file
@@ -0,0 +1,77 @@
|
||||
#!/bin/bash
|
||||
# Run the full benchmark suite: sample trace → replay against vLLM → collect metrics.
|
||||
#
|
||||
# Prerequisites:
|
||||
# - vLLM server running (use scripts/launch_vllm.sh)
|
||||
# - Sampled trace file exists (or will be created)
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/run_benchmark.sh [--endpoint URL] [--tag NAME]
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
cd "$PROJECT_DIR"
|
||||
|
||||
# Defaults
|
||||
TRACE_INPUT="${TRACE_INPUT:-$HOME/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl}"
|
||||
ENDPOINT="${ENDPOINT:-http://localhost:8000}"
|
||||
TAG="${TAG:-default}"
|
||||
TARGET_REQUESTS="${TARGET_REQUESTS:-5000}"
|
||||
TIME_SCALE="${TIME_SCALE:-1.0}"
|
||||
MAX_INFLIGHT="${MAX_INFLIGHT:-32}"
|
||||
SEED="${SEED:-42}"
|
||||
|
||||
# Parse args
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--endpoint) ENDPOINT="$2"; shift 2 ;;
|
||||
--tag) TAG="$2"; shift 2 ;;
|
||||
--target-requests) TARGET_REQUESTS="$2"; shift 2 ;;
|
||||
--time-scale) TIME_SCALE="$2"; shift 2 ;;
|
||||
--max-inflight) MAX_INFLIGHT="$2"; shift 2 ;;
|
||||
*) echo "Unknown arg: $1"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
SAMPLED_TRACE="traces/sampled_${TARGET_REQUESTS}req_seed${SEED}.jsonl"
|
||||
OUTPUT_DIR="outputs/${TAG}_$(date +%Y%m%d_%H%M%S)"
|
||||
|
||||
echo "=== Benchmark: tag=$TAG ==="
|
||||
echo " Trace: $TRACE_INPUT"
|
||||
echo " Endpoint: $ENDPOINT"
|
||||
echo " Target requests: $TARGET_REQUESTS"
|
||||
echo " Time scale: $TIME_SCALE"
|
||||
echo " Max inflight sessions: $MAX_INFLIGHT"
|
||||
|
||||
# Step 1: Sample trace (if not already done)
|
||||
if [ ! -f "$SAMPLED_TRACE" ]; then
|
||||
echo ""
|
||||
echo "=== Step 1: Sampling trace ==="
|
||||
python scripts/sample_trace.py \
|
||||
--input "$TRACE_INPUT" \
|
||||
--output "$SAMPLED_TRACE" \
|
||||
--target-requests "$TARGET_REQUESTS" \
|
||||
--seed "$SEED"
|
||||
else
|
||||
echo ""
|
||||
echo "=== Step 1: Using existing sampled trace: $SAMPLED_TRACE ==="
|
||||
fi
|
||||
|
||||
# Step 2: Run replay
|
||||
echo ""
|
||||
echo "=== Step 2: Replaying trace ==="
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
python -m replayer \
|
||||
--trace "$SAMPLED_TRACE" \
|
||||
--output "$OUTPUT_DIR/metrics.jsonl" \
|
||||
--endpoint "$ENDPOINT" \
|
||||
--time-scale "$TIME_SCALE" \
|
||||
--max-inflight-sessions "$MAX_INFLIGHT" \
|
||||
-v
|
||||
|
||||
echo ""
|
||||
echo "=== Done ==="
|
||||
echo " Metrics: $OUTPUT_DIR/metrics.jsonl"
|
||||
echo " Summary: $OUTPUT_DIR/metrics.summary.json"
|
||||
254
scripts/run_experiments.sh
Executable file
254
scripts/run_experiments.sh
Executable file
@@ -0,0 +1,254 @@
|
||||
#!/bin/bash
|
||||
# Run the complete experiment matrix:
|
||||
# 1. Combined TP=2 DP=4 (4 instances, baseline)
|
||||
# 2. Combined TP=1 DP=8 (8 instances, max throughput)
|
||||
# 3. PD-Sep TP=1: P×4 + D×4 via Mooncake/RDMA
|
||||
#
|
||||
# All use the same trace, same concurrency, same timeout.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
VENV="$PROJECT_DIR/.venv/bin"
|
||||
VLLM="$VENV/vllm"
|
||||
PYTHON="$VENV/python"
|
||||
|
||||
MODEL="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
||||
TRACE="$PROJECT_DIR/traces/sampled_1000req_seed42.jsonl"
|
||||
|
||||
# Uniform benchmark params
|
||||
MAX_SESSIONS=${MAX_SESSIONS:-8}
|
||||
MAX_CONCURRENT=${MAX_CONCURRENT:-16}
|
||||
TIME_SCALE=10
|
||||
REQUEST_TIMEOUT=${REQUEST_TIMEOUT:-300}
|
||||
REQUEST_LIMIT="${REQUEST_LIMIT:-}" # empty = all 1000
|
||||
|
||||
cleanup_gpu() {
|
||||
pkill -9 -f "vllm" 2>/dev/null || true
|
||||
pkill -9 -f "cache_aware_proxy\|mooncake_connector_proxy\|uvicorn" 2>/dev/null || true
|
||||
fuser 9090/tcp 8000/tcp 2>/dev/null | xargs -r kill -9 2>/dev/null || true
|
||||
sleep 5
|
||||
fuser /dev/nvidia* 2>/dev/null | tr " " "\n" | sort -u | xargs -r kill -9 2>/dev/null || true
|
||||
sleep 10
|
||||
}
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
local timeout=${2:-600}
|
||||
timeout "$timeout" bash -c "until curl -s localhost:$port/v1/models >/dev/null 2>&1; do sleep 5; done"
|
||||
}
|
||||
|
||||
run_benchmark() {
|
||||
local tag=$1
|
||||
local endpoint=$2
|
||||
local extra_args="${3:-}"
|
||||
local outdir="$PROJECT_DIR/outputs/$tag"
|
||||
|
||||
echo " Running benchmark -> $outdir"
|
||||
local limit_arg=""
|
||||
if [ -n "$REQUEST_LIMIT" ]; then
|
||||
limit_arg="--request-limit $REQUEST_LIMIT"
|
||||
fi
|
||||
|
||||
$PYTHON -m replayer \
|
||||
--trace "$TRACE" \
|
||||
--output "$outdir/metrics.jsonl" \
|
||||
--endpoint "$endpoint" \
|
||||
--model "$MODEL" \
|
||||
--time-scale $TIME_SCALE \
|
||||
--max-inflight-sessions $MAX_SESSIONS \
|
||||
--concurrency-limit $MAX_CONCURRENT \
|
||||
--request-timeout $REQUEST_TIMEOUT \
|
||||
$limit_arg \
|
||||
-v
|
||||
|
||||
echo " Done: $(wc -l < "$outdir/metrics.jsonl") requests"
|
||||
}
|
||||
|
||||
#######################################################################
|
||||
# Experiment 1: Combined TP=2 DP=4
|
||||
#######################################################################
|
||||
run_combined_tp2_dp4() {
|
||||
echo ""
|
||||
echo "================================================================"
|
||||
echo " Experiment 1: Combined TP=2 DP=4 (4 instances on 8 GPUs)"
|
||||
echo "================================================================"
|
||||
cleanup_gpu
|
||||
|
||||
for i in 0 1 2 3; do
|
||||
local gpu_start=$((i * 2))
|
||||
local gpu_end=$((gpu_start + 1))
|
||||
local port=$((8000 + i))
|
||||
echo " Starting instance $i: GPUs $gpu_start,$gpu_end, port $port"
|
||||
CUDA_VISIBLE_DEVICES=$gpu_start,$gpu_end $VLLM serve "$MODEL" \
|
||||
--host 0.0.0.0 --port $port \
|
||||
--tensor-parallel-size 2 \
|
||||
--trust-remote-code --enable-prefix-caching --enforce-eager \
|
||||
--dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 &
|
||||
done
|
||||
|
||||
for i in 0 1 2 3; do
|
||||
wait_for_server $((8000 + i))
|
||||
echo " Instance $i ready"
|
||||
done
|
||||
echo " All 4 instances ready"
|
||||
|
||||
# Start global scheduler (cache-aware proxy in combined mode)
|
||||
echo " Starting global scheduler..."
|
||||
$PYTHON "$PROJECT_DIR/scripts/cache_aware_proxy.py" \
|
||||
--combined http://127.0.0.1:8000 http://127.0.0.1:8001 http://127.0.0.1:8002 http://127.0.0.1:8003 \
|
||||
--port 9090 &
|
||||
sleep 5
|
||||
|
||||
run_benchmark "exp1_combined_tp2_dp4" "http://localhost:9090"
|
||||
}
|
||||
|
||||
#######################################################################
|
||||
# Experiment 2: Combined TP=1 DP=8
|
||||
#######################################################################
|
||||
run_combined_tp1_dp8() {
|
||||
echo ""
|
||||
echo "================================================================"
|
||||
echo " Experiment 2: Combined TP=1 DP=8 (8 instances on 8 GPUs)"
|
||||
echo "================================================================"
|
||||
cleanup_gpu
|
||||
|
||||
for i in $(seq 0 7); do
|
||||
local port=$((8000 + i))
|
||||
echo " Starting instance $i: GPU $i, port $port"
|
||||
CUDA_VISIBLE_DEVICES=$i $VLLM serve "$MODEL" \
|
||||
--host 0.0.0.0 --port $port \
|
||||
--tensor-parallel-size 1 \
|
||||
--trust-remote-code --enable-prefix-caching --enforce-eager \
|
||||
--dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 &
|
||||
done
|
||||
|
||||
for i in $(seq 0 7); do
|
||||
wait_for_server $((8000 + i))
|
||||
echo " Instance $i ready"
|
||||
done
|
||||
echo " All 8 instances ready"
|
||||
|
||||
# Start global scheduler (cache-aware proxy in combined mode)
|
||||
echo " Starting global scheduler..."
|
||||
$PYTHON "$PROJECT_DIR/scripts/cache_aware_proxy.py" \
|
||||
--combined http://127.0.0.1:8000 http://127.0.0.1:8001 http://127.0.0.1:8002 http://127.0.0.1:8003 \
|
||||
http://127.0.0.1:8004 http://127.0.0.1:8005 http://127.0.0.1:8006 http://127.0.0.1:8007 \
|
||||
--port 9090 &
|
||||
sleep 5
|
||||
|
||||
run_benchmark "exp2_combined_tp1_dp8" "http://localhost:9090"
|
||||
}
|
||||
|
||||
#######################################################################
|
||||
# Experiment 3: PD-Sep TP=1 P×4 D×4 (Mooncake/RDMA)
|
||||
#######################################################################
|
||||
run_pd_sep_tp1() {
|
||||
echo ""
|
||||
echo "================================================================"
|
||||
echo " Experiment 3: PD-Sep TP=1 P×4 + D×4 (Mooncake/RDMA)"
|
||||
echo "================================================================"
|
||||
cleanup_gpu
|
||||
|
||||
PROXY_SCRIPT="$PROJECT_DIR/scripts/cache_aware_proxy.py"
|
||||
|
||||
# Start 4 prefill instances (GPUs 0-3)
|
||||
local prefill_args=""
|
||||
for i in 0 1 2 3; do
|
||||
local port=$((8010 + i))
|
||||
local bootstrap=$((8998 + i))
|
||||
echo " Prefill $i: GPU $i, port $port, bootstrap $bootstrap"
|
||||
VLLM_MOONCAKE_BOOTSTRAP_PORT=$bootstrap \
|
||||
CUDA_VISIBLE_DEVICES=$i $VLLM serve "$MODEL" \
|
||||
--host 0.0.0.0 --port $port \
|
||||
--tensor-parallel-size 1 \
|
||||
--trust-remote-code --enable-prefix-caching --enforce-eager \
|
||||
--dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 \
|
||||
--kv-transfer-config \
|
||||
"{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_producer\"}" &
|
||||
prefill_args="$prefill_args --prefill http://127.0.0.1:$port $bootstrap"
|
||||
done
|
||||
|
||||
# Start 4 decode instances (GPUs 4-7)
|
||||
local decode_args=""
|
||||
for i in 0 1 2 3; do
|
||||
local gpu=$((4 + i))
|
||||
local port=$((8020 + i))
|
||||
echo " Decode $i: GPU $gpu, port $port"
|
||||
CUDA_VISIBLE_DEVICES=$gpu $VLLM serve "$MODEL" \
|
||||
--host 0.0.0.0 --port $port \
|
||||
--tensor-parallel-size 1 \
|
||||
--trust-remote-code --enable-prefix-caching --enforce-eager \
|
||||
--dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 \
|
||||
--kv-transfer-config \
|
||||
"{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_consumer\",\"kv_load_failure_policy\":\"recompute\"}" &
|
||||
decode_args="$decode_args --decode http://127.0.0.1:$port"
|
||||
done
|
||||
|
||||
# Wait for all instances
|
||||
for i in 0 1 2 3; do
|
||||
wait_for_server $((8010 + i))
|
||||
echo " Prefill $i ready"
|
||||
done
|
||||
for i in 0 1 2 3; do
|
||||
wait_for_server $((8020 + i))
|
||||
echo " Decode $i ready"
|
||||
done
|
||||
|
||||
# Start proxy (wait for bootstrap to be queryable first)
|
||||
echo " Waiting for bootstrap servers..."
|
||||
for bp in 8998 8999 9000 9001; do
|
||||
timeout 120 bash -c "until curl -s localhost:$bp/query > /dev/null 2>&1; do sleep 2; done"
|
||||
echo " Bootstrap $bp ready"
|
||||
done
|
||||
|
||||
echo " Starting proxy on port 9000..."
|
||||
$PYTHON "$PROXY_SCRIPT" $prefill_args $decode_args --host 0.0.0.0 --port 9090 &
|
||||
sleep 15
|
||||
|
||||
# Smoke test with retry
|
||||
echo " Smoke test..."
|
||||
for attempt in 1 2 3; do
|
||||
result=$(curl -s -m 120 http://localhost:9090/v1/completions \
|
||||
-X POST -H "Content-Type: application/json" \
|
||||
-d "{\"model\":\"$MODEL\",\"prompt\":[100,200,300],\"max_tokens\":3,\"temperature\":0}" 2>&1)
|
||||
if echo "$result" | grep -q "choices"; then
|
||||
echo " Smoke test passed!"
|
||||
break
|
||||
fi
|
||||
echo " Attempt $attempt failed, retrying..."
|
||||
sleep 10
|
||||
done
|
||||
|
||||
run_benchmark "exp3_pd_sep_tp1_mooncake" "http://localhost:9090"
|
||||
}
|
||||
|
||||
#######################################################################
|
||||
# Main
|
||||
#######################################################################
|
||||
echo "Starting experiment matrix on $(hostname)"
|
||||
echo "Model: $MODEL"
|
||||
echo "Trace: $TRACE"
|
||||
echo "Params: sessions=$MAX_SESSIONS, concurrent=$MAX_CONCURRENT, time_scale=$TIME_SCALE"
|
||||
echo ""
|
||||
|
||||
case "${1:-all}" in
|
||||
1|tp2dp4) run_combined_tp2_dp4 ;;
|
||||
2|tp1dp8) run_combined_tp1_dp8 ;;
|
||||
3|pdsep) run_pd_sep_tp1 ;;
|
||||
all)
|
||||
run_combined_tp2_dp4
|
||||
run_combined_tp1_dp8
|
||||
run_pd_sep_tp1
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 {1|2|3|all|tp2dp4|tp1dp8|pdsep}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
echo ""
|
||||
echo "================================================================"
|
||||
echo " All experiments complete!"
|
||||
echo "================================================================"
|
||||
cleanup_gpu
|
||||
204
scripts/sample_trace.py
Normal file
204
scripts/sample_trace.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""Sample sessions from the full cluster-scale trace to fit a single machine.
|
||||
|
||||
Preserves:
|
||||
- Complete session structure (all turns within a session kept together)
|
||||
- Original arrival timing (inter-session and intra-session gaps)
|
||||
- hash_ids for KV cache reuse patterns
|
||||
- Request type distribution
|
||||
|
||||
Sampling strategy:
|
||||
1. Group requests by session (derived from parent_chat_id chains)
|
||||
2. Randomly sample N sessions (or until target request count reached)
|
||||
3. Re-zero timestamps so first event starts at t=0
|
||||
4. Optionally compress time axis to increase load density
|
||||
|
||||
Usage:
|
||||
python scripts/sample_trace.py \\
|
||||
--input ~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl \\
|
||||
--output traces/sampled.jsonl \\
|
||||
--target-requests 5000 \\
|
||||
--seed 42
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_raw_rows(path: Path) -> dict[str, list[dict]]:
|
||||
"""Load trace, group rows by resolved session_id. Preserve file order."""
|
||||
chat_to_session: dict[int, str] = {}
|
||||
rows_by_session: dict[str, list[dict]] = collections.OrderedDict()
|
||||
|
||||
with path.open("r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
row = json.loads(line)
|
||||
cid = int(row["chat_id"])
|
||||
pid = int(row["parent_chat_id"])
|
||||
|
||||
if "session_id" in row:
|
||||
sid = str(row["session_id"])
|
||||
elif pid < 0:
|
||||
sid = str(cid)
|
||||
else:
|
||||
sid = chat_to_session.get(pid, str(pid))
|
||||
chat_to_session[cid] = sid
|
||||
|
||||
row["_session_id"] = sid
|
||||
rows_by_session.setdefault(sid, []).append(row)
|
||||
|
||||
return rows_by_session
|
||||
|
||||
|
||||
def sample_sessions(
|
||||
rows_by_session: dict[str, list[dict]],
|
||||
*,
|
||||
target_requests: int,
|
||||
seed: int,
|
||||
strategy: str = "random",
|
||||
) -> list[str]:
|
||||
"""Select sessions until target request count is reached."""
|
||||
all_sids = list(rows_by_session.keys())
|
||||
rng = random.Random(seed)
|
||||
|
||||
if strategy == "random":
|
||||
rng.shuffle(all_sids)
|
||||
elif strategy == "sequential":
|
||||
pass # keep file order
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy: {strategy}")
|
||||
|
||||
selected = []
|
||||
total = 0
|
||||
for sid in all_sids:
|
||||
selected.append(sid)
|
||||
total += len(rows_by_session[sid])
|
||||
if total >= target_requests:
|
||||
break
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def build_output(
|
||||
rows_by_session: dict[str, list[dict]],
|
||||
selected: list[str],
|
||||
*,
|
||||
time_scale: float = 1.0,
|
||||
) -> list[dict]:
|
||||
"""Build output rows with re-zeroed timestamps."""
|
||||
out_rows = []
|
||||
for sid in selected:
|
||||
for row in rows_by_session[sid]:
|
||||
out = {k: v for k, v in row.items() if not k.startswith("_")}
|
||||
out["session_id"] = sid
|
||||
out_rows.append(out)
|
||||
|
||||
out_rows.sort(key=lambda r: float(r["timestamp"]))
|
||||
|
||||
if not out_rows:
|
||||
return out_rows
|
||||
|
||||
# Re-zero: subtract earliest timestamp
|
||||
t0 = float(out_rows[0]["timestamp"])
|
||||
for row in out_rows:
|
||||
row["timestamp"] = (float(row["timestamp"]) - t0) / time_scale
|
||||
|
||||
return out_rows
|
||||
|
||||
|
||||
def print_summary(
|
||||
rows_by_session: dict[str, list[dict]],
|
||||
selected: list[str],
|
||||
out_rows: list[dict],
|
||||
) -> None:
|
||||
n_sessions = len(selected)
|
||||
n_requests = len(out_rows)
|
||||
turns_per_session = [len(rows_by_session[s]) for s in selected]
|
||||
multi_turn = sum(1 for t in turns_per_session if t > 1)
|
||||
|
||||
input_lens = [r["input_length"] for r in out_rows]
|
||||
output_lens = [r["output_length"] for r in out_rows]
|
||||
|
||||
span_s = float(out_rows[-1]["timestamp"]) if out_rows else 0
|
||||
session_starts = {}
|
||||
for r in out_rows:
|
||||
sid = r["session_id"]
|
||||
ts = float(r["timestamp"])
|
||||
if sid not in session_starts:
|
||||
session_starts[sid] = ts
|
||||
starts_sorted = sorted(session_starts.values())
|
||||
deltas = [starts_sorted[i+1] - starts_sorted[i]
|
||||
for i in range(len(starts_sorted) - 1)]
|
||||
|
||||
# hash_ids overlap: count unique hash_ids across all requests
|
||||
all_hashes = set()
|
||||
for r in out_rows:
|
||||
all_hashes.update(r.get("hash_ids", []))
|
||||
|
||||
print(f"Sampled: {n_sessions} sessions, {n_requests} requests")
|
||||
print(f" Multi-turn sessions: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)")
|
||||
print(f" Turns/session: min={min(turns_per_session)} max={max(turns_per_session)} "
|
||||
f"avg={sum(turns_per_session)/len(turns_per_session):.1f}")
|
||||
print(f" Input length: min={min(input_lens)} max={max(input_lens)} "
|
||||
f"avg={sum(input_lens)/len(input_lens):.0f}")
|
||||
print(f" Output length: min={min(output_lens)} max={max(output_lens)} "
|
||||
f"avg={sum(output_lens)/len(output_lens):.0f}")
|
||||
print(f" Trace span: {span_s:.1f}s ({span_s/60:.1f} min)")
|
||||
print(f" Unique hash blocks: {len(all_hashes)}")
|
||||
if deltas:
|
||||
deltas.sort()
|
||||
p = lambda q: deltas[min(int(q * len(deltas)), len(deltas) - 1)]
|
||||
print(f" Session arrival deltas (s): p10={p(0.1):.2f} p50={p(0.5):.2f} "
|
||||
f"p90={p(0.9):.2f} max={max(deltas):.2f}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
p.add_argument("--input", type=Path, required=True,
|
||||
help="Path to the full trace JSONL file")
|
||||
p.add_argument("--output", type=Path, required=True,
|
||||
help="Path to write sampled trace JSONL")
|
||||
p.add_argument("--target-requests", type=int, default=5000,
|
||||
help="Target number of requests (stops after session that crosses it)")
|
||||
p.add_argument("--strategy", choices=["random", "sequential"], default="random",
|
||||
help="Session selection strategy")
|
||||
p.add_argument("--time-scale", type=float, default=1.0,
|
||||
help="Compress time axis by this factor (>1 = faster arrival)")
|
||||
p.add_argument("--seed", type=int, default=42)
|
||||
args = p.parse_args()
|
||||
|
||||
print(f"Loading trace from {args.input} ...")
|
||||
rows_by_session = load_raw_rows(args.input)
|
||||
total_sessions = len(rows_by_session)
|
||||
total_requests = sum(len(v) for v in rows_by_session.values())
|
||||
print(f"Full trace: {total_sessions} sessions, {total_requests} requests")
|
||||
|
||||
selected = sample_sessions(
|
||||
rows_by_session,
|
||||
target_requests=args.target_requests,
|
||||
seed=args.seed,
|
||||
strategy=args.strategy,
|
||||
)
|
||||
|
||||
out_rows = build_output(
|
||||
rows_by_session, selected,
|
||||
time_scale=args.time_scale,
|
||||
)
|
||||
|
||||
print_summary(rows_by_session, selected, out_rows)
|
||||
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
with args.output.open("w", encoding="utf-8") as fh:
|
||||
for row in out_rows:
|
||||
fh.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
print(f"\nWrote {len(out_rows)} rows to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user