Systematic study of prefill-decode disaggregation for agentic LLM workloads using production GLM-5.1 coder trace (2.1M requests, 71B input tokens). Key findings: - Cache-aware routing improves TPOT p90 by 15% and APC from 20.8% to 44.7% without PD separation, matching PD-Sep's decode isolation benefit - PD separation adds +72% TTFT overhead (KV transfer) with no TPOT gain when using the same cache-aware scheduler - Prefill remains compute-bound even at 95% KV cache reuse (AI >1000x vs decode AI <2), but absolute FLOPs drop 71% from cache hits - For agentic MoE workloads, cache-aware routing > PD separation Infrastructure: - Trace sampler preserving session structure + hash_ids for prefix sharing - Async trace replayer with streaming TTFT/TPOT/E2E measurement - Unified cache-aware + token-level load-balanced global scheduler proxy supporting both PD-colocated and PD-disaggregated (Mooncake/RDMA) modes - vLLM 0.18.1 scheduler patch for KV transfer abort race condition - Roofline analysis tool for prefill/decode compute characterization Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
164 lines
6.2 KiB
Python
164 lines
6.2 KiB
Python
"""Analyze trace patterns to assess PD separation benefit.
|
|
|
|
Computes metrics relevant to deciding PD-combined vs PD-separated:
|
|
- Input/output token ratio (high ratio = prefill-heavy → PD sep benefits)
|
|
- Prefix sharing density (high sharing → benefits from shared KV cache)
|
|
- Session length distribution (multi-turn = more prefix reuse)
|
|
- Arrival burstiness (bursty prefill → PD sep can absorb spikes)
|
|
- Compute-intensity ratio: prefill FLOP share vs decode FLOP share
|
|
|
|
Usage:
|
|
python scripts/analyze_trace.py --input traces/sampled_1000req_seed42.jsonl
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import collections
|
|
import json
|
|
import statistics
|
|
from pathlib import Path
|
|
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser(description=__doc__,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
|
p.add_argument("--input", type=Path, required=True)
|
|
args = p.parse_args()
|
|
|
|
rows = []
|
|
with args.input.open() as fh:
|
|
for line in fh:
|
|
rows.append(json.loads(line))
|
|
|
|
# Session structure
|
|
sessions: dict[str, list[dict]] = collections.OrderedDict()
|
|
chat_to_session: dict[int, str] = {}
|
|
for r in rows:
|
|
cid = int(r["chat_id"])
|
|
pid = int(r["parent_chat_id"])
|
|
sid = r.get("session_id")
|
|
if sid is None:
|
|
sid = str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))
|
|
chat_to_session[cid] = str(sid)
|
|
sessions.setdefault(str(sid), []).append(r)
|
|
|
|
n_sessions = len(sessions)
|
|
turns_per_session = [len(v) for v in sessions.values()]
|
|
multi_turn = sum(1 for t in turns_per_session if t > 1)
|
|
|
|
input_lens = [r["input_length"] for r in rows]
|
|
output_lens = [r["output_length"] for r in rows]
|
|
total_input = sum(input_lens)
|
|
total_output = sum(output_lens)
|
|
|
|
print("=" * 60)
|
|
print("Trace Pattern Analysis for PD Separation Decision")
|
|
print("=" * 60)
|
|
|
|
# 1. Input/Output ratio
|
|
io_ratio = total_input / max(total_output, 1)
|
|
print(f"\n1. Input/Output Token Ratio")
|
|
print(f" Total input tokens: {total_input:>12,}")
|
|
print(f" Total output tokens: {total_output:>12,}")
|
|
print(f" I/O ratio: {io_ratio:>12.1f}x")
|
|
print(f" → {'STRONGLY' if io_ratio > 50 else 'Moderately' if io_ratio > 10 else 'Weakly'} prefill-heavy")
|
|
|
|
# 2. Prefill compute share
|
|
# Approximate: prefill FLOP ∝ input_length, decode FLOP ∝ output_length * input_length
|
|
# More precisely: prefill dominates when input >> output
|
|
prefill_share = total_input / (total_input + total_output)
|
|
print(f"\n2. Compute Split (token count proxy)")
|
|
print(f" Prefill share: {prefill_share*100:.1f}%")
|
|
print(f" Decode share: {(1-prefill_share)*100:.1f}%")
|
|
|
|
# 3. Session structure
|
|
print(f"\n3. Session Structure")
|
|
print(f" Sessions: {n_sessions}")
|
|
print(f" Requests: {len(rows)}")
|
|
print(f" Multi-turn: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)")
|
|
print(f" Turns/sess: min={min(turns_per_session)} max={max(turns_per_session)} "
|
|
f"avg={statistics.fmean(turns_per_session):.1f}")
|
|
|
|
# 4. Prefix sharing
|
|
all_hash_ids = set()
|
|
per_request_hashes = []
|
|
for r in rows:
|
|
hids = set(r.get("hash_ids", []))
|
|
per_request_hashes.append(hids)
|
|
all_hash_ids.update(hids)
|
|
|
|
hash_refcount = collections.Counter()
|
|
for hids in per_request_hashes:
|
|
for h in hids:
|
|
hash_refcount[h] += 1
|
|
|
|
shared_blocks = sum(1 for h, c in hash_refcount.items() if c > 1)
|
|
total_blocks = len(all_hash_ids)
|
|
block_reuse = shared_blocks / max(total_blocks, 1)
|
|
avg_refcount = statistics.fmean(hash_refcount.values()) if hash_refcount else 0
|
|
|
|
print(f"\n4. Prefix Block Sharing")
|
|
print(f" Unique blocks: {total_blocks:>10,}")
|
|
print(f" Shared (ref>1): {shared_blocks:>10,} ({block_reuse*100:.1f}%)")
|
|
print(f" Avg refcount: {avg_refcount:>10.2f}")
|
|
print(f" → {'High' if block_reuse > 0.3 else 'Moderate' if block_reuse > 0.1 else 'Low'} prefix reuse potential")
|
|
|
|
# 5. Input length distribution
|
|
input_sorted = sorted(input_lens)
|
|
pct = lambda q: input_sorted[min(int(q * len(input_sorted)), len(input_sorted) - 1)]
|
|
print(f"\n5. Input Length Distribution")
|
|
print(f" p10={pct(0.1):>8,} p50={pct(0.5):>8,} p90={pct(0.9):>8,} max={max(input_lens):>8,}")
|
|
long_context = sum(1 for l in input_lens if l > 32000)
|
|
print(f" Requests >32k tokens: {long_context} ({long_context/len(rows)*100:.1f}%)")
|
|
|
|
# 6. Arrival pattern
|
|
timestamps = sorted(float(r["timestamp"]) for r in rows)
|
|
span = timestamps[-1] - timestamps[0]
|
|
avg_rate = len(rows) / max(span, 0.001)
|
|
|
|
# Burstiness: coefficient of variation of inter-arrival times
|
|
inter_arrivals = [timestamps[i+1] - timestamps[i] for i in range(len(timestamps) - 1)]
|
|
inter_arrivals = [t for t in inter_arrivals if t > 0]
|
|
if inter_arrivals:
|
|
cv = statistics.stdev(inter_arrivals) / statistics.fmean(inter_arrivals)
|
|
else:
|
|
cv = 0
|
|
print(f"\n6. Arrival Pattern")
|
|
print(f" Span: {span:.1f}s ({span/60:.1f} min)")
|
|
print(f" Avg rate: {avg_rate:.2f} req/s")
|
|
print(f" Burstiness (CoV): {cv:.2f}")
|
|
print(f" → {'Bursty' if cv > 1.5 else 'Moderate' if cv > 0.8 else 'Steady'} arrival pattern")
|
|
|
|
# Summary
|
|
print(f"\n{'=' * 60}")
|
|
print("Summary: PD Separation Recommendation")
|
|
print(f"{'=' * 60}")
|
|
factors = []
|
|
if io_ratio > 50:
|
|
factors.append("Very high I/O ratio (prefill-dominated)")
|
|
elif io_ratio > 10:
|
|
factors.append("High I/O ratio")
|
|
if block_reuse > 0.1:
|
|
factors.append(f"Significant prefix reuse ({block_reuse*100:.0f}% shared blocks)")
|
|
if long_context / len(rows) > 0.3:
|
|
factors.append(f"Many long-context requests ({long_context/len(rows)*100:.0f}%)")
|
|
if cv > 1.0:
|
|
factors.append("Bursty arrivals (PD sep absorbs prefill spikes)")
|
|
|
|
if len(factors) >= 2:
|
|
print("→ RECOMMEND PD separation:")
|
|
elif len(factors) == 1:
|
|
print("→ PD separation MAY help:")
|
|
else:
|
|
print("→ PD separation likely NOT beneficial:")
|
|
|
|
for f in factors:
|
|
print(f" • {f}")
|
|
if not factors:
|
|
print(" • No strong indicators for PD separation benefit")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|