#!/usr/bin/env python3 """Compare KVC variants vs baseline, EXCLUDING errors and truncated requests.""" import json import numpy as np from pathlib import Path OUT = Path("/mnt/kzlin/workflow/pd-hybrid/agentic-pd-hybrid/outputs") DATASETS = [ ("baseline 8DP", OUT / "qwen3-30b-tp1-v2-fixed/exp1_8way_dp_cache_aware_metrics.jsonl"), ("v3 1P7D", OUT / "qwen3-30b-tp1-v3-kvaware/exp1_1p7d_kvc_kvaware_metrics.jsonl"), ("v3 2P6D", OUT / "qwen3-30b-tp1-v3-kvaware/exp2_2p6d_kvc_kvaware_metrics.jsonl"), ("v4 1P7D", OUT / "qwen3-30b-tp1-v4-cap16/exp1_1p7d_kvc_cap16_metrics.jsonl"), ("v4 2P6D", OUT / "qwen3-30b-tp1-v4-cap16/exp2_2p6d_kvc_cap16_metrics.jsonl"), ] def load_rows(path): rows = [] with open(path) as f: for line in f: rows.append(json.loads(line)) return rows def is_truncated(row): a = row.get("actual_output_tokens") r = row.get("requested_output_tokens") if a is not None and r is not None and r > 1: return a < r * 0.5 return False def stats(values): if not values: return {"n": 0} a = np.array(values) return { "n": len(a), "mean": float(np.mean(a)), "p50": float(np.percentile(a, 50)), "p90": float(np.percentile(a, 90)), "p99": float(np.percentile(a, 99)), } def fmt(s, key): if s["n"] == 0: return "N/A" v = s[key] return f"{v:.3f}s" if v < 100 else f"{v:.1f}s" results = [] for label, path in DATASETS: if not path.exists(): print(f"SKIP {label}") continue rows = load_rows(path) total = len(rows) err_n = sum(1 for r in rows if r.get("error") is not None) trunc_n = sum(1 for r in rows if r.get("error") is None and is_truncated(r)) # Filter: error=None AND not truncated AND latency present clean = [r for r in rows if r.get("error") is None and not is_truncated(r) and r.get("latency_s") is not None] lats = [r["latency_s"] for r in clean] ttfts = [r["ttft_s"] for r in clean if r.get("ttft_s") is not None] results.append({ "label": label, "total": total, "err": err_n, "trunc": trunc_n, "clean_n": len(clean), "lat": stats(lats), "ttft": stats(ttfts), }) # Print comparison table print(f"\n{'='*100}") print("LATENCY (excluding errors AND truncated)") print(f"{'='*100}") print(f"{'config':<16}{'total':>7}{'err':>6}{'trunc':>7}{'clean':>7} {'mean':>9}{'P50':>9}{'P90':>9}{'P99':>9}") for r in results: print(f"{r['label']:<16}{r['total']:>7}{r['err']:>6}{r['trunc']:>7}{r['clean_n']:>7} " f"{fmt(r['lat'],'mean'):>9}{fmt(r['lat'],'p50'):>9}{fmt(r['lat'],'p90'):>9}{fmt(r['lat'],'p99'):>9}") print(f"\n{'='*100}") print("TTFT (excluding errors AND truncated)") print(f"{'='*100}") print(f"{'config':<16}{'clean':>7} {'mean':>9}{'P50':>9}{'P90':>9}{'P99':>9}") for r in results: print(f"{r['label']:<16}{r['clean_n']:>7} " f"{fmt(r['ttft'],'mean'):>9}{fmt(r['ttft'],'p50'):>9}{fmt(r['ttft'],'p90'):>9}{fmt(r['ttft'],'p99'):>9}") # Also: per-execution-mode breakdown for v4 only (the most interesting) print(f"\n{'='*100}") print("V4 2P6D: per-execution-mode (excluding errors and truncated)") print(f"{'='*100}") v4_2p6d = next((p for l, p in DATASETS if l == "v4 2P6D"), None) if v4_2p6d: rows = load_rows(v4_2p6d) clean = [r for r in rows if r.get("error") is None and not is_truncated(r)] from collections import Counter modes = Counter(r["execution_mode"] for r in clean) print(f"{'mode':<55}{'n':>7}{'%':>7} {'mean':>9}{'P50':>9}{'P90':>9}{'P99':>9}") for mode, count in modes.most_common(10): m_rows = [r for r in clean if r["execution_mode"] == mode] s = stats([r["latency_s"] for r in m_rows]) pct = count/len(clean)*100 print(f" {mode:<53}{count:>7}{pct:>6.1f}% {fmt(s,'mean'):>9}{fmt(s,'p50'):>9}{fmt(s,'p90'):>9}{fmt(s,'p99'):>9}") # Also: WHAT IF we only count direct-to-D? (Pure KVC performance) print(f"\n{'='*100}") print("Pure KVC (kvcache-direct-to-d-session ONLY) vs Baseline") print(f"{'='*100}") for label, path in DATASETS: if not path.exists() or "1P7D" not in label and "2P6D" not in label: continue rows = load_rows(path) direct = [r for r in rows if r.get("error") is None and not is_truncated(r) and r.get("execution_mode") == "kvcache-direct-to-d-session"] if not direct: continue s_lat = stats([r["latency_s"] for r in direct]) s_ttft = stats([r["ttft_s"] for r in direct if r.get("ttft_s") is not None]) print(f"{label:<16}n={s_lat['n']:>5} lat: P50={fmt(s_lat,'p50')} P90={fmt(s_lat,'p90')} ttft: P50={fmt(s_ttft,'p50')} P90={fmt(s_ttft,'p90')}") # Baseline for reference (already non-fallback by definition) print() baseline_path = OUT / "qwen3-30b-tp1-v2-fixed/exp1_8way_dp_cache_aware_metrics.jsonl" baseline_rows = load_rows(baseline_path) clean = [r for r in baseline_rows if r.get("error") is None and not is_truncated(r)] s_lat = stats([r["latency_s"] for r in clean]) s_ttft = stats([r["ttft_s"] for r in clean if r.get("ttft_s") is not None]) print(f"{'baseline 8DP':<16}n={s_lat['n']:>5} lat: P50={fmt(s_lat,'p50')} P90={fmt(s_lat,'p90')} ttft: P50={fmt(s_ttft,'p50')} P90={fmt(s_ttft,'p90')}")