Breakdown profiling at proxy level captures: t_proxy_recv → t_prefill_sent → t_prefill_done → t_decode_sent → t_first_token Key finding: 87.7% of TTFT is spent in kv+decode phase, NOT prefill. Root cause: decode instance KV cache memory saturation (97.1% usage). With 6P+2D config, 2 decode GPUs have only ~56GB total KV cache. Large agentic requests (avg 33.6k tokens) fill this quickly. Small requests (49 tokens, prefill=0.044s) wait 114s for KV cache to be freed by large requests completing decode. vLLM log confirms: Running=0, Waiting=6, KV cache=97.1% GPU is idle but requests queue for KV cache memory, not compute. This is the fundamental bottleneck of single-machine PD separation for long-context agentic workloads: concentrating decode onto fewer GPUs creates a KV cache memory wall. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
88 lines
3.7 KiB
Python
88 lines
3.7 KiB
Python
"""Deep profile: why fire-and-forget TTFT is 5x worse than await."""
|
|
import json, statistics
|
|
|
|
await_rows = [json.loads(l) for l in open("outputs/gpu_ab_6p2d/metrics.jsonl")]
|
|
fnf_rows = [json.loads(l) for l in open("outputs/gpu_ab_6p2d_fnf/metrics.jsonl")]
|
|
|
|
await_ok = [r for r in await_rows if not r.get("error")]
|
|
fnf_ok = [r for r in fnf_rows if not r.get("error")]
|
|
|
|
# Match by request_id
|
|
await_by_id = {r["request_id"]: r for r in await_ok}
|
|
fnf_by_id = {r["request_id"]: r for r in fnf_ok}
|
|
common = set(await_by_id.keys()) & set(fnf_by_id.keys())
|
|
|
|
print("=" * 75)
|
|
print(" PROFILE: Fire-and-Forget vs Await-Prefill (same 6P+2D instances)")
|
|
print("=" * 75)
|
|
print(f" Common requests: {len(common)}")
|
|
|
|
# Per-request comparison
|
|
diffs = []
|
|
for rid in common:
|
|
a = await_by_id[rid]
|
|
f = fnf_by_id[rid]
|
|
if a.get("ttft_s") and f.get("ttft_s") and a["ttft_s"] > 0:
|
|
diffs.append({
|
|
"id": rid, "input": a["input_length"],
|
|
"a_ttft": a["ttft_s"], "f_ttft": f["ttft_s"],
|
|
"ratio": f["ttft_s"] / a["ttft_s"],
|
|
"a_e2e": a["latency_s"], "f_e2e": f["latency_s"],
|
|
"a_tpot": a.get("tpot_s", 0), "f_tpot": f.get("tpot_s", 0),
|
|
"a_out": a.get("actual_output_tokens", 0) or 0,
|
|
"f_out": f.get("actual_output_tokens", 0) or 0,
|
|
})
|
|
|
|
diffs.sort(key=lambda x: x["input"])
|
|
|
|
print("\n Per-request (sorted by input_length):")
|
|
hdr = "%8s %10s %10s %7s %10s %10s %8s %8s" % (
|
|
"input", "await_TTFT", "fnf_TTFT", "ratio", "await_E2E", "fnf_E2E", "a_TPOT", "f_TPOT")
|
|
print(" " + hdr)
|
|
print(" " + "-" * len(hdr))
|
|
for d in diffs[:25]:
|
|
print(" %8d %10.3f %10.3f %6.1fx %10.3f %10.3f %8.4f %8.4f" % (
|
|
d["input"], d["a_ttft"], d["f_ttft"], d["ratio"],
|
|
d["a_e2e"], d["f_e2e"], d["a_tpot"], d["f_tpot"]))
|
|
|
|
# Statistics
|
|
if diffs:
|
|
ratios = [d["ratio"] for d in diffs]
|
|
ratios.sort()
|
|
p = lambda v, q: v[min(int(q*len(v)), len(v)-1)]
|
|
print("\n TTFT ratio (FnF / Await):")
|
|
print(" p10=%.2fx p50=%.2fx p90=%.2fx mean=%.2fx" % (
|
|
p(ratios,.1), p(ratios,.5), p(ratios,.9), statistics.fmean(ratios)))
|
|
|
|
faster = sum(1 for r in ratios if r < 1.0)
|
|
print(" FnF faster: %d/%d (%.0f%%)" % (faster, len(ratios), faster*100/len(ratios)))
|
|
|
|
# Bucket by input size
|
|
print("\n TTFT ratio by input size bucket:")
|
|
buckets = [(0, 5000, "<5k"), (5000, 20000, "5-20k"), (20000, 50000, "20-50k"), (50000, 999999, ">50k")]
|
|
for lo, hi, label in buckets:
|
|
subset = [d for d in diffs if lo <= d["input"] < hi]
|
|
if subset:
|
|
rs = [d["ratio"] for d in subset]
|
|
a_ttfts = [d["a_ttft"] for d in subset]
|
|
f_ttfts = [d["f_ttft"] for d in subset]
|
|
print(" %6s: n=%3d await_TTFT=%.3f fnf_TTFT=%.3f ratio=%.2fx" % (
|
|
label, len(subset), statistics.fmean(a_ttfts), statistics.fmean(f_ttfts),
|
|
statistics.fmean(rs)))
|
|
|
|
# TPOT comparison
|
|
a_tpots = [d["a_tpot"] for d in diffs if d["a_tpot"] > 0]
|
|
f_tpots = [d["f_tpot"] for d in diffs if d["f_tpot"] > 0]
|
|
if a_tpots and f_tpots:
|
|
print("\n TPOT comparison:")
|
|
print(" Await: mean=%.4f p50=%.4f" % (statistics.fmean(a_tpots), sorted(a_tpots)[len(a_tpots)//2]))
|
|
print(" FnF: mean=%.4f p50=%.4f" % (statistics.fmean(f_tpots), sorted(f_tpots)[len(f_tpots)//2]))
|
|
|
|
# Also look at non-common requests (FnF only failures)
|
|
fnf_err = [r for r in fnf_rows if r.get("error")]
|
|
await_err_ids = {r["request_id"] for r in await_rows if r.get("error")}
|
|
fnf_only_err = [r for r in fnf_err if r["request_id"] not in await_err_ids]
|
|
print("\n Errors unique to FnF: %d" % len(fnf_only_err))
|
|
for r in fnf_only_err[:5]:
|
|
print(" input=%d err=%s" % (r["input_length"], r["error"][:60]))
|