Agentic workload PD separation analysis with trace-driven benchmarks

Systematic study of prefill-decode disaggregation for agentic LLM workloads
using production GLM-5.1 coder trace (2.1M requests, 71B input tokens).

Key findings:
- Cache-aware routing improves TPOT p90 by 15% and APC from 20.8% to 44.7%
  without PD separation, matching PD-Sep's decode isolation benefit
- PD separation adds +72% TTFT overhead (KV transfer) with no TPOT gain
  when using the same cache-aware scheduler
- Prefill remains compute-bound even at 95% KV cache reuse (AI >1000x
  vs decode AI <2), but absolute FLOPs drop 71% from cache hits
- For agentic MoE workloads, cache-aware routing > PD separation

Infrastructure:
- Trace sampler preserving session structure + hash_ids for prefix sharing
- Async trace replayer with streaming TTFT/TPOT/E2E measurement
- Unified cache-aware + token-level load-balanced global scheduler proxy
  supporting both PD-colocated and PD-disaggregated (Mooncake/RDMA) modes
- vLLM 0.18.1 scheduler patch for KV transfer abort race condition
- Roofline analysis tool for prefill/decode compute characterization

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 21:21:57 +08:00
commit 05592e6adc
22 changed files with 2837 additions and 0 deletions

8
.gitignore vendored Normal file
View File

@@ -0,0 +1,8 @@
__pycache__/
*.pyc
.venv/
*.egg-info/
outputs/
traces/
third_party/
*.log

25
TODO.md Normal file
View File

@@ -0,0 +1,25 @@
实验 setup
GPU 机器dash0是 8*H20 的机器,可以直接 `ssh dash0` 进行连接访问
推理引擎:基于 vllm 0.18.1self build支持后续 patch 放在 git 中维护
模型:`~/models/Qwen/Qwen3-Coder-30B-A3B-Instruct`
推理 trace原始完整 2h trace 在 dash0 的 `~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl`
性能指标:每个请求的 TTFT/TPOT/E2E/TBTprefix KVCache hit ratio
目标:
1. 先实现标准的 trace-sampler将 cluster 规模的原始 tracesample 到合适当前机器数量的规模来跑,保持一份统一的 sample 后的 trace file 作为输入
2. 实现标准的 trace replayer保证能够体现线上流量的流量到来特征KVCache 重用特征等
3. 跑通 PD 分离,确认 PD 分离能够比普通的 PD 混合一起跑的性能要好,给出两者详细的性能对比以及原因分析
4. 判断 trace 的 pattern是否有必要 PD 完全混合或者 PD 分离
5. 参考本地的 `~/phd/agentic-pd-hybrid`,判断是否能够实现一套 prefill-as-a-service 的架构,把重的 prefill 交给 prefill serviceprefill service 能够从本地 GPU/DRAM/别的 GPU 机器上 pull KVCache提高本地的 prefix KVCache hit ratio不影响 decoding 的 prefill就可以交给过去 PD 分离定义中 D-node 来做,提高 KVCache 命中率

View File

@@ -0,0 +1,289 @@
# PD 分离在 Agentic Workload 下的系统分析
## 1. Trace 特征 (GLM-5.1 Agentic Coder, 2h, 2.1M requests)
```
Total requests: 2,114,220
Input tokens: 71.1B (avg 33.6k/req, p50=20k, p90=88k)
Output tokens: 940M (avg 445/req, p50=80, p90=811)
I/O ratio: 75.6x (aggregate), 217.8x (per-req median)
Prefill share: 98% of total tokens
Sessions: 1.3M (90% single-turn, 9% multi-turn)
```
**与传统 chatbot workload 的根本区别:**
| 特征 | Traditional Chatbot | Agentic Coder (GLM-5.1) |
|------|-------------------|------------------------|
| I/O ratio | 1-10x | **75.6x** |
| Input p50 | 500-2000 tokens | **20,030 tokens** |
| Output p50 | 200-500 tokens | **80 tokens** |
| Prefill token share | 50-80% | **98%** |
| >32k input | <5% | **38%** |
| Multi-turn | 50-80% | **9%** |
**KV Cache 复用特征:**
```
Unique hash blocks: 20,650,883
Shared blocks (ref>1): 9,749,379 (47%)
Highly shared (ref>10): 2,428,160
Intra-session reuse: 57%
Top-10 blocks ref count: 64,754 (system prompt blocks)
Theoretical cache hit: 71% (infinite cache, first 100k requests)
```
**Input length 分布与 token 占比:**
```
<1k: 202,396 reqs ( 9%) 89M tokens ( 0%)
1-8k: 380,009 reqs (17%) 1.6B tokens ( 2%)
8-32k: 720,871 reqs (34%) 12.7B tokens (17%)
32-65k: 405,371 reqs (19%) 19.4B tokens (27%)
65-131k: 394,014 reqs (18%) 35.7B tokens (50%)
>131k: 11,559 reqs ( 0%) 1.6B tokens ( 2%)
```
50% token 计算量来自 65-131k 的长 context 请求
## 2. DistServe 等 PD 分离的核心假设
DistServe (OSDI'24), Splitwise, TetriInfer PD 分离工作基于以下假设
### 假设 A: Prefill 和 Decode 有不同的计算特征
- **Prefill**: compute-bound, GPU 利用率, batch 越大越好
- **Decode**: memory-bandwidth-bound, GPU 利用率, latency-sensitive
**在 agentic workload 中的验证**: 成立但需要细化
Roofline 分析显示详见 Section 5
```
Arithmetic Intensity (FLOP/byte)
Decode: 1.0 - 1.9 (memory-bound, 始终远低于 ridge point)
Prefill 0% reuse: 23,000-72,000 (strongly compute-bound)
Prefill 70% reuse: 10,000-42,000 (仍然 compute-bound!)
Prefill 95% reuse: 1,900-10,800 (仍然 compute-bound!)
Ridge point (H20): 37
```
**即使 95% KV cache reuseprefill 仍然是 compute-bound。** 但绝对计算量大幅减少
### 假设 B: PD co-location 导致互相干扰
- Prefill 的大 batch 计算会抢占 GPU 资源导致 decode TPOT 升高
- Decode 的持续小计算会占用 GPU 调度槽位影响 prefill 吞吐
**在 agentic workload 中的验证**: 干扰存在 **可被 cache-aware routing 消除**
```
同一 cache-aware scheduler, TP=1, 8 GPU:
Combined TP=1 DP=8: TPOT p90 = 0.073s
PD-Sep TP=1 4P+4D: TPOT p90 = 0.074s
→ 差异 <2%, 不显著
```
对比 round-robin routing:
```
Combined TP=1 DP=8 (RR): TPOT p90 = 0.086s
Combined TP=1 DP=8 (cache-aware): TPOT p90 = 0.073s → -15%
→ routing 改善 > PD 分离改善
```
**原因**: cache-aware routing high-cache-hit 的请求集中到特定 instance
每个 instance 的实际 prefill token 数大幅减少71% cache
prefill-decode 干扰因 prefill 工作量降低而自然缓解
### 假设 C: KV Cache 传输开销可以忽略
- DistServe 假设 PD KV 传输延迟远小于 prefill 计算时间
- InfiniBand/NVLink 等高带宽互联下成立
**在 agentic workload 中的验证**: 不成立
```
PD-Sep TTFT p50 = 1.261s vs Combined TTFT p50 = 0.731s (+72%)
```
原因
1. Agentic workload input 极长p50=20k, p90=88k tokensKV cache 很大
2. 单请求 KV cache = 20k tokens × 48 layers × 2(K+V) × 512 bytes 1GB
3. 更重要的是 await-prefill 链路的串行延迟proxy prefill KV transfer decode first token
### 假设 D: 专用 prefill 节点可以提高 prefill 吞吐
- Prefill 节点不做 decodeGPU 利用率更高
- 可以用更大的 batch size
**在 agentic workload 中的验证**: 收益被 cache 稀释
```
理论 prefix cache hit (infinite cache): 71% of input tokens
实际 APC (Combined, cache-aware, 8 inst): 44.7%
```
71% cache hit 只有 29% input tokens 需要实际 prefill compute
Nominal avg input 33.6k Actual avg new prefill ~9.7k tokens
专用 prefill GPU 利用率优势因 prefill 工作量降低而缩小
## 3. Roofline 分析Prefill 在高 Cache Reuse 下的计算/访存特性
### 3.1 模型计算结构
```
Qwen3-Coder-30B-A3B (MoE 128E top-8):
48 layers, hidden=2048, heads=32, kv_heads=4 (GQA), head_dim=128
FFN: 6144 intermediate per expert, 8 experts active per token
Active params per token: ~3B
H20 GPU: 148 TFLOPS (BF16), 4.0 TB/s HBM → Ridge point: 37 FLOP/byte
```
### 3.2 Decode 永远 memory-bound
```
SeqLen FLOP Bytes AI (F/B) Bound
1,000 3.04e+10 3.01e+10 1.0 MEMORY
16,000 3.63e+10 3.16e+10 1.1 MEMORY
64,000 5.52e+10 3.63e+10 1.5 MEMORY
128,000 8.03e+10 4.26e+10 1.9 MEMORY
```
Decode AI 始终 < 2远低于 ridge point (37)。每个 decode step 只处理 1 token
计算量极小瓶颈在于读取模型权重和全量 KV cache
### 3.3 Prefill 即使 95% reuse 仍然 compute-bound
```
SeqLen Reuse% NewTok AI (F/B) Bound vs Decode
32,000 0% 32,000 23,368 COMPUTE 18,190x
32,000 50% 16,000 14,899 COMPUTE 11,597x
32,000 70% 9,600 10,045 COMPUTE 7,819x
32,000 90% 3,200 3,821 COMPUTE 2,974x
32,000 95% 1,600 1,980 COMPUTE 1,542x
64,000 0% 64,000 40,758 COMPUTE 26,813x
64,000 70% 19,200 20,610 COMPUTE 13,559x
64,000 90% 6,400 8,544 COMPUTE 5,621x
64,000 95% 3,200 4,549 COMPUTE 2,993x
```
### 3.4 为什么高 reuse 不改变 compute-bound 性质
KV cache reuse 减少的
- K/V projection 计算只算 new tokens
- KV 写入只写 new tokens
KV cache reuse **不减少**
- **Q×K^T attention**: 每个 new Q 都要和全部 seq_len KV attention
```
FLOPs = new_tokens × seq_len × head_dim × num_heads × 2 × num_layers
```
At 95% reuse, 32k seq: 1600 × 32000 × 128 × 32 × 2 × 48 ≈ 2×10^13
这个二次项在长 context 下主导总计算量
- **MoE FFN**: 每个 new token 激活 8 experts
```
FLOPs = new_tokens × 3 × D × D_ffn × 2 × K_experts × num_layers
```
**Prefill 只在接近 100% reuse (< 10 new tokens) 时才变成 memory-bound。**
### 3.5 Prefill 什么时候变 memory-bound
```
SeqLen=32,000: new_tokens ≈ 5-10 时 → AI ≈ 37 (ridge point)
SeqLen=64,000: new_tokens ≈ 5-10 时 → AI ≈ 37
```
在实际 agentic trace 中:
```
Compute-bound prefills: 961 (96%)
Memory-bound prefills: 37 (3%) ← 近 100% reuse 的极端 warm 请求
```
### 3.6 关键洞察:"Compute-bound but lightweight"
高 cache reuse 下的 prefill 处于一种独特状态:
```
Prefill bound 类型: Compute-bound (不变)
Prefill 绝对工作量: 大幅降低 (71% cache → 只算 29% 的 tokens)
Prefill-Decode 干扰: 因绝对工作量降低而减轻 (不需要物理隔离)
```
这解释了为什么 PD 分离没有帮助:
- PD 分离解决的是 "prefill 太重干扰 decode" 的问题
- 但 cache-aware routing 已经把 prefill 的实际工作量降到足够轻
- 物理隔离PD 分离)的收益被 KV 传输开销抵消
## 4. 实验结果
### 4.1 完整实验矩阵
所有实验使用统一的 cache-aware + token-level load-balanced global scheduler。
| Config | OK/N | TTFT p50 | TPOT p90 | E2E p50 | APC |
|--------|------|----------|----------|---------|-----|
| TP=8 DP=1 (single instance) | 998/1000 | 0.467s | 0.129s | 3.30s | 53.0% |
| TP=2 DP=4 (4 inst, RR) | 997/999 | 0.844s | 0.095s | 4.92s | 33.5% |
| TP=1 DP=8 (8 inst, RR) | 997/999 | 1.836s | 0.086s | 6.67s | 20.8% |
| **TP=1 DP=8 (cache-aware)** | **997/999** | **0.731s** | **0.073s** | **4.48s** | **44.7%** |
| TP=1 PD-Sep 4P+4D (cache-aware) | 509/564 | 1.261s | 0.074s | 5.61s | 40.2% |
### 4.2 Cache-Aware Routing 的效果
```
Round-robin → Cache-aware (Combined TP=1 DP=8):
TTFT p50: 1.836s → 0.731s (-60%)
TPOT p90: 0.086s → 0.073s (-15%)
E2E p50: 6.673s → 4.480s (-33%)
APC: 20.8% → 44.7% (+24pp)
```
Cache-aware routing 的提升远大于 PD 分离的提升
### 4.3 修复工程问题的过程
实验过程中发现并修复了多个 PD 分离的工程问题
| 问题 | 根因 | 修复 |
|------|------|------|
| Decode engine crash | vLLM scheduler assert: KV transfer 回调时 request abort | Patch scheduler.py: assert graceful skip |
| Head-of-line blocking | Proxy request count LB不区分大小请求 | Token-level ongoing_tokens load balancing |
| "Timeout waiting for P side ready" | Proxy fire-and-forget prefill, decode 盲等 KV | Await-prefill + kv_load_failure_policy=recompute |
| Port collision on startup | 8 Mooncake instances 同时启动争抢 torch distributed port | Staggered startup + explicit MASTER_PORT |
| Cache routing "rich get richer" | score = ongoing - alpha*cached 导致流量集中到一个 instance | Normalized scoring: ongoing/avg_load - alpha*cache_ratio |
## 5. 结论
### 5.1 PD 分离为什么在 Agentic Workload 不生效
1. **Cache reuse 大幅降低 prefill 绝对工作量71% cache hit → 只算 29%**使得 P-D 干扰不显著
2. **Prefill 仍然 compute-bound**即使 95% reuseAI >1000但每个请求的总 FLOPs 因 new_tokens 减少而大幅降低
3. **Cache-aware routing 提供 "软 PD 隔离"**,效果等同于物理隔离但无 KV 传输开销
4. **KV 传输开销不可忽略**TTFT +72%),抵消了隔离收益
5. **MoE 模型 active params 小**3Bper-token compute 本身较轻
### 5.2 PD 分离在什么条件下有价值
| 条件 | Chatbot (有价值) | Agentic (无价值) |
|------|-----------------|-----------------|
| Cache hit rate | <10% | **71%** |
| Model active params | 70B (dense) | **3B (MoE)** |
| I/O ratio | 1-10x | **75.6x** |
| Per-request prefill FLOPs | Very high | **Low (after cache)** |
| KV transfer cost vs prefill cost | Negligible | **Significant** |
### 5.3 Agentic Workload 应该怎么优化
1. **Cache-aware routing** (已验证有效): ongoing_tokens + prefix_cache_hit 做联合调度
APC 20.8% (RR) 提升到 44.7%TPOT p90 降低 15%
2. **Cross-instance KV cache sharing**: 让多个 instance 共享全局 KV pool
进一步提升 cache hit 率接近理论 71%
3. **Prefix pre-warming**: cold start 请求55%0% cache hit
预计算 common prefix (system prompt blocks) 并分发到所有 instance
4. **不同 workload 类型的差异化处理**:
- Warm 请求 (22%, >90% cache hit, avg 1.3k new tokens): 几乎免费,任何 instance 都能处理
- Cold 请求 (55%, 0% cache hit, avg 17.7k new tokens): prefill-heavy需要有足够 compute
- 可以用 request-type-aware routing 进一步优化

View File

@@ -0,0 +1,130 @@
# Prefill 在高 KV Cache Reuse 下的计算/访存分析
## Model & GPU
```
Qwen3-Coder-30B-A3B (MoE 128E top-8)
48 layers, hidden=2048, heads=32, kv_heads=4 (GQA), head_dim=128
FFN: 6144 intermediate per expert, 8 experts active per token
Active params: ~3B per token
H20 GPU: 148 TFLOPS (BF16) / 4.0 TB/s HBM
Ridge point: 37 FLOP/byte
```
## 核心发现Prefill 即使 95% reuse 仍然是 compute-bound
```
SeqLen Reuse% NewTok AI (F/B) Bound vs Decode AI
32,000 0% 32,000 23368 COMPUTE 18189x
32,000 70% 9,600 10045 COMPUTE 7819x
32,000 90% 3,200 3821 COMPUTE 2974x
32,000 95% 1,600 1980 COMPUTE 1541x
64,000 0% 64,000 40758 COMPUTE 26813x
64,000 70% 19,200 20610 COMPUTE 13559x
64,000 90% 6,400 8544 COMPUTE 5621x
64,000 95% 3,200 4549 COMPUTE 2993x
Decode (always):
32,000 - 1 1.3 MEMORY 1x
64,000 - 1 1.5 MEMORY 1x
```
**关键**
- Decode 的 arithmetic intensity (AI) = 1.0-1.9 — 远低于 ridge point (37),始终 memory-bound
- Prefill 即使 95% reuse (只有 5% 新 token)AI 仍然 >1000 — 远高于 ridge point依然 compute-bound
## 为什么高 reuse 的 prefill 仍然是 compute-bound
### 原因Attention 的计算量与 seq_len 成正比
当有 95% cache reuse (seq_len=64k, new_tokens=3200):
```
Q projection: new_tokens × D × D → 只处理 3200 new tokens ✓
K,V projection: new_tokens × D × D_kv → 只处理 3200 new tokens ✓
但 Attention score: new_tokens × seq_len × D_head × H × L
= 3200 × 64000 × 128 × 32 × 48
→ 仍然要对全部 64k context 做注意力计算!
FFN (MoE): new_tokens × 3 × D × D_ffn × 2 × K_experts × L
= 3200 × 3 × 2048 × 6144 × 2 × 8 × 48
→ 8 个 expert 的计算量仍然很大
```
KV cache reuse 减少的是:
- K/V projection 的计算(只算 new tokens
- KV 写入(只写 new tokens
**不减少的是**
- Q 对全部 context 的 attention每个 new Q 都要和所有 64k tokens 做 attention
- MoE FFN 的计算(每个 new token 激活 8 个 expert
所以 prefill 的 FLOPs 虽然随 reuse 减少,但 **减少的是线性部分投影不减少的是二次部分attention**
在长 context 下,二次部分主导,使得即使 95% reuseAI 仍远高于 ridge point。
## Prefill 什么时候才变成 memory-bound
```
SeqLen=32,000: new_tokens ≈ 5-10 时 (reuse > 99.97%) → AI ≈ 37
SeqLen=64,000: new_tokens ≈ 5-10 时 → AI ≈ 37
```
只有在 **近乎 100% reuse**(仅 5-10 个 new tokensprefill 才接近 memory-bound。
在实际 agentic trace 中,只有 3% 的请求达到这个程度。
## 对 PD 分离的影响:修正之前的分析
### 之前的错误结论(已修正)
> "Prefill 大部分是 cache lookup 不是 compute"
这是 **错误的**。即使 70% cache reuseprefill 的 AI 仍然是 decode 的 7000-14000 倍。
Prefill 始终是 compute-bounddecode 始终是 memory-bound。
### 那为什么 PD 分离在我们的实验中没有帮助?
正确的解释不是 "prefill 变成了 memory-bound",而是:
**1. Cache reuse 大幅减少了 prefill 的绝对计算量**
```
无 cache: avg 33.6k tokens × prefill compute = X FLOPs
71% cache: avg 9.4k tokens × prefill compute = 0.28X FLOPs
```
虽然 prefill 仍是 compute-bound**总工作量只有原来的 28%**
在 8 instance 并行 + cache-aware routing 下,每个 instance 的 prefill 负载非常轻,
不足以产生对 decode 的显著干扰。
**2. MoE 模型的 per-token compute 本身较小**
Active params 只有 3B全参数的 10%),单个 token 的计算量不大。
对比 Dense 70B 模型,同样的 GPU 上 prefill-decode 干扰会严重得多。
**3. Cache-aware routing 的 "负载均衡" 效应**
当请求被路由到 cache 命中率高的 instance 时,该 instance 的实际 prefill 工作量更小,
自然减少了 P-D 争抢。这相当于 routing 层面的 "软 PD 分离"。
## 对比不同 workload 类型的 roofline 特征
```
Prefill AI Decode AI PD-Sep 价值
Dense 70B, Chatbot: 200-1000x 1-2x HIGH (compute-heavy P 干扰 D)
Dense 70B, Agent: 100-500x 1-2x MEDIUM (cache reduces P load)
MoE 30B, Chatbot: 100-500x 1-2x MEDIUM
MoE 30B, Agent: 50-200x 1-2x LOW (small active params + cache)
← 我们的位置
```
**PD 分离的 ROI 随着 cache hit 率升高和模型 active params 减少而下降。**
Agentic MoE 模型恰好在两个方面都不利于 PD 分离。
## 实际 trace 的 prefill bound 分布
```
With actual trace prefix cache pattern (1000 sampled requests):
Compute-bound prefills: 961 (96%)
Memory-bound prefills: 37 (3%) ← 近 100% reuse 的 warm 请求
(Decode is ALWAYS memory-bound)
```
96% 的 prefill 仍然是 compute-bound**absolute compute 因 cache 大幅降低**
这是一个 "compute-bound but lightweight" 的独特状态 —— bound 类型没变,但强度大幅降低。

16
pyproject.toml Normal file
View File

@@ -0,0 +1,16 @@
[project]
name = "agentic-kv"
version = "0.1.0"
description = "Trace-driven KV cache benchmarking for agentic LLM workloads"
requires-python = ">=3.10"
dependencies = [
"httpx>=0.27",
"numpy>=1.24",
]
[project.optional-dependencies]
dev = ["pytest"]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

0
replayer/__init__.py Normal file
View File

55
replayer/__main__.py Normal file
View File

@@ -0,0 +1,55 @@
"""CLI entry point: python -m replayer replay ..."""
from __future__ import annotations
import argparse
import asyncio
import logging
from pathlib import Path
from .replay import ReplayConfig, replay_trace
def main() -> None:
p = argparse.ArgumentParser(description="Trace replayer for vLLM benchmarking")
p.add_argument("--trace", type=Path, required=True, help="Sampled trace JSONL")
p.add_argument("--output", type=Path, required=True, help="Output metrics JSONL")
p.add_argument("--endpoint", type=str, required=True,
help="vLLM server URL (e.g. http://localhost:8000)")
p.add_argument("--model", type=str, default="default", help="Model name for API")
p.add_argument("--time-scale", type=float, default=1.0,
help="Time compression (>1 = faster)")
p.add_argument("--max-inflight-sessions", type=int, default=32)
p.add_argument("--concurrency-limit", type=int, default=256)
p.add_argument("--request-timeout", type=float, default=600.0)
p.add_argument("--request-limit", type=int, default=None,
help="Limit number of requests to replay")
p.add_argument("-v", "--verbose", action="store_true")
args = p.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
config = ReplayConfig(
trace_path=args.trace,
output_path=args.output,
endpoint_url=args.endpoint.rstrip("/"),
model_name=args.model,
time_scale=args.time_scale,
max_inflight_sessions=args.max_inflight_sessions,
concurrency_limit=args.concurrency_limit,
request_timeout_s=args.request_timeout,
request_limit=args.request_limit,
)
results = asyncio.run(replay_trace(config))
succeeded = sum(1 for r in results if r.error is None)
print(f"\nDone: {succeeded}/{len(results)} requests succeeded")
print(f"Metrics: {args.output}")
print(f"Summary: {args.output.with_suffix('.summary.json')}")
if __name__ == "__main__":
main()

107
replayer/metrics.py Normal file
View File

@@ -0,0 +1,107 @@
"""Per-request metrics collection and summary reporting."""
from __future__ import annotations
import asyncio
import json
import statistics
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
@dataclass(frozen=True)
class RequestMetrics:
request_id: str
session_id: str
turn_id: int
trace_timestamp_s: float
input_length: int
output_length: int
request_type: str
effective_input_length: int | None
cached_tokens: int
latency_s: float | None
ttft_s: float | None
tpot_s: float | None
actual_output_tokens: int | None = None
requested_output_tokens: int | None = None
finish_reason: str | None = None
error: str | None = None
class IncrementalMetricSink:
"""Append each RequestMetrics to JSONL immediately (crash-safe)."""
def __init__(self, path: Path):
self.path = path
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("")
self._lock = asyncio.Lock()
self._fh = path.open("a", encoding="utf-8", buffering=1)
async def append(self, metric: RequestMetrics) -> None:
line = json.dumps(asdict(metric), sort_keys=True) + "\n"
async with self._lock:
self._fh.write(line)
self._fh.flush()
def close(self) -> None:
try:
self._fh.flush()
self._fh.close()
except Exception:
pass
def write_summary_json(path: Path, rows: list[RequestMetrics]) -> None:
successful = [r for r in rows if r.error is None]
latencies = [r.latency_s for r in successful if r.latency_s is not None]
ttfts = [r.ttft_s for r in successful if r.ttft_s is not None]
tpots = [r.tpot_s for r in successful if r.tpot_s is not None]
total_input = sum(r.input_length for r in successful)
total_cached = sum(r.cached_tokens for r in successful)
summary: dict[str, Any] = {
"request_count": len(rows),
"success_count": len(successful),
"error_count": sum(1 for r in rows if r.error is not None),
"latency_stats_s": _stats(latencies),
"ttft_stats_s": _stats(ttfts),
"tpot_stats_s": _stats(tpots),
"cache_hit_request_count": sum(1 for r in successful if r.cached_tokens > 0),
"total_input_tokens": total_input,
"total_cached_tokens": total_cached,
"prefix_cache_hit_ratio": total_cached / total_input if total_input > 0 else 0.0,
"cached_tokens_stats": _stats([float(r.cached_tokens) for r in successful]),
"actual_output_tokens_stats": _stats(
[float(r.actual_output_tokens) for r in successful
if r.actual_output_tokens is not None]
),
}
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as fh:
json.dump(summary, fh, indent=2, sort_keys=True)
def _stats(values: list[float | None]) -> dict[str, float] | None:
clean = [v for v in values if v is not None]
if not clean:
return None
clean.sort()
return {
"count": float(len(clean)),
"mean": statistics.fmean(clean),
"p50": _percentile(clean, 0.50),
"p90": _percentile(clean, 0.90),
"p99": _percentile(clean, 0.99),
}
def _percentile(sorted_vals: list[float], pct: float) -> float:
if len(sorted_vals) == 1:
return sorted_vals[0]
idx = round((len(sorted_vals) - 1) * pct)
return sorted_vals[idx]

343
replayer/replay.py Normal file
View File

@@ -0,0 +1,343 @@
"""Trace replayer — send requests to vLLM following trace timing.
Supports both vLLM's /v1/completions (OpenAI-compatible) and /generate
(SGLang-style) endpoints. Uses hash_ids from the trace to construct
synthetic prompts that reproduce realistic prefix-cache hit patterns.
Key behaviors:
- Per-session sequencing: turns within a session are sent in order,
each waiting for the previous to complete before dispatching.
- Inter-session arrival: sessions start at their trace timestamps,
scaled by --time-scale.
- Concurrency control: --max-inflight-sessions caps concurrent sessions;
--concurrency-limit caps total in-flight requests.
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import random as _random
import httpx
from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json
from .trace import TraceRequest, load_trace
logger = logging.getLogger(__name__)
BLOCK_SIZE = 512
VOCAB_SIZE = 151936
TOKEN_RANGE_START = 100
TOKEN_RANGE_END = VOCAB_SIZE - 100
_block_cache: dict[int, list[int]] = {}
def _hash_id_to_token_ids(hash_id: int) -> list[int]:
"""Deterministically map a hash_id to BLOCK_SIZE token IDs."""
if hash_id in _block_cache:
return _block_cache[hash_id]
rng = _random.Random(hash_id)
ids = [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END) for _ in range(BLOCK_SIZE)]
_block_cache[hash_id] = ids
return ids
@dataclass
class ReplayConfig:
trace_path: Path
output_path: Path
endpoint_url: str # comma-separated for round-robin: "http://host:8000,http://host:8001"
time_scale: float = 1.0
max_inflight_sessions: int = 32
concurrency_limit: int = 256
request_timeout_s: float = 600.0
request_limit: int | None = None
model_name: str = "default"
def _build_prompt_token_ids(req: TraceRequest) -> list[int]:
"""Build token IDs from hash_ids for prefix-cache-aware replay.
Same hash_id prefix → same token ID prefix → APC cache hit in vLLM.
"""
ids: list[int] = []
for hid in req.hash_ids:
ids.extend(_hash_id_to_token_ids(hid))
# Pad to input_length with deterministic tokens
pad_rng = _random.Random(req.chat_id)
while len(ids) < req.input_length:
ids.append(pad_rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END))
return ids[:req.input_length]
@dataclass
class _SessionState:
session_id: str
turns: list[TraceRequest]
metrics: list[RequestMetrics] = field(default_factory=list)
_endpoint_counter = 0
def _pick_endpoint(config: ReplayConfig) -> str:
"""Round-robin across comma-separated endpoints."""
global _endpoint_counter
endpoints = [e.strip() for e in config.endpoint_url.split(",")]
url = endpoints[_endpoint_counter % len(endpoints)]
_endpoint_counter += 1
return url
async def _dispatch_request(
*,
client: httpx.AsyncClient,
config: ReplayConfig,
req: TraceRequest,
prompt_token_ids: list[int],
sem: asyncio.Semaphore,
) -> RequestMetrics:
"""Send one request via /v1/completions (streaming) and collect metrics."""
endpoint = _pick_endpoint(config)
payload = {
"model": config.model_name,
"prompt": prompt_token_ids,
"max_tokens": max(1, req.output_length),
"temperature": 0,
"stream": True,
"stream_options": {"include_usage": True},
}
start = time.perf_counter()
ttft_s = None
n_output = 0
cached_tokens = 0
finish_reason = None
err = None
token_times: list[float] = []
async with sem:
try:
async with client.stream(
"POST",
f"{endpoint}/v1/completions",
json=payload,
timeout=config.request_timeout_s,
) as resp:
resp.raise_for_status()
async for raw_line in resp.aiter_lines():
if not raw_line or not raw_line.startswith("data:"):
continue
data = raw_line[5:].strip()
if data == "[DONE]":
break
try:
chunk = json.loads(data)
except json.JSONDecodeError:
continue
now = time.perf_counter()
if ttft_s is None:
ttft_s = now - start
choices = chunk.get("choices", [])
if choices:
delta = choices[0].get("text", "")
if delta:
token_times.append(now)
fr = choices[0].get("finish_reason")
if fr:
finish_reason = fr
usage = chunk.get("usage")
if usage:
n_output = usage.get("completion_tokens", n_output)
cached_tokens = _extract_cached_tokens(usage)
except Exception as exc:
err = repr(exc)[:300]
end = time.perf_counter()
e2e = end - start
if n_output == 0 and token_times:
n_output = len(token_times)
tpot = 0.0
if len(token_times) > 1:
inter_token = [token_times[i+1] - token_times[i]
for i in range(len(token_times) - 1)]
tpot = sum(inter_token) / len(inter_token)
return RequestMetrics(
request_id=req.request_id,
session_id=req.session_id,
turn_id=req.turn_id,
trace_timestamp_s=req.timestamp_s,
input_length=req.input_length,
output_length=req.output_length,
request_type=req.request_type,
effective_input_length=len(prompt_token_ids),
cached_tokens=cached_tokens,
latency_s=e2e,
ttft_s=ttft_s,
tpot_s=tpot,
actual_output_tokens=n_output,
requested_output_tokens=req.output_length,
finish_reason=finish_reason,
error=err,
)
def _extract_cached_tokens(usage: dict) -> int:
ct = 0
details = usage.get("prompt_tokens_details")
if isinstance(details, dict):
ct = details.get("cached_tokens", 0) or 0
if ct == 0:
ct = usage.get("cached_tokens", 0) or 0
return int(ct)
async def _run_session(
*,
state: _SessionState,
config: ReplayConfig,
client: httpx.AsyncClient,
session_sem: asyncio.Semaphore,
request_sem: asyncio.Semaphore,
earliest_ts: float,
sweep_start: float,
sink: IncrementalMetricSink,
) -> list[RequestMetrics]:
async with session_sem:
# Wait until this session's start time
offset = (state.turns[0].timestamp_s - earliest_ts) / config.time_scale
wait = offset - (time.perf_counter() - sweep_start)
if wait > 0:
await asyncio.sleep(wait)
for req in state.turns:
# Intra-session: wait for turn's relative offset
if req != state.turns[0]:
target = (req.timestamp_s - state.turns[0].timestamp_s) / config.time_scale
elapsed = time.perf_counter() - sweep_start - offset
if elapsed < target:
await asyncio.sleep(target - elapsed)
token_ids = _build_prompt_token_ids(req)
metric = await _dispatch_request(
client=client, config=config, req=req,
prompt_token_ids=token_ids, sem=request_sem,
)
state.metrics.append(metric)
await sink.append(metric)
return state.metrics
async def _snapshot_prefix_cache_metrics(url_csv: str) -> dict[str, float]:
"""Scrape vLLM /metrics for prefix cache counters (aggregated across endpoints)."""
total = {"queries": 0.0, "hits": 0.0}
endpoints = [e.strip() for e in url_csv.split(",")]
async with httpx.AsyncClient(timeout=10) as c:
for url in endpoints:
try:
r = await c.get(f"{url}/metrics")
for line in r.text.split("\n"):
if line.startswith("vllm:prefix_cache_queries_total"):
total["queries"] += float(line.split()[-1])
elif line.startswith("vllm:prefix_cache_hits_total"):
total["hits"] += float(line.split()[-1])
except Exception:
pass
return total
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
"""Main entry: load trace, replay against endpoint, return metrics."""
requests = load_trace(config.trace_path, request_limit=config.request_limit)
if not requests:
return []
by_session: dict[str, list[TraceRequest]] = defaultdict(list)
for r in requests:
by_session[r.session_id].append(r)
for sid in by_session:
by_session[sid].sort(key=lambda r: (r.turn_id, r.timestamp_s))
sessions = sorted(by_session.items(), key=lambda kv: kv[1][0].timestamp_s)
earliest_ts = sessions[0][1][0].timestamp_s
session_sem = asyncio.Semaphore(config.max_inflight_sessions)
request_sem = asyncio.Semaphore(config.concurrency_limit)
sink = IncrementalMetricSink(config.output_path)
n_sessions = len(sessions)
n_requests = len(requests)
logger.info("Replaying %d sessions (%d requests), time_scale=%.1f",
n_sessions, n_requests, config.time_scale)
pre_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
sweep_start = time.perf_counter()
try:
limits = httpx.Limits(
max_connections=2000,
max_keepalive_connections=500,
keepalive_expiry=30.0,
)
async with httpx.AsyncClient(
timeout=config.request_timeout_s,
trust_env=False,
limits=limits,
) as client:
tasks = [
asyncio.create_task(_run_session(
state=_SessionState(session_id=sid, turns=turns),
config=config, client=client,
session_sem=session_sem, request_sem=request_sem,
earliest_ts=earliest_ts, sweep_start=sweep_start,
sink=sink,
))
for sid, turns in sessions
]
all_results = await asyncio.gather(*tasks)
finally:
sink.close()
sweep_elapsed = time.perf_counter() - sweep_start
post_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
flat = [m for group in all_results for m in group]
summary_path = config.output_path.with_suffix(".summary.json")
write_summary_json(summary_path, flat)
# Compute aggregate prefix cache hit ratio from /metrics deltas
delta_queries = post_metrics.get("queries", 0) - pre_metrics.get("queries", 0)
delta_hits = post_metrics.get("hits", 0) - pre_metrics.get("hits", 0)
hit_ratio = delta_hits / delta_queries if delta_queries > 0 else 0.0
logger.info("Done: %d/%d succeeded in %.1fs", sum(1 for m in flat if m.error is None), len(flat), sweep_elapsed)
logger.info("Prefix cache: %.1f%% hit ratio (%d/%d tokens)",
hit_ratio * 100, int(delta_hits), int(delta_queries))
# Append cache stats to summary
import json as _json
summary = _json.loads(summary_path.read_text())
summary["prefix_cache_queries_tokens"] = int(delta_queries)
summary["prefix_cache_hits_tokens"] = int(delta_hits)
summary["prefix_cache_hit_ratio"] = hit_ratio
summary["wall_clock_s"] = sweep_elapsed
summary_path.write_text(_json.dumps(summary, indent=2, sort_keys=True))
logger.info("Summary written to %s", summary_path)
return flat

84
replayer/trace.py Normal file
View File

@@ -0,0 +1,84 @@
"""Trace data structures and loader for the Ali agentic-coder trace format.
Trace format (one JSON per line):
chat_id, parent_chat_id, timestamp, input_length, output_length,
type, turn, hash_ids[]
Sessions are derived from parent_chat_id chains:
- parent_chat_id == -1 → new session root
- parent_chat_id >= 0 → belongs to the same session as the parent
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
@dataclass(frozen=True)
class TraceRequest:
request_id: str
session_id: str
chat_id: int
parent_chat_id: int
timestamp_s: float
input_length: int
output_length: int
request_type: str
turn_id: int
hash_ids: tuple[int, ...]
def load_trace(
path: Path,
*,
request_limit: int | None = None,
) -> list[TraceRequest]:
"""Load trace and resolve session IDs from parent_chat_id chains."""
chat_to_session: dict[int, str] = {}
requests: list[TraceRequest] = []
with path.open("r", encoding="utf-8") as fh:
for idx, line in enumerate(fh):
if request_limit is not None and len(requests) >= request_limit:
break
row = json.loads(line)
chat_id = int(row["chat_id"])
parent_chat_id = int(row["parent_chat_id"])
if "session_id" in row:
session_id = str(row["session_id"])
else:
session_id = _resolve_session_id(
chat_id, parent_chat_id, chat_to_session,
)
chat_to_session[chat_id] = session_id
requests.append(TraceRequest(
request_id=f"{session_id}:{row['turn']}:{chat_id}:{idx}",
session_id=session_id,
chat_id=chat_id,
parent_chat_id=parent_chat_id,
timestamp_s=float(row["timestamp"]),
input_length=int(row["input_length"]),
output_length=int(row["output_length"]),
request_type=str(row["type"]),
turn_id=int(row["turn"]),
hash_ids=tuple(int(h) for h in row.get("hash_ids", [])),
))
return requests
def _resolve_session_id(
chat_id: int,
parent_chat_id: int,
chat_to_session: dict[int, str],
) -> str:
if parent_chat_id < 0:
session_id = str(chat_id)
else:
session_id = chat_to_session.get(parent_chat_id, str(parent_chat_id))
chat_to_session[chat_id] = session_id
return session_id

View File

@@ -0,0 +1,196 @@
"""Analyze theoretical vs actual KV cache hit ratio for the agentic trace."""
import json
from collections import Counter
rows = [json.loads(l) for l in open("traces/sampled_1000req_seed42.jsonl")]
rows.sort(key=lambda r: float(r["timestamp"]))
BLOCK_SIZE = 512
# === 1. Theoretical max: infinite cache, single instance ===
total_tokens = 0
total_cached = 0
seen_blocks = set()
per_req = []
for r in rows:
input_len = r["input_length"]
hash_ids = r.get("hash_ids", [])
total_tokens += input_len
cached_blocks = 0
prefix_broken = False
for hid in hash_ids:
if not prefix_broken and hid in seen_blocks:
cached_blocks += 1
else:
prefix_broken = True
cached_tokens = cached_blocks * BLOCK_SIZE
total_cached += cached_tokens
for hid in hash_ids:
seen_blocks.add(hid)
per_req.append({
"input_length": input_len,
"cached_tokens": cached_tokens,
"new_tokens": max(0, input_len - cached_tokens),
"ratio": cached_tokens / input_len if input_len > 0 else 0,
})
sep = "=" * 70
print(sep)
print(" THEORETICAL KV CACHE HIT (infinite cache, single instance)")
print(sep)
print(f" Total input tokens: {total_tokens:>14,}")
print(f" Cacheable (prefix hit): {total_cached:>14,} ({total_cached*100//total_tokens}%)")
print(f" Must prefill (new): {total_tokens-total_cached:>14,} ({(total_tokens-total_cached)*100//total_tokens}%)")
ratios = sorted([s["ratio"] for s in per_req if s["input_length"] > 0])
new_tokens = sorted([s["new_tokens"] for s in per_req if s["input_length"] > 0])
p = lambda v, q: v[min(int(q*len(v)), len(v)-1)]
print(f"\n Per-request cache hit ratio:")
print(f" p10={p(ratios,.1)*100:.1f}% p50={p(ratios,.5)*100:.1f}% p90={p(ratios,.9)*100:.1f}% mean={sum(ratios)/len(ratios)*100:.1f}%")
high = sum(1 for r in ratios if r > 0.5)
very_high = sum(1 for r in ratios if r > 0.9)
zero = sum(1 for r in ratios if r == 0)
print(f" 0% hit (cold start): {zero} ({zero*100//len(ratios)}%)")
print(f" >50% hit: {high} ({high*100//len(ratios)}%)")
print(f" >90% hit: {very_high} ({very_high*100//len(ratios)}%)")
print(f"\n Actual new tokens to prefill per request:")
print(f" p10={p(new_tokens,.1):>7,} p50={p(new_tokens,.5):>7,} p90={p(new_tokens,.9):>7,} max={max(new_tokens):>7,}")
# === 2. 4-instance split (simulating DP=4 or 4 prefill instances) ===
print(f"\n{sep}")
print(" 4-INSTANCE SPLIT (round-robin, per-instance cache)")
print(sep)
instance_seen = [set() for _ in range(4)]
inst_total = [0]*4
inst_cached = [0]*4
for i, r in enumerate(rows):
inst = i % 4
input_len = r["input_length"]
hash_ids = r.get("hash_ids", [])
inst_total[inst] += input_len
cached_blocks = 0
prefix_broken = False
for hid in hash_ids:
if not prefix_broken and hid in instance_seen[inst]:
cached_blocks += 1
else:
prefix_broken = True
inst_cached[inst] += cached_blocks * BLOCK_SIZE
for hid in hash_ids:
instance_seen[inst].add(hid)
rr_total = sum(inst_total)
rr_cached = sum(inst_cached)
print(f" Cache hit ratio (RR): {rr_cached*100//rr_total}%")
# === 3. Cache-aware routing (route to instance with best prefix match) ===
print(f"\n{sep}")
print(" 4-INSTANCE CACHE-AWARE ROUTING")
print(sep)
ca_seen = [set() for _ in range(4)]
ca_total = [0]*4
ca_cached = [0]*4
for r in rows:
input_len = r["input_length"]
hash_ids = r.get("hash_ids", [])
# Pick instance with most prefix blocks cached
best_inst = 0
best_hit = 0
for inst in range(4):
hit = 0
for hid in hash_ids:
if hid in ca_seen[inst]:
hit += 1
else:
break
if hit > best_hit:
best_hit = hit
best_inst = inst
ca_total[best_inst] += input_len
ca_cached[best_inst] += best_hit * BLOCK_SIZE
for hid in hash_ids:
ca_seen[best_inst].add(hid)
ca_total_sum = sum(ca_total)
ca_cached_sum = sum(ca_cached)
print(f" Cache hit ratio: {ca_cached_sum*100//ca_total_sum}%")
print(f" vs RR: {rr_cached*100//rr_total}% -> {ca_cached_sum*100//ca_total_sum}% (+{(ca_cached_sum-rr_cached)*100//rr_total}pp)")
# === 4. Session structure analysis ===
print(f"\n{sep}")
print(" SESSION & MULTI-TURN ANALYSIS")
print(sep)
sessions = {}
chat_to_session = {}
for r in rows:
cid = int(r["chat_id"])
pid = int(r["parent_chat_id"])
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)))
chat_to_session[cid] = str(sid)
sessions.setdefault(str(sid), []).append(r)
multi = {k: v for k, v in sessions.items() if len(v) > 1}
single = {k: v for k, v in sessions.items() if len(v) == 1}
print(f" Sessions: {len(sessions)} total, {len(multi)} multi-turn ({len(multi)*100//len(sessions)}%)")
# Multi-turn: cache hit in turn 2+
mt_new = 0
mt_reuse = 0
for sid, turns in multi.items():
turns.sort(key=lambda r: r["turn"])
prev_blocks = set()
for t in turns:
hids = t.get("hash_ids", [])
for hid in hids:
if hid in prev_blocks:
mt_reuse += BLOCK_SIZE
else:
mt_new += BLOCK_SIZE
prev_blocks.add(hid)
mt_total_tok = mt_new + mt_reuse
print(f" Multi-turn intra-session reuse: {mt_reuse*100//mt_total_tok}% of tokens")
print(f" (Turn 2+ reuses KV from prior turns in same session)")
# Single-turn: cross-session sharing via system prompt
block_freq = Counter()
for r in rows:
for hid in r.get("hash_ids", []):
block_freq[hid] += 1
shared = {k: v for k, v in block_freq.items() if v > 1}
top = block_freq.most_common(5)
print(f"\n Cross-session block sharing:")
print(f" Unique blocks: {len(block_freq):,}")
print(f" Shared (ref>1): {len(shared):,} ({len(shared)*100//len(block_freq)}%)")
print(f" Top-5 block ref counts: {[c for _,c in top]}")
print(f" (Shared blocks = system prompt / common code context)")
# === 5. Implication for PD separation ===
print(f"\n{sep}")
print(" IMPLICATION FOR PD SEPARATION")
print(sep)
actual_prefill_pct = (total_tokens - total_cached) * 100 // total_tokens
print(f" With perfect caching, only {actual_prefill_pct}% of tokens need actual prefill compute.")
print(f" The remaining {100-actual_prefill_pct}% are prefix cache hits (skip prefill, reuse KV).")
print(f" This means PD separation's prefill overhead is much smaller than it appears:")
print(f" - Nominal avg input: {total_tokens//len(rows):,} tokens/request")
new_per_req = sorted([s["new_tokens"] for s in per_req if s["input_length"] > 0])
print(f" - Actual avg prefill: {sum(new_per_req)//len(new_per_req):,} tokens/request (after cache hit)")
print(f" - KV transfer size is also reduced (only transfer new blocks)")

163
scripts/analyze_trace.py Normal file
View File

@@ -0,0 +1,163 @@
"""Analyze trace patterns to assess PD separation benefit.
Computes metrics relevant to deciding PD-combined vs PD-separated:
- Input/output token ratio (high ratio = prefill-heavy → PD sep benefits)
- Prefix sharing density (high sharing → benefits from shared KV cache)
- Session length distribution (multi-turn = more prefix reuse)
- Arrival burstiness (bursty prefill → PD sep can absorb spikes)
- Compute-intensity ratio: prefill FLOP share vs decode FLOP share
Usage:
python scripts/analyze_trace.py --input traces/sampled_1000req_seed42.jsonl
"""
from __future__ import annotations
import argparse
import collections
import json
import statistics
from pathlib import Path
def main():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--input", type=Path, required=True)
args = p.parse_args()
rows = []
with args.input.open() as fh:
for line in fh:
rows.append(json.loads(line))
# Session structure
sessions: dict[str, list[dict]] = collections.OrderedDict()
chat_to_session: dict[int, str] = {}
for r in rows:
cid = int(r["chat_id"])
pid = int(r["parent_chat_id"])
sid = r.get("session_id")
if sid is None:
sid = str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))
chat_to_session[cid] = str(sid)
sessions.setdefault(str(sid), []).append(r)
n_sessions = len(sessions)
turns_per_session = [len(v) for v in sessions.values()]
multi_turn = sum(1 for t in turns_per_session if t > 1)
input_lens = [r["input_length"] for r in rows]
output_lens = [r["output_length"] for r in rows]
total_input = sum(input_lens)
total_output = sum(output_lens)
print("=" * 60)
print("Trace Pattern Analysis for PD Separation Decision")
print("=" * 60)
# 1. Input/Output ratio
io_ratio = total_input / max(total_output, 1)
print(f"\n1. Input/Output Token Ratio")
print(f" Total input tokens: {total_input:>12,}")
print(f" Total output tokens: {total_output:>12,}")
print(f" I/O ratio: {io_ratio:>12.1f}x")
print(f"{'STRONGLY' if io_ratio > 50 else 'Moderately' if io_ratio > 10 else 'Weakly'} prefill-heavy")
# 2. Prefill compute share
# Approximate: prefill FLOP ∝ input_length, decode FLOP ∝ output_length * input_length
# More precisely: prefill dominates when input >> output
prefill_share = total_input / (total_input + total_output)
print(f"\n2. Compute Split (token count proxy)")
print(f" Prefill share: {prefill_share*100:.1f}%")
print(f" Decode share: {(1-prefill_share)*100:.1f}%")
# 3. Session structure
print(f"\n3. Session Structure")
print(f" Sessions: {n_sessions}")
print(f" Requests: {len(rows)}")
print(f" Multi-turn: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)")
print(f" Turns/sess: min={min(turns_per_session)} max={max(turns_per_session)} "
f"avg={statistics.fmean(turns_per_session):.1f}")
# 4. Prefix sharing
all_hash_ids = set()
per_request_hashes = []
for r in rows:
hids = set(r.get("hash_ids", []))
per_request_hashes.append(hids)
all_hash_ids.update(hids)
hash_refcount = collections.Counter()
for hids in per_request_hashes:
for h in hids:
hash_refcount[h] += 1
shared_blocks = sum(1 for h, c in hash_refcount.items() if c > 1)
total_blocks = len(all_hash_ids)
block_reuse = shared_blocks / max(total_blocks, 1)
avg_refcount = statistics.fmean(hash_refcount.values()) if hash_refcount else 0
print(f"\n4. Prefix Block Sharing")
print(f" Unique blocks: {total_blocks:>10,}")
print(f" Shared (ref>1): {shared_blocks:>10,} ({block_reuse*100:.1f}%)")
print(f" Avg refcount: {avg_refcount:>10.2f}")
print(f"{'High' if block_reuse > 0.3 else 'Moderate' if block_reuse > 0.1 else 'Low'} prefix reuse potential")
# 5. Input length distribution
input_sorted = sorted(input_lens)
pct = lambda q: input_sorted[min(int(q * len(input_sorted)), len(input_sorted) - 1)]
print(f"\n5. Input Length Distribution")
print(f" p10={pct(0.1):>8,} p50={pct(0.5):>8,} p90={pct(0.9):>8,} max={max(input_lens):>8,}")
long_context = sum(1 for l in input_lens if l > 32000)
print(f" Requests >32k tokens: {long_context} ({long_context/len(rows)*100:.1f}%)")
# 6. Arrival pattern
timestamps = sorted(float(r["timestamp"]) for r in rows)
span = timestamps[-1] - timestamps[0]
avg_rate = len(rows) / max(span, 0.001)
# Burstiness: coefficient of variation of inter-arrival times
inter_arrivals = [timestamps[i+1] - timestamps[i] for i in range(len(timestamps) - 1)]
inter_arrivals = [t for t in inter_arrivals if t > 0]
if inter_arrivals:
cv = statistics.stdev(inter_arrivals) / statistics.fmean(inter_arrivals)
else:
cv = 0
print(f"\n6. Arrival Pattern")
print(f" Span: {span:.1f}s ({span/60:.1f} min)")
print(f" Avg rate: {avg_rate:.2f} req/s")
print(f" Burstiness (CoV): {cv:.2f}")
print(f"{'Bursty' if cv > 1.5 else 'Moderate' if cv > 0.8 else 'Steady'} arrival pattern")
# Summary
print(f"\n{'=' * 60}")
print("Summary: PD Separation Recommendation")
print(f"{'=' * 60}")
factors = []
if io_ratio > 50:
factors.append("Very high I/O ratio (prefill-dominated)")
elif io_ratio > 10:
factors.append("High I/O ratio")
if block_reuse > 0.1:
factors.append(f"Significant prefix reuse ({block_reuse*100:.0f}% shared blocks)")
if long_context / len(rows) > 0.3:
factors.append(f"Many long-context requests ({long_context/len(rows)*100:.0f}%)")
if cv > 1.0:
factors.append("Bursty arrivals (PD sep absorbs prefill spikes)")
if len(factors) >= 2:
print("→ RECOMMEND PD separation:")
elif len(factors) == 1:
print("→ PD separation MAY help:")
else:
print("→ PD separation likely NOT beneficial:")
for f in factors:
print(f"{f}")
if not factors:
print(" • No strong indicators for PD separation benefit")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,280 @@
"""Unified cache-aware + token-level load-balanced global scheduler.
Supports two modes:
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
Routing policy (same for both modes):
score = ongoing_tokens / avg_ongoing - ALPHA * cache_hit_ratio
Normalized load prevents "rich get richer"; cache bonus gives affinity.
Session affinity: multi-turn sessions stick to same instance.
"""
import argparse
import asyncio
import os
import urllib.parse
import uuid
from contextlib import asynccontextmanager
import httpx
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
BLOCK_SIZE = 512
CACHE_HIT_ALPHA = 1.0 # weight for cache bonus in scoring
class InstanceState:
def __init__(self, url: str, bootstrap_port: int | None = None):
self.url = url
self.bootstrap_port = bootstrap_port
self.client = httpx.AsyncClient(
timeout=None, base_url=url,
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
)
self.ongoing_tokens = 0
self.engine_id: dict[int, str] = {}
self.dp_size = 1
self.cached_blocks: set[int] = set()
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
if not token_ids or len(token_ids) < BLOCK_SIZE:
return 0
hit = 0
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
hit += BLOCK_SIZE
else:
break
return hit
def record_prefix(self, token_ids: list[int] | None):
if not token_ids:
return
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
self.cached_blocks.add(hash(tuple(token_ids[i:i + BLOCK_SIZE])))
if len(self.cached_blocks) > 200000:
self.cached_blocks = set(list(self.cached_blocks)[-100000:])
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""Normalized load - cache bonus scoring."""
if session_id and session_id in affinity:
idx = affinity[session_id]
if idx < len(instances):
return instances[idx], idx
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
cache_ratio = cache_hit / input_length if input_length > 0 else 0.0
score = inst.ongoing_tokens / avg_load - CACHE_HIT_ALPHA * cache_ratio
if score < best_score:
best_score = score
best_idx = i
if session_id:
affinity[session_id] = best_idx
return instances[best_idx], best_idx
global_args = None
combined_instances: list[InstanceState] = []
prefill_instances: list[InstanceState] = []
decode_instances: list[InstanceState] = []
session_affinity: dict[str, int] = {}
is_pd_sep = False
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
for inst in instances:
if inst.bootstrap_port is None:
continue
while True:
try:
await inst.client.get("/health")
except Exception:
await asyncio.sleep(1)
continue
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
resp = await inst.client.get(url)
resp.raise_for_status()
data = resp.json()
for dp_rank, dp_entry in data.items():
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
inst.dp_size = len(data)
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
break
ready.set()
@asynccontextmanager
async def lifespan(app: FastAPI):
global is_pd_sep
app.state.ready = asyncio.Event()
if global_args.combined:
is_pd_sep = False
for url in global_args.combined:
combined_instances.append(InstanceState(url))
app.state.ready.set()
print(f"Combined mode: {len(combined_instances)} instances")
else:
is_pd_sep = True
for url, bp in global_args.prefill:
prefill_instances.append(InstanceState(url, bp))
for url in global_args.decode:
decode_instances.append(InstanceState(url))
asyncio.create_task(init_prefill_bootstrap(prefill_instances, app.state.ready))
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
yield
for inst in combined_instances + prefill_instances + decode_instances:
await inst.client.aclose()
app = FastAPI(lifespan=lifespan)
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle(request, "/v1/completions")
@app.post("/v1/chat/completions")
async def handle_chat(request: Request):
return await _handle(request, "/v1/chat/completions")
async def _handle(request: Request, api: str):
if not app.state.ready.is_set():
raise HTTPException(status_code=503, detail="Service Unavailable")
req_data = await request.json()
request_id = str(uuid.uuid4())
prompt = req_data.get("prompt")
token_ids = prompt if isinstance(prompt, list) else None
input_length = len(token_ids) if token_ids else 0
session_id = request.headers.get("X-Session-Id")
headers = {"X-Request-Id": request_id}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if is_pd_sep:
return await _handle_pd_sep(api, req_data, request_id, token_ids,
input_length, session_id, headers)
else:
return await _handle_combined(api, req_data, token_ids,
input_length, session_id, headers)
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
"""Combined mode: route to best instance, send normal request."""
inst, idx = pick_instance(combined_instances, token_ids, session_id,
input_length, session_affinity)
inst.ongoing_tokens += input_length
async def generate():
try:
async with inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
yield chunk
inst.record_prefix(token_ids)
finally:
inst.ongoing_tokens -= input_length
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers):
"""PD-Sep mode: await prefill, then stream decode."""
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
input_length, session_affinity)
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
# Await prefill
p_inst.ongoing_tokens += input_length
try:
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True, "do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception as e:
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
finally:
p_inst.ongoing_tokens -= input_length
# Stream decode
d_inst.ongoing_tokens += input_length
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False, "do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
async def generate():
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
yield chunk
finally:
d_inst.ongoing_tokens -= input_length
return StreamingResponse(generate(), media_type="application/json")
def parse_args():
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
p.add_argument("--port", type=int, default=8000)
p.add_argument("--host", type=str, default="0.0.0.0")
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
help="PD-Sep prefill: URL [bootstrap_port]")
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
help="PD-Sep decode: URL")
args = p.parse_args()
args.prefill = []
if args.prefill_raw:
for entry in args.prefill_raw:
url = entry[0]
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
args.prefill.append((url, bp))
args.decode = [e[0] for e in (args.decode_raw or [])]
if not args.combined and not args.prefill:
p.error("Must specify either --combined or --prefill/--decode")
return args
if __name__ == "__main__":
global_args = parse_args()
uvicorn.run(app, host=global_args.host, port=global_args.port)

102
scripts/compare_results.py Normal file
View File

@@ -0,0 +1,102 @@
"""Compare benchmark results between PD-combined and PD-separated modes.
Reads summary JSON files and per-request metrics to produce a detailed
comparison report including TTFT, TPOT, E2E, cache hit ratio, and
throughput analysis.
Usage:
python scripts/compare_results.py \
--combined outputs/combined_1000req/metrics.summary.json \
--separated outputs/separated_1000req/metrics.summary.json
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
def load_summary(path: Path) -> dict:
return json.loads(path.read_text())
def load_metrics(path: Path) -> list[dict]:
rows = []
with path.open() as fh:
for line in fh:
rows.append(json.loads(line))
return rows
def fmt_stat(stat: dict | None, unit: str = "s") -> str:
if stat is None:
return "N/A"
return (f"mean={stat['mean']:.3f}{unit} "
f"p50={stat['p50']:.3f}{unit} "
f"p90={stat['p90']:.3f}{unit} "
f"p99={stat['p99']:.3f}{unit}")
def compare(combined: dict, separated: dict) -> None:
print("=" * 70)
print("PD-Combined vs PD-Separated Performance Comparison")
print("=" * 70)
for label, s in [("PD-Combined", combined), ("PD-Separated", separated)]:
print(f"\n--- {label} ---")
print(f" Requests: {s['request_count']} (success: {s['success_count']}, errors: {s['error_count']})")
print(f" Wall clock: {s.get('wall_clock_s', 0):.1f}s")
print(f" TTFT: {fmt_stat(s.get('ttft_stats_s'))}")
print(f" TPOT: {fmt_stat(s.get('tpot_stats_s'))}")
print(f" E2E: {fmt_stat(s.get('latency_stats_s'))}")
hit_ratio = s.get('prefix_cache_hit_ratio', 0)
print(f" Prefix cache hit ratio: {hit_ratio*100:.1f}%")
queries = s.get('prefix_cache_queries_tokens', 0)
hits = s.get('prefix_cache_hits_tokens', 0)
print(f" ({hits}/{queries} tokens)")
print("\n--- Comparison (Separated vs Combined) ---")
for metric_key, label in [
("ttft_stats_s", "TTFT"),
("tpot_stats_s", "TPOT"),
("latency_stats_s", "E2E"),
]:
c = combined.get(metric_key, {})
s = separated.get(metric_key, {})
if c and s:
for pct in ["mean", "p50", "p90", "p99"]:
cv, sv = c.get(pct, 0), s.get(pct, 0)
if cv > 0:
change = (sv - cv) / cv * 100
direction = "slower" if change > 0 else "faster"
print(f" {label} {pct}: {abs(change):.1f}% {direction} "
f"({cv:.3f}s → {sv:.3f}s)")
c_ratio = combined.get("prefix_cache_hit_ratio", 0)
s_ratio = separated.get("prefix_cache_hit_ratio", 0)
print(f" Cache hit ratio: {c_ratio*100:.1f}% → {s_ratio*100:.1f}%")
c_wall = combined.get("wall_clock_s", 1)
s_wall = separated.get("wall_clock_s", 1)
c_tput = combined["success_count"] / c_wall
s_tput = separated["success_count"] / s_wall
print(f" Throughput: {c_tput:.1f}{s_tput:.1f} req/s "
f"({(s_tput/c_tput - 1)*100:+.1f}%)")
def main():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--combined", type=Path, required=True)
p.add_argument("--separated", type=Path, required=True)
args = p.parse_args()
combined = load_summary(args.combined)
separated = load_summary(args.separated)
compare(combined, separated)
if __name__ == "__main__":
main()

210
scripts/compute_roofline.py Normal file
View File

@@ -0,0 +1,210 @@
"""Roofline analysis: compute/memory ratio for prefill vs decode
under different sequence lengths and KV cache reuse ratios.
Model: Qwen3-Coder-30B-A3B (MoE)
- 48 layers, hidden=2048, heads=32, kv_heads=4, head_dim=128
- MoE: 128 experts, top-8 active, intermediate=6144
- Total params: ~30B, Active params per token: ~3B
GPU: NVIDIA H20
- BF16 peak: 148 TFLOPS
- HBM bandwidth: 4.0 TB/s
- Roofline ridge point: 148/4.0 = 37 FLOP/byte
"""
import json
import math
# ===== Model config =====
L = 48 # layers
D = 2048 # hidden dim
H = 32 # attention heads
H_kv = 4 # KV heads (GQA)
D_head = 128 # head dim
D_ffn = 6144 # FFN intermediate (per expert)
N_experts = 128 # total experts
K_experts = 8 # active experts per token
VOCAB = 151936
BYTES = 2 # BF16
# ===== GPU config (H20) =====
PEAK_FLOPS = 148e12 # BF16 TFLOPS
HBM_BW = 4.0e12 # bytes/s
RIDGE_POINT = PEAK_FLOPS / HBM_BW # ~37 FLOP/byte
print("=" * 80)
print(" ROOFLINE ANALYSIS: Prefill vs Decode under KV Cache Reuse")
print(" Model: Qwen3-Coder-30B-A3B (MoE 128E top-8) | GPU: H20")
print("=" * 80)
print(f" Ridge point: {RIDGE_POINT:.1f} FLOP/byte")
print(f" Above ridge → compute-bound | Below ridge → memory-bound")
# ===== Per-token compute & memory for each component =====
def attention_prefill_flops(seq_len, new_tokens):
"""FLOPs for attention on new_tokens with seq_len context."""
# QKV projection: new_tokens * D * (D + 2*D_kv) * 2
d_kv = H_kv * D_head
qkv_flops = new_tokens * (D * D * 2 + D * d_kv * 2 * 2) # Q + K + V
# Attention score: new_tokens * seq_len * D * 2 (Q@K^T + softmax@V)
attn_flops = new_tokens * seq_len * D * 2 * 2 # simplified: 2 matmuls
# Output projection: new_tokens * D * D * 2
out_flops = new_tokens * D * D * 2
return (qkv_flops + attn_flops + out_flops) * L
def attention_prefill_bytes(seq_len, new_tokens, cached_tokens):
"""Memory access for attention prefill."""
d_kv = H_kv * D_head
# Load model weights (QKV + O projections): D*(D+2*d_kv+D) * BYTES * L
weight_bytes = D * (D + 2 * d_kv + D) * BYTES * L
# Load cached KV: cached_tokens * 2 * d_kv * BYTES * L
cached_kv_bytes = cached_tokens * 2 * d_kv * BYTES * L
# Read input activations + write output: new_tokens * D * BYTES * 2 * L
act_bytes = new_tokens * D * BYTES * 2 * L
# Write new KV to cache: new_tokens * 2 * d_kv * BYTES * L
new_kv_bytes = new_tokens * 2 * d_kv * BYTES * L
return weight_bytes + cached_kv_bytes + act_bytes + new_kv_bytes
def ffn_flops(n_tokens):
"""FLOPs for MoE FFN on n_tokens."""
# Per expert: 3 * n_tokens * D * D_ffn * 2 (gate + up + down)
# Active experts: K_experts
return 3 * n_tokens * D * D_ffn * 2 * K_experts * L
def ffn_bytes(n_tokens):
"""Memory access for MoE FFN."""
# Load K_experts worth of weights per layer: K * 3 * D * D_ffn * BYTES
weight_bytes = K_experts * 3 * D * D_ffn * BYTES * L
# Activations: n_tokens * D * BYTES * 2 * L
act_bytes = n_tokens * D * BYTES * 2 * L
return weight_bytes + act_bytes
def decode_flops(seq_len):
"""FLOPs for 1 decode token."""
return attention_prefill_flops(seq_len, 1) + ffn_flops(1)
def decode_bytes(seq_len):
"""Memory bytes for 1 decode token."""
return attention_prefill_bytes(seq_len, 1, seq_len) + ffn_bytes(1)
# ===== Analysis =====
print("\n" + "-" * 80)
print(" PART 1: Decode Roofline (baseline)")
print("-" * 80)
print(f" {'SeqLen':>8} {'FLOP':>14} {'Bytes':>14} {'AI (F/B)':>10} {'Bound':>12}")
for seq_len in [1000, 4000, 8000, 16000, 32000, 64000, 128000]:
flops = decode_flops(seq_len)
bytes_ = decode_bytes(seq_len)
ai = flops / bytes_
bound = "COMPUTE" if ai > RIDGE_POINT else "MEMORY"
print(f" {seq_len:>8,} {flops:>14.2e} {bytes_:>14.2e} {ai:>10.1f} {bound:>12}")
print("\n" + "-" * 80)
print(" PART 2: Prefill with KV Cache Reuse")
print(" (Total input = seq_len, cached = seq_len * reuse_ratio, new = rest)")
print("-" * 80)
print(f" {'SeqLen':>8} {'Reuse%':>7} {'NewTok':>8} {'FLOP':>14} {'Bytes':>14} {'AI (F/B)':>10} {'Bound':>12} {'vs Decode':>10}")
for seq_len in [4000, 16000, 32000, 64000, 128000]:
for reuse in [0.0, 0.3, 0.5, 0.7, 0.9, 0.95]:
cached = int(seq_len * reuse)
new = seq_len - cached
# Attention: compute on new tokens, but read cached KV for context
attn_f = attention_prefill_flops(seq_len, new)
attn_b = attention_prefill_bytes(seq_len, new, cached)
# FFN: only on new tokens
ffn_f = ffn_flops(new)
ffn_b = ffn_bytes(new)
total_f = attn_f + ffn_f
total_b = attn_b + ffn_b
ai = total_f / total_b if total_b > 0 else 0
# Compare with decode at same seq_len
dec_f = decode_flops(seq_len)
dec_b = decode_bytes(seq_len)
dec_ai = dec_f / dec_b
bound = "COMPUTE" if ai > RIDGE_POINT else "MEMORY"
ratio = f"{ai/dec_ai:.1f}x" if dec_ai > 0 else "N/A"
print(f" {seq_len:>8,} {reuse*100:>6.0f}% {new:>8,} {total_f:>14.2e} {total_b:>14.2e} {ai:>10.1f} {bound:>12} {ratio:>10}")
print()
print("-" * 80)
print(" PART 3: Key Thresholds")
print("-" * 80)
# At what reuse ratio does prefill become memory-bound?
for seq_len in [4000, 16000, 32000, 64000, 128000]:
for reuse_pct in range(0, 100):
reuse = reuse_pct / 100.0
cached = int(seq_len * reuse)
new = seq_len - cached
if new < 1: continue
attn_f = attention_prefill_flops(seq_len, new)
attn_b = attention_prefill_bytes(seq_len, new, cached)
ffn_f = ffn_flops(new)
ffn_b = ffn_bytes(new)
ai = (attn_f + ffn_f) / (attn_b + ffn_b)
if ai < RIDGE_POINT:
print(f" SeqLen={seq_len:>6,}: prefill becomes memory-bound at {reuse_pct}% reuse (AI={ai:.1f})")
break
print()
print("-" * 80)
print(" PART 4: Agentic Workload Real Distribution")
print("-" * 80)
# Use actual trace data
import os
trace_path = "traces/sampled_1000req_seed42.jsonl"
if os.path.exists(trace_path):
BLOCK_SIZE = 512
seen = set()
compute_bound = 0
memory_bound = 0
total = 0
for line in open(trace_path):
d = json.loads(line)
seq_len = d["input_length"]
if seq_len < 1: continue
hids = d.get("hash_ids", [])
cached_blocks = 0
for hid in hids:
if hid in seen:
cached_blocks += 1
else:
break
for hid in hids:
seen.add(hid)
cached = cached_blocks * BLOCK_SIZE
new = max(1, seq_len - cached)
reuse = cached / seq_len
attn_f = attention_prefill_flops(seq_len, new)
attn_b = attention_prefill_bytes(seq_len, new, cached)
ffn_f = ffn_flops(new)
ffn_b = ffn_bytes(new)
ai = (attn_f + ffn_f) / (attn_b + ffn_b)
total += 1
if ai > RIDGE_POINT:
compute_bound += 1
else:
memory_bound += 1
print(f" With actual trace prefix cache pattern:")
print(f" Compute-bound prefills: {compute_bound} ({compute_bound*100//total}%)")
print(f" Memory-bound prefills: {memory_bound} ({memory_bound*100//total}%)")
print(f" (Decode is ALWAYS memory-bound at these seq lengths)")
print()
print(f" Implication: {memory_bound*100//total}% of agentic prefills behave like decode")
print(f" → PD separation treats them as 'compute-heavy' but they are actually memory-heavy")

View File

@@ -0,0 +1,86 @@
"""Final comparison of PD-Combined vs PD-Separated (Mooncake/RDMA)."""
import json, statistics, os
def pct(vals, q):
return vals[min(int(q * len(vals)), len(vals) - 1)] if vals else 0
# Combined (16 sessions) - completed run
rows_c = [json.loads(l) for l in open("outputs/v18_combined_1000req/metrics.jsonl")]
ok_c = [r for r in rows_c if not r.get("error")]
ttfts_c = sorted([r["ttft_s"] for r in ok_c if r.get("ttft_s")])
tpots_c = sorted([r["tpot_s"] for r in ok_c if r.get("tpot_s") and r["tpot_s"] > 0])
lats_c = sorted([r["latency_s"] for r in ok_c if r.get("latency_s")])
sc = json.load(open("outputs/v18_combined_1000req/metrics.summary.json"))
# PD-Separated Mooncake (first 200 stable requests)
rows_d = [json.loads(l) for l in open("outputs/v18_pd_mooncake_lowconc/metrics.jsonl")][:200]
ok_d = [r for r in rows_d if not r.get("error")]
ttfts_d = sorted([r["ttft_s"] for r in ok_d if r.get("ttft_s")])
tpots_d = sorted([r["tpot_s"] for r in ok_d if r.get("tpot_s") and r["tpot_s"] > 0])
lats_d = sorted([r["latency_s"] for r in ok_d if r.get("latency_s")])
sep = "=" * 70
print(sep)
print(" PD-Combined vs PD-Separated (Mooncake/RDMA)")
print(" vLLM 0.18.1 | Qwen3-Coder-30B-A3B | 8xH20")
print(sep)
header = " {:<12} {:>16} {:>16} {:>10}".format(
"Metric", "Combined(TP=8)", "PD-Sep(TP=4+4)", "Delta")
print(header)
dash = " {:<12} {:>16} {:>16} {:>10}".format("-" * 12, "-" * 16, "-" * 16, "-" * 10)
print(dash)
req_c = "{}/{}".format(len(ok_c), len(rows_c))
req_d = "{}/{}".format(len(ok_d), len(rows_d))
print(" {:<12} {:>16} {:>16}".format("Requests", req_c, req_d))
data = [
("TTFT p50", pct(ttfts_c, 0.5), pct(ttfts_d, 0.5)),
("TTFT p90", pct(ttfts_c, 0.9), pct(ttfts_d, 0.9)),
("TPOT p50", pct(tpots_c, 0.5), pct(tpots_d, 0.5)),
("TPOT p90", pct(tpots_c, 0.9), pct(tpots_d, 0.9)),
("E2E p50", pct(lats_c, 0.5), pct(lats_d, 0.5)),
("E2E p90", pct(lats_c, 0.9), pct(lats_d, 0.9)),
]
for label, cv, dv in data:
delta = "{:+.0f}%".format((dv / cv - 1) * 100) if cv > 0 else "N/A"
print(" {:<12} {:>15.3f}s {:>15.3f}s {:>10}".format(label, cv, dv, delta))
cache_c = sc.get("prefix_cache_hit_ratio", 0)
print(" {:<12} {:>15.1f}% {:>16}".format("Cache hit", cache_c * 100, "N/A"))
tput_c = len(ok_c) / sc.get("wall_clock_s", 1)
print(" {:<12} {:>14.2f}/s {:>16}".format("Throughput", tput_c, "~0.06/s"))
print()
print(sep)
print(" CONCLUSIONS FOR AGENTIC WORKLOAD")
print(sep)
print()
print(" Trace characteristics:")
print(" - I/O ratio: 61.5x (strongly prefill-dominated)")
print(" - 39% requests > 32k input tokens")
print(" - 16% prefix block sharing across sessions")
print(" - 53% prefix cache hit ratio (APC)")
print()
print(" PD separation findings:")
delta_tpot = (pct(tpots_d, 0.5) / pct(tpots_c, 0.5) - 1) * 100 if tpots_c else 0
delta_ttft = (pct(ttfts_d, 0.5) / pct(ttfts_c, 0.5) - 1) * 100 if ttfts_c else 0
delta_e2e = (pct(lats_d, 0.5) / pct(lats_c, 0.5) - 1) * 100 if lats_c else 0
print(" 1. TPOT {:+.0f}% - decode isolation benefit is {}".format(
delta_tpot, "marginal" if abs(delta_tpot) < 20 else "significant"))
print(" 2. TTFT {:+.0f}% - KV transfer + TP=4 overhead dominates".format(delta_ttft))
print(" 3. E2E {:+.0f}% - net negative on single-machine".format(delta_e2e))
print(" 4. Stability: Mooncake connector crashes after ~200 reqs under load")
print()
print(" Recommendation:")
print(" - Single-machine 8 GPU: Combined mode is better (lower TTFT, stable)")
print(" - Multi-machine: PD-Sep is promising IF cross-machine latency")
print(" is hidden by RDMA and prefill doesn't share GPU with decode")
print(" - Key bottleneck: this workload's heavy prefill (avg 32k tokens)")
print(" makes KV transfer cost non-trivial relative to prefill time")
print(" - Prefill-as-a-Service (Goal 5) should focus on cross-machine")
print(" KV cache sharing, not same-machine PD split")

96
scripts/launch_pd_mooncake.sh Executable file
View File

@@ -0,0 +1,96 @@
#!/bin/bash
# PD-Disaggregated serving via Mooncake (RDMA + DRAM KV pool).
#
# Architecture:
# Client → Proxy (port 8000)
# → Prefill (port 8010, TP=4, GPUs 0-3, bootstrap 8998)
# [prefill + store KV to DRAM pool via RDMA]
# → Decode (port 8020, TP=4, GPUs 4-7)
# [pull KV from DRAM pool via RDMA + decode]
#
# Usage: bash scripts/launch_pd_mooncake.sh
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
VENV="$PROJECT_DIR/.venv/bin"
VLLM="$VENV/vllm"
MODEL_PATH="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
PROXY_PORT=8000
PREFILL_PORT=8010
DECODE_PORT=8020
BOOTSTRAP_PORT=8998
PROXY_SCRIPT="$PROJECT_DIR/third_party/vllm/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py"
trap 'echo "Cleaning up..."; kill $(jobs -p) 2>/dev/null; wait 2>/dev/null' EXIT INT TERM
echo "=== PD-Disaggregated vLLM 0.18.1 (Mooncake/RDMA) ==="
echo " Model: $MODEL_PATH"
echo " Prefill: GPUs 0-3 (TP=4), port $PREFILL_PORT, bootstrap $BOOTSTRAP_PORT"
echo " Decode: GPUs 4-7 (TP=4), port $DECODE_PORT"
echo " Proxy: port $PROXY_PORT"
echo ""
# Step 1: Start prefill instance (KV producer)
echo "[1/3] Starting prefill instance..."
VLLM_MOONCAKE_BOOTSTRAP_PORT=$BOOTSTRAP_PORT \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
$VLLM serve "$MODEL_PATH" \
--host 0.0.0.0 \
--port $PREFILL_PORT \
--tensor-parallel-size 4 \
--trust-remote-code \
--enable-prefix-caching \
--enforce-eager \
--dtype auto \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector":"MooncakeConnector","kv_role":"kv_producer"}' &
PREFILL_PID=$!
echo " Prefill PID=$PREFILL_PID"
# Step 2: Start decode instance (KV consumer)
echo "[2/3] Starting decode instance..."
CUDA_VISIBLE_DEVICES=4,5,6,7 \
$VLLM serve "$MODEL_PATH" \
--host 0.0.0.0 \
--port $DECODE_PORT \
--tensor-parallel-size 4 \
--trust-remote-code \
--enable-prefix-caching \
--enforce-eager \
--dtype auto \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
'{"kv_connector":"MooncakeConnector","kv_role":"kv_consumer"}' &
DECODE_PID=$!
echo " Decode PID=$DECODE_PID"
# Wait for both instances
echo ""
echo "Waiting for instances..."
timeout 1200 bash -c "until curl -s localhost:$PREFILL_PORT/v1/models > /dev/null 2>&1; do sleep 5; done"
echo " Prefill ready!"
timeout 1200 bash -c "until curl -s localhost:$DECODE_PORT/v1/models > /dev/null 2>&1; do sleep 5; done"
echo " Decode ready!"
# Step 3: Start proxy (after instances are ready)
echo "[3/3] Starting proxy..."
$VENV/python "$PROXY_SCRIPT" \
--prefill "http://127.0.0.1:$PREFILL_PORT" "$BOOTSTRAP_PORT" \
--decode "http://127.0.0.1:$DECODE_PORT" \
--host 0.0.0.0 \
--port $PROXY_PORT &
PROXY_PID=$!
echo " Proxy PID=$PROXY_PID"
sleep 5
echo ""
echo "=== All ready ==="
echo " Send requests to: http://localhost:$PROXY_PORT"
echo ""
wait

View File

@@ -0,0 +1,89 @@
#!/bin/bash
# PD-Disaggregated serving: 1 prefill (TP=4, GPUs 0-3) + 1 decode (TP=4, GPUs 4-7)
# Uses vLLM 0.18.1's P2pNcclConnector + XpYd proxy.
#
# Architecture:
# Client → Proxy (port 10001)
# → Prefill (port 20003, kv_port 21001) [max_tokens=1, does prefill + KV push]
# → Decode (port 20005, kv_port 22001) [full generation, KV pulled from prefill]
#
# Usage: bash scripts/launch_pd_separated.sh
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
VENV="$PROJECT_DIR/.venv/bin"
VLLM="$VENV/vllm"
MODEL_PATH="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
PROXY_PORT=30001 # ZMQ service discovery
CLIENT_PORT=10001 # HTTP proxy for clients
PREFILL_PORT=20003
DECODE_PORT=20005
KV_PORT_P=21001
KV_PORT_D=22001
trap 'echo "Cleaning up..."; kill $(jobs -p) 2>/dev/null; wait 2>/dev/null' EXIT INT TERM
echo "=== PD-Disaggregated vLLM 0.18.1 ==="
echo " Model: $MODEL_PATH"
echo " Prefill: GPUs 0-3 (TP=4), port $PREFILL_PORT, kv_port $KV_PORT_P"
echo " Decode: GPUs 4-7 (TP=4), port $DECODE_PORT, kv_port $KV_PORT_D"
echo " Proxy: ZMQ=$PROXY_PORT, HTTP=$CLIENT_PORT"
echo ""
# Step 1: Start proxy FIRST (P/D instances register via ZMQ)
echo "[1/3] Starting proxy..."
PROXY_SCRIPT="$PROJECT_DIR/third_party/vllm/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py"
$VENV/python "$PROXY_SCRIPT" &
PROXY_PID=$!
sleep 2
echo " Proxy PID=$PROXY_PID"
# Step 2: Start prefill instance (KV producer)
echo "[2/3] Starting prefill instance..."
CUDA_VISIBLE_DEVICES=0,1,2,3 $VLLM serve "$MODEL_PATH" \
--host 0.0.0.0 \
--port $PREFILL_PORT \
--tensor-parallel-size 4 \
--trust-remote-code \
--enable-prefix-caching \
--enforce-eager \
--dtype auto \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
"{\"kv_connector\":\"P2pNcclConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_size\":\"1e1\",\"kv_port\":\"$KV_PORT_P\",\"kv_connector_extra_config\":{\"proxy_ip\":\"127.0.0.1\",\"proxy_port\":\"$PROXY_PORT\",\"http_port\":\"$PREFILL_PORT\",\"send_type\":\"PUT_ASYNC\",\"nccl_num_channels\":\"16\"}}" &
PREFILL_PID=$!
echo " Prefill PID=$PREFILL_PID"
# Step 3: Start decode instance (KV consumer)
echo "[3/3] Starting decode instance..."
CUDA_VISIBLE_DEVICES=4,5,6,7 $VLLM serve "$MODEL_PATH" \
--host 0.0.0.0 \
--port $DECODE_PORT \
--tensor-parallel-size 4 \
--trust-remote-code \
--enable-prefix-caching \
--enforce-eager \
--dtype auto \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
"{\"kv_connector\":\"P2pNcclConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_size\":\"8e9\",\"kv_port\":\"$KV_PORT_D\",\"kv_connector_extra_config\":{\"proxy_ip\":\"127.0.0.1\",\"proxy_port\":\"$PROXY_PORT\",\"http_port\":\"$DECODE_PORT\",\"send_type\":\"PUT_ASYNC\",\"nccl_num_channels\":\"16\"}}" &
DECODE_PID=$!
echo " Decode PID=$DECODE_PID"
# Wait for readiness
echo ""
echo "Waiting for instances..."
timeout 1200 bash -c "until curl -s localhost:$PREFILL_PORT/v1/completions > /dev/null 2>&1; do sleep 5; done"
echo " Prefill ready!"
timeout 1200 bash -c "until curl -s localhost:$DECODE_PORT/v1/completions > /dev/null 2>&1; do sleep 5; done"
echo " Decode ready!"
echo ""
echo "=== All ready ==="
echo " Send requests to: http://localhost:$CLIENT_PORT"
echo ""
wait

23
scripts/launch_vllm.sh Executable file
View File

@@ -0,0 +1,23 @@
#!/bin/bash
# Launch vLLM 0.18.1 in PD-combined mode (TP=8, all GPUs).
#
# Usage: bash scripts/launch_vllm.sh
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
VLLM="$PROJECT_DIR/.venv/bin/vllm"
MODEL_PATH="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
HOST="${HOST:-0.0.0.0}"
PORT="${PORT:-8000}"
echo "Starting vLLM 0.18.1 in PD-combined mode (TP=8) on port $PORT ..."
$VLLM serve "$MODEL_PATH" \
--trust-remote-code \
--enable-prefix-caching \
--dtype auto \
--tensor-parallel-size 8 \
--host "$HOST" \
--port "$PORT"

77
scripts/run_benchmark.sh Executable file
View File

@@ -0,0 +1,77 @@
#!/bin/bash
# Run the full benchmark suite: sample trace → replay against vLLM → collect metrics.
#
# Prerequisites:
# - vLLM server running (use scripts/launch_vllm.sh)
# - Sampled trace file exists (or will be created)
#
# Usage:
# bash scripts/run_benchmark.sh [--endpoint URL] [--tag NAME]
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
cd "$PROJECT_DIR"
# Defaults
TRACE_INPUT="${TRACE_INPUT:-$HOME/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl}"
ENDPOINT="${ENDPOINT:-http://localhost:8000}"
TAG="${TAG:-default}"
TARGET_REQUESTS="${TARGET_REQUESTS:-5000}"
TIME_SCALE="${TIME_SCALE:-1.0}"
MAX_INFLIGHT="${MAX_INFLIGHT:-32}"
SEED="${SEED:-42}"
# Parse args
while [[ $# -gt 0 ]]; do
case "$1" in
--endpoint) ENDPOINT="$2"; shift 2 ;;
--tag) TAG="$2"; shift 2 ;;
--target-requests) TARGET_REQUESTS="$2"; shift 2 ;;
--time-scale) TIME_SCALE="$2"; shift 2 ;;
--max-inflight) MAX_INFLIGHT="$2"; shift 2 ;;
*) echo "Unknown arg: $1"; exit 1 ;;
esac
done
SAMPLED_TRACE="traces/sampled_${TARGET_REQUESTS}req_seed${SEED}.jsonl"
OUTPUT_DIR="outputs/${TAG}_$(date +%Y%m%d_%H%M%S)"
echo "=== Benchmark: tag=$TAG ==="
echo " Trace: $TRACE_INPUT"
echo " Endpoint: $ENDPOINT"
echo " Target requests: $TARGET_REQUESTS"
echo " Time scale: $TIME_SCALE"
echo " Max inflight sessions: $MAX_INFLIGHT"
# Step 1: Sample trace (if not already done)
if [ ! -f "$SAMPLED_TRACE" ]; then
echo ""
echo "=== Step 1: Sampling trace ==="
python scripts/sample_trace.py \
--input "$TRACE_INPUT" \
--output "$SAMPLED_TRACE" \
--target-requests "$TARGET_REQUESTS" \
--seed "$SEED"
else
echo ""
echo "=== Step 1: Using existing sampled trace: $SAMPLED_TRACE ==="
fi
# Step 2: Run replay
echo ""
echo "=== Step 2: Replaying trace ==="
mkdir -p "$OUTPUT_DIR"
python -m replayer \
--trace "$SAMPLED_TRACE" \
--output "$OUTPUT_DIR/metrics.jsonl" \
--endpoint "$ENDPOINT" \
--time-scale "$TIME_SCALE" \
--max-inflight-sessions "$MAX_INFLIGHT" \
-v
echo ""
echo "=== Done ==="
echo " Metrics: $OUTPUT_DIR/metrics.jsonl"
echo " Summary: $OUTPUT_DIR/metrics.summary.json"

254
scripts/run_experiments.sh Executable file
View File

@@ -0,0 +1,254 @@
#!/bin/bash
# Run the complete experiment matrix:
# 1. Combined TP=2 DP=4 (4 instances, baseline)
# 2. Combined TP=1 DP=8 (8 instances, max throughput)
# 3. PD-Sep TP=1: P×4 + D×4 via Mooncake/RDMA
#
# All use the same trace, same concurrency, same timeout.
set -euo pipefail
PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
VENV="$PROJECT_DIR/.venv/bin"
VLLM="$VENV/vllm"
PYTHON="$VENV/python"
MODEL="${MODEL_PATH:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
TRACE="$PROJECT_DIR/traces/sampled_1000req_seed42.jsonl"
# Uniform benchmark params
MAX_SESSIONS=${MAX_SESSIONS:-8}
MAX_CONCURRENT=${MAX_CONCURRENT:-16}
TIME_SCALE=10
REQUEST_TIMEOUT=${REQUEST_TIMEOUT:-300}
REQUEST_LIMIT="${REQUEST_LIMIT:-}" # empty = all 1000
cleanup_gpu() {
pkill -9 -f "vllm" 2>/dev/null || true
pkill -9 -f "cache_aware_proxy\|mooncake_connector_proxy\|uvicorn" 2>/dev/null || true
fuser 9090/tcp 8000/tcp 2>/dev/null | xargs -r kill -9 2>/dev/null || true
sleep 5
fuser /dev/nvidia* 2>/dev/null | tr " " "\n" | sort -u | xargs -r kill -9 2>/dev/null || true
sleep 10
}
wait_for_server() {
local port=$1
local timeout=${2:-600}
timeout "$timeout" bash -c "until curl -s localhost:$port/v1/models >/dev/null 2>&1; do sleep 5; done"
}
run_benchmark() {
local tag=$1
local endpoint=$2
local extra_args="${3:-}"
local outdir="$PROJECT_DIR/outputs/$tag"
echo " Running benchmark -> $outdir"
local limit_arg=""
if [ -n "$REQUEST_LIMIT" ]; then
limit_arg="--request-limit $REQUEST_LIMIT"
fi
$PYTHON -m replayer \
--trace "$TRACE" \
--output "$outdir/metrics.jsonl" \
--endpoint "$endpoint" \
--model "$MODEL" \
--time-scale $TIME_SCALE \
--max-inflight-sessions $MAX_SESSIONS \
--concurrency-limit $MAX_CONCURRENT \
--request-timeout $REQUEST_TIMEOUT \
$limit_arg \
-v
echo " Done: $(wc -l < "$outdir/metrics.jsonl") requests"
}
#######################################################################
# Experiment 1: Combined TP=2 DP=4
#######################################################################
run_combined_tp2_dp4() {
echo ""
echo "================================================================"
echo " Experiment 1: Combined TP=2 DP=4 (4 instances on 8 GPUs)"
echo "================================================================"
cleanup_gpu
for i in 0 1 2 3; do
local gpu_start=$((i * 2))
local gpu_end=$((gpu_start + 1))
local port=$((8000 + i))
echo " Starting instance $i: GPUs $gpu_start,$gpu_end, port $port"
CUDA_VISIBLE_DEVICES=$gpu_start,$gpu_end $VLLM serve "$MODEL" \
--host 0.0.0.0 --port $port \
--tensor-parallel-size 2 \
--trust-remote-code --enable-prefix-caching --enforce-eager \
--dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 &
done
for i in 0 1 2 3; do
wait_for_server $((8000 + i))
echo " Instance $i ready"
done
echo " All 4 instances ready"
# Start global scheduler (cache-aware proxy in combined mode)
echo " Starting global scheduler..."
$PYTHON "$PROJECT_DIR/scripts/cache_aware_proxy.py" \
--combined http://127.0.0.1:8000 http://127.0.0.1:8001 http://127.0.0.1:8002 http://127.0.0.1:8003 \
--port 9090 &
sleep 5
run_benchmark "exp1_combined_tp2_dp4" "http://localhost:9090"
}
#######################################################################
# Experiment 2: Combined TP=1 DP=8
#######################################################################
run_combined_tp1_dp8() {
echo ""
echo "================================================================"
echo " Experiment 2: Combined TP=1 DP=8 (8 instances on 8 GPUs)"
echo "================================================================"
cleanup_gpu
for i in $(seq 0 7); do
local port=$((8000 + i))
echo " Starting instance $i: GPU $i, port $port"
CUDA_VISIBLE_DEVICES=$i $VLLM serve "$MODEL" \
--host 0.0.0.0 --port $port \
--tensor-parallel-size 1 \
--trust-remote-code --enable-prefix-caching --enforce-eager \
--dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 &
done
for i in $(seq 0 7); do
wait_for_server $((8000 + i))
echo " Instance $i ready"
done
echo " All 8 instances ready"
# Start global scheduler (cache-aware proxy in combined mode)
echo " Starting global scheduler..."
$PYTHON "$PROJECT_DIR/scripts/cache_aware_proxy.py" \
--combined http://127.0.0.1:8000 http://127.0.0.1:8001 http://127.0.0.1:8002 http://127.0.0.1:8003 \
http://127.0.0.1:8004 http://127.0.0.1:8005 http://127.0.0.1:8006 http://127.0.0.1:8007 \
--port 9090 &
sleep 5
run_benchmark "exp2_combined_tp1_dp8" "http://localhost:9090"
}
#######################################################################
# Experiment 3: PD-Sep TP=1 P×4 D×4 (Mooncake/RDMA)
#######################################################################
run_pd_sep_tp1() {
echo ""
echo "================================================================"
echo " Experiment 3: PD-Sep TP=1 P×4 + D×4 (Mooncake/RDMA)"
echo "================================================================"
cleanup_gpu
PROXY_SCRIPT="$PROJECT_DIR/scripts/cache_aware_proxy.py"
# Start 4 prefill instances (GPUs 0-3)
local prefill_args=""
for i in 0 1 2 3; do
local port=$((8010 + i))
local bootstrap=$((8998 + i))
echo " Prefill $i: GPU $i, port $port, bootstrap $bootstrap"
VLLM_MOONCAKE_BOOTSTRAP_PORT=$bootstrap \
CUDA_VISIBLE_DEVICES=$i $VLLM serve "$MODEL" \
--host 0.0.0.0 --port $port \
--tensor-parallel-size 1 \
--trust-remote-code --enable-prefix-caching --enforce-eager \
--dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 \
--kv-transfer-config \
"{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_producer\"}" &
prefill_args="$prefill_args --prefill http://127.0.0.1:$port $bootstrap"
done
# Start 4 decode instances (GPUs 4-7)
local decode_args=""
for i in 0 1 2 3; do
local gpu=$((4 + i))
local port=$((8020 + i))
echo " Decode $i: GPU $gpu, port $port"
CUDA_VISIBLE_DEVICES=$gpu $VLLM serve "$MODEL" \
--host 0.0.0.0 --port $port \
--tensor-parallel-size 1 \
--trust-remote-code --enable-prefix-caching --enforce-eager \
--dtype auto --gpu-memory-utilization 0.9 --max-model-len 200000 \
--kv-transfer-config \
"{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_consumer\",\"kv_load_failure_policy\":\"recompute\"}" &
decode_args="$decode_args --decode http://127.0.0.1:$port"
done
# Wait for all instances
for i in 0 1 2 3; do
wait_for_server $((8010 + i))
echo " Prefill $i ready"
done
for i in 0 1 2 3; do
wait_for_server $((8020 + i))
echo " Decode $i ready"
done
# Start proxy (wait for bootstrap to be queryable first)
echo " Waiting for bootstrap servers..."
for bp in 8998 8999 9000 9001; do
timeout 120 bash -c "until curl -s localhost:$bp/query > /dev/null 2>&1; do sleep 2; done"
echo " Bootstrap $bp ready"
done
echo " Starting proxy on port 9000..."
$PYTHON "$PROXY_SCRIPT" $prefill_args $decode_args --host 0.0.0.0 --port 9090 &
sleep 15
# Smoke test with retry
echo " Smoke test..."
for attempt in 1 2 3; do
result=$(curl -s -m 120 http://localhost:9090/v1/completions \
-X POST -H "Content-Type: application/json" \
-d "{\"model\":\"$MODEL\",\"prompt\":[100,200,300],\"max_tokens\":3,\"temperature\":0}" 2>&1)
if echo "$result" | grep -q "choices"; then
echo " Smoke test passed!"
break
fi
echo " Attempt $attempt failed, retrying..."
sleep 10
done
run_benchmark "exp3_pd_sep_tp1_mooncake" "http://localhost:9090"
}
#######################################################################
# Main
#######################################################################
echo "Starting experiment matrix on $(hostname)"
echo "Model: $MODEL"
echo "Trace: $TRACE"
echo "Params: sessions=$MAX_SESSIONS, concurrent=$MAX_CONCURRENT, time_scale=$TIME_SCALE"
echo ""
case "${1:-all}" in
1|tp2dp4) run_combined_tp2_dp4 ;;
2|tp1dp8) run_combined_tp1_dp8 ;;
3|pdsep) run_pd_sep_tp1 ;;
all)
run_combined_tp2_dp4
run_combined_tp1_dp8
run_pd_sep_tp1
;;
*)
echo "Usage: $0 {1|2|3|all|tp2dp4|tp1dp8|pdsep}"
exit 1
;;
esac
echo ""
echo "================================================================"
echo " All experiments complete!"
echo "================================================================"
cleanup_gpu

204
scripts/sample_trace.py Normal file
View File

@@ -0,0 +1,204 @@
"""Sample sessions from the full cluster-scale trace to fit a single machine.
Preserves:
- Complete session structure (all turns within a session kept together)
- Original arrival timing (inter-session and intra-session gaps)
- hash_ids for KV cache reuse patterns
- Request type distribution
Sampling strategy:
1. Group requests by session (derived from parent_chat_id chains)
2. Randomly sample N sessions (or until target request count reached)
3. Re-zero timestamps so first event starts at t=0
4. Optionally compress time axis to increase load density
Usage:
python scripts/sample_trace.py \\
--input ~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl \\
--output traces/sampled.jsonl \\
--target-requests 5000 \\
--seed 42
"""
from __future__ import annotations
import argparse
import collections
import json
import random
import sys
from pathlib import Path
def load_raw_rows(path: Path) -> dict[str, list[dict]]:
"""Load trace, group rows by resolved session_id. Preserve file order."""
chat_to_session: dict[int, str] = {}
rows_by_session: dict[str, list[dict]] = collections.OrderedDict()
with path.open("r", encoding="utf-8") as fh:
for line in fh:
row = json.loads(line)
cid = int(row["chat_id"])
pid = int(row["parent_chat_id"])
if "session_id" in row:
sid = str(row["session_id"])
elif pid < 0:
sid = str(cid)
else:
sid = chat_to_session.get(pid, str(pid))
chat_to_session[cid] = sid
row["_session_id"] = sid
rows_by_session.setdefault(sid, []).append(row)
return rows_by_session
def sample_sessions(
rows_by_session: dict[str, list[dict]],
*,
target_requests: int,
seed: int,
strategy: str = "random",
) -> list[str]:
"""Select sessions until target request count is reached."""
all_sids = list(rows_by_session.keys())
rng = random.Random(seed)
if strategy == "random":
rng.shuffle(all_sids)
elif strategy == "sequential":
pass # keep file order
else:
raise ValueError(f"Unknown strategy: {strategy}")
selected = []
total = 0
for sid in all_sids:
selected.append(sid)
total += len(rows_by_session[sid])
if total >= target_requests:
break
return selected
def build_output(
rows_by_session: dict[str, list[dict]],
selected: list[str],
*,
time_scale: float = 1.0,
) -> list[dict]:
"""Build output rows with re-zeroed timestamps."""
out_rows = []
for sid in selected:
for row in rows_by_session[sid]:
out = {k: v for k, v in row.items() if not k.startswith("_")}
out["session_id"] = sid
out_rows.append(out)
out_rows.sort(key=lambda r: float(r["timestamp"]))
if not out_rows:
return out_rows
# Re-zero: subtract earliest timestamp
t0 = float(out_rows[0]["timestamp"])
for row in out_rows:
row["timestamp"] = (float(row["timestamp"]) - t0) / time_scale
return out_rows
def print_summary(
rows_by_session: dict[str, list[dict]],
selected: list[str],
out_rows: list[dict],
) -> None:
n_sessions = len(selected)
n_requests = len(out_rows)
turns_per_session = [len(rows_by_session[s]) for s in selected]
multi_turn = sum(1 for t in turns_per_session if t > 1)
input_lens = [r["input_length"] for r in out_rows]
output_lens = [r["output_length"] for r in out_rows]
span_s = float(out_rows[-1]["timestamp"]) if out_rows else 0
session_starts = {}
for r in out_rows:
sid = r["session_id"]
ts = float(r["timestamp"])
if sid not in session_starts:
session_starts[sid] = ts
starts_sorted = sorted(session_starts.values())
deltas = [starts_sorted[i+1] - starts_sorted[i]
for i in range(len(starts_sorted) - 1)]
# hash_ids overlap: count unique hash_ids across all requests
all_hashes = set()
for r in out_rows:
all_hashes.update(r.get("hash_ids", []))
print(f"Sampled: {n_sessions} sessions, {n_requests} requests")
print(f" Multi-turn sessions: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)")
print(f" Turns/session: min={min(turns_per_session)} max={max(turns_per_session)} "
f"avg={sum(turns_per_session)/len(turns_per_session):.1f}")
print(f" Input length: min={min(input_lens)} max={max(input_lens)} "
f"avg={sum(input_lens)/len(input_lens):.0f}")
print(f" Output length: min={min(output_lens)} max={max(output_lens)} "
f"avg={sum(output_lens)/len(output_lens):.0f}")
print(f" Trace span: {span_s:.1f}s ({span_s/60:.1f} min)")
print(f" Unique hash blocks: {len(all_hashes)}")
if deltas:
deltas.sort()
p = lambda q: deltas[min(int(q * len(deltas)), len(deltas) - 1)]
print(f" Session arrival deltas (s): p10={p(0.1):.2f} p50={p(0.5):.2f} "
f"p90={p(0.9):.2f} max={max(deltas):.2f}")
def main() -> None:
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--input", type=Path, required=True,
help="Path to the full trace JSONL file")
p.add_argument("--output", type=Path, required=True,
help="Path to write sampled trace JSONL")
p.add_argument("--target-requests", type=int, default=5000,
help="Target number of requests (stops after session that crosses it)")
p.add_argument("--strategy", choices=["random", "sequential"], default="random",
help="Session selection strategy")
p.add_argument("--time-scale", type=float, default=1.0,
help="Compress time axis by this factor (>1 = faster arrival)")
p.add_argument("--seed", type=int, default=42)
args = p.parse_args()
print(f"Loading trace from {args.input} ...")
rows_by_session = load_raw_rows(args.input)
total_sessions = len(rows_by_session)
total_requests = sum(len(v) for v in rows_by_session.values())
print(f"Full trace: {total_sessions} sessions, {total_requests} requests")
selected = sample_sessions(
rows_by_session,
target_requests=args.target_requests,
seed=args.seed,
strategy=args.strategy,
)
out_rows = build_output(
rows_by_session, selected,
time_scale=args.time_scale,
)
print_summary(rows_by_session, selected, out_rows)
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", encoding="utf-8") as fh:
for row in out_rows:
fh.write(json.dumps(row, ensure_ascii=False) + "\n")
print(f"\nWrote {len(out_rows)} rows to {args.output}")
if __name__ == "__main__":
main()