Files
agentic-pd-hybrid/scripts/analysis/analyze_ts1_validation.py
kzlin 2ec0debef4 feat(kvc): session migration with reset-on-success + direct-append threshold tuning
KVC v2 beats 4DP at ts=1 same-scale on 7/8 metrics:
  TTFT mean -24%, p50 -54%, p90 -64%; lat mean -0.8%, p50 -12.6%, p90 -0.7%.
  Direct-to-D rate jumped 42.8% -> 91.7%. REFACTOR_PLAN_V1 scenario C achieved.

Two-knob fix:
- reset-on-success blacklist decay: clear (sess, D) reject counter on
  successful direct-to-D path. Eliminates v1 thrashing where session 6880
  was stable on decode-1 for 70 turns then collapsed to 75 D-changes after
  cumulative transient pressure tripped the permanent blacklist.
- bump --kvcache-direct-max-uncached-tokens default 2048 -> 8192 via CLI flag.
  41% of v1 fallbacks were 'real-large-append' (>2048 token append); raising
  the threshold lets these go through the direct-to-D fast path.

Code:
- policies.py: RoutingState.session_d_rejects counter + KvAwarePolicy
  migration_reject_threshold; degenerate fallback picks least-rejected D.
- replay.py: record_admission_reject + reset-on-success in _run_request;
  _fallthrough_reason classifies turn-2+ fall-throughs as session-not-resident
  / real-large-append / etc, replacing misleading 'large-append' suffix
  (TEAM_REPORT §2.7).
- cli.py + benchmark.py: --kvcache-migration-reject-threshold flag wiring.

Docs:
- REFACTOR_PLAN_V1_ZH.md: forward-looking plan after ts=1 validation.
- MIGRATION_V1_FINDINGS_ZH.md: v1 thrashing root-cause analysis.
- V2_RESULTS_ZH.md: v2 results, scenario C achievement, attribution.
- TEAM_REPORT_AGENTIC_PD_HYBRID_ZH.md: comprehensive team report.

Scripts:
- sweep_ts1_kvc_n3_plus_dp.sh: ts=1 baseline (KVC 1P3D N=3 + 4DP CA).
- sweep_ts1_migration_v1.sh / v2.sh: validation runs.
- analyze_ts1_validation.py: 4-way comparison analyzer.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-09 01:18:13 +08:00

317 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""TS=1 validation analysis: KVC 1P3D × N=3 + 4DP × 1.
Reads metrics from outputs/qwen3-30b-tp1-ts1-validation/{kvc_1p3d_run{1,2,3},dp4}_metrics.jsonl
and reports per the structural claims in docs/AGENTIC_FIT_ANALYSIS_ZH.md and TEAM_REPORT.
Sections:
1. Headline summary table (errors, latency p50/p90/p99, TTFT p50)
2. §1 (session pinning): distinct-D-per-session distribution + direct-to-D bimodal
3. §1 (cross-run consistency): sessions consistently starved across all 3 runs + size ratio
4. §2 (LRU): KVTransferError counts per D + peak token_usage from worker logs
5. §7 (ts=1 vs ts=10): direct-to-D rate, fallback rate, per-D load balance
6. KVC vs DP same-scale comparison
Usage: python scripts/analysis/analyze_ts1_validation.py [--root PATH]
"""
import argparse
import json
import re
from collections import Counter, defaultdict
from pathlib import Path
import numpy as np
def load_metrics(path):
rows = []
with open(path) as f:
for line in f:
line = line.strip()
if not line:
continue
rows.append(json.loads(line))
return rows
def load_summary(path):
with open(path) as f:
return json.load(f)
def pct(arr, p):
if not arr:
return float("nan")
return float(np.percentile(arr, p))
def summarize_run(label, rows, summary):
ok = [r for r in rows if r.get("error") is None]
err = [r for r in rows if r.get("error") is not None]
lats = [r["latency_s"] for r in ok if r.get("latency_s") is not None]
ttfts = [r["ttft_s"] for r in ok if r.get("ttft_s") is not None]
return {
"label": label,
"n": len(rows),
"ok": len(ok),
"err": len(err),
"lat_mean": float(np.mean(lats)) if lats else float("nan"),
"lat_p50": pct(lats, 50),
"lat_p90": pct(lats, 90),
"lat_p99": pct(lats, 99),
"ttft_mean": float(np.mean(ttfts)) if ttfts else float("nan"),
"ttft_p50": pct(ttfts, 50),
"summary": summary,
}
def headline_table(stats):
print("\n" + "=" * 110)
print("HEADLINE: same trace, same scale, same ts=1")
print("=" * 110)
cols = ["label", "ok/n", "err", "lat_mean", "lat_p50", "lat_p90", "lat_p99", "ttft_mean", "ttft_p50"]
print(f"{cols[0]:<22}{cols[1]:>12}{cols[2]:>6}{cols[3]:>10}{cols[4]:>10}{cols[5]:>10}{cols[6]:>10}{cols[7]:>10}{cols[8]:>10}")
for s in stats:
ok_n = f"{s['ok']}/{s['n']}"
print(f"{s['label']:<22}{ok_n:>12}{s['err']:>6}"
f"{s['lat_mean']:>9.3f}s{s['lat_p50']:>9.3f}s{s['lat_p90']:>9.3f}s{s['lat_p99']:>9.3f}s"
f"{s['ttft_mean']:>9.3f}s{s['ttft_p50']:>9.3f}s")
def session_pinning(rows, label):
"""§1: distinct D per session — should be ~1.0 if pin behavior persists."""
sess_d = defaultdict(set)
for r in rows:
sid = r.get("session_id")
d = r.get("assigned_decode_node") or r.get("decode_node")
if sid is not None and d is not None:
sess_d[sid].add(d)
if not sess_d:
return None
distinct = [len(s) for s in sess_d.values()]
return {
"label": label,
"n_sessions": len(sess_d),
"avg_distinct_D": float(np.mean(distinct)),
"max_distinct_D": max(distinct),
"sess_d": {sid: sorted(ds) for sid, ds in sess_d.items()},
}
def direct_to_d_distribution(rows, label):
"""§1: per-session direct-to-D rate; check for bimodal."""
sess_total = Counter()
sess_direct = Counter()
for r in rows:
sid = r.get("session_id")
if sid is None:
continue
sess_total[sid] += 1
mode = r.get("execution_mode", "")
if mode == "kvcache-direct-to-d-session":
sess_direct[sid] += 1
rates = []
for sid in sess_total:
rate = sess_direct[sid] / sess_total[sid]
rates.append((sid, rate, sess_total[sid]))
bins = [0, 0.2, 0.4, 0.6, 0.8, 1.01]
bin_labels = ["0-20%", "20-40%", "40-60%", "60-80%", "80-100%"]
counts = [0] * 5
for _, r, _ in rates:
for i in range(5):
if bins[i] <= r < bins[i + 1]:
counts[i] += 1
break
print(f"\n [{label}] direct-to-D rate distribution (n={len(rates)} sessions):")
for lbl, cnt in zip(bin_labels, counts):
bar = "" * cnt
print(f" {lbl:<10}: {cnt:>3} {bar}")
return rates
def starved_cross_run(per_run_rates, threshold=0.20):
"""§1: sessions starved (<threshold direct-to-D) in ALL runs."""
if len(per_run_rates) < 2:
return None
sess_starved = defaultdict(int)
sess_lucky = defaultdict(int)
for rates in per_run_rates:
for sid, rate, _ in rates:
if rate < threshold:
sess_starved[sid] += 1
elif rate > 0.80:
sess_lucky[sid] += 1
n_runs = len(per_run_rates)
consistently_starved = [sid for sid, c in sess_starved.items() if c == n_runs]
consistently_lucky = [sid for sid, c in sess_lucky.items() if c == n_runs]
return {
"n_runs": n_runs,
"consistently_starved": consistently_starved,
"consistently_lucky": consistently_lucky,
}
def session_size_comparison(rows, sids_a, sids_b, label_a="A", label_b="B"):
"""Compare peak input_length of two session groups."""
sess_max_input = defaultdict(int)
for r in rows:
sid = r.get("session_id")
ilen = r.get("input_length") or 0
if sid is not None and ilen > sess_max_input[sid]:
sess_max_input[sid] = ilen
a_inputs = [sess_max_input[s] for s in sids_a if s in sess_max_input]
b_inputs = [sess_max_input[s] for s in sids_b if s in sess_max_input]
if a_inputs and b_inputs:
ratio = np.mean(a_inputs) / np.mean(b_inputs)
print(f"\n Cross-run starvation correlates with session size?")
print(f" consistently {label_a} (n={len(a_inputs)}): peak_input mean = {np.mean(a_inputs):.0f}")
print(f" consistently {label_b} (n={len(b_inputs)}): peak_input mean = {np.mean(b_inputs):.0f}")
print(f" {label_a}/{label_b} ratio = {ratio:.2f}x (ts=10 baseline was 1.98x)")
def per_d_balance(rows, label):
"""§7: per-D load balance."""
per_d = Counter()
for r in rows:
d = r.get("assigned_decode_node") or r.get("decode_node")
if d:
per_d[d] += 1
if not per_d:
return
counts = list(per_d.values())
spread = (max(counts) - min(counts)) / max(np.mean(counts), 1)
print(f"\n [{label}] per-D load: {dict(sorted(per_d.items()))}")
print(f" spread (max-min)/mean = {spread*100:.1f}% "
f"(ts=10 KVC 2P6D = ±26%, 8DP CA = ±10%)")
def execution_modes_table(rows, label):
"""Show top execution modes."""
ok = [r for r in rows if r.get("error") is None]
if not ok:
return
modes = Counter(r["execution_mode"] for r in ok)
print(f"\n [{label}] execution modes (n_ok={len(ok)}):")
for mode, cnt in modes.most_common(8):
mode_rows = [r for r in ok if r["execution_mode"] == mode]
lats = [r["latency_s"] for r in mode_rows if r.get("latency_s") is not None]
ttfts = [r["ttft_s"] for r in mode_rows if r.get("ttft_s") is not None]
if lats:
print(f" {mode:<55} {cnt:>5} ({cnt/len(ok)*100:>4.1f}%) "
f"lat p50={pct(lats,50):.3f}s p90={pct(lats,90):.3f}s ttft p50={pct(ttfts,50):.3f}s")
def lru_vs_errors(run_dir, label):
"""§2: trim events vs KVTransferError per worker."""
log_dir = run_dir / "logs"
if not log_dir.exists():
return
print(f"\n [{label}] D-side LRU vs errors (from worker logs):")
print(f" {'worker':<14}{'trim':>8}{'KVTransferError':>20}{'peak_token_usage':>20}")
for log_file in sorted(log_dir.glob("decode-*.log")):
worker = log_file.stem
text = log_file.read_text(errors="ignore")
trim_count = len(re.findall(r"Trimmed decode session cache", text))
err_count = len(re.findall(r"KVTransferError", text))
usages = re.findall(r"token usage: ([\d.]+)", text)
peak = max((float(u) for u in usages), default=0.0)
print(f" {worker:<14}{trim_count:>8}{err_count:>20}{peak:>20.3f}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--root", default="outputs/qwen3-30b-tp1-ts1-validation",
help="Sweep output root")
args = parser.parse_args()
root = Path(args.root)
if not root.is_absolute():
root = Path("/mnt/kzlin/workflow/pd-hybrid/agentic-pd-hybrid") / root
# Load all available runs
stats = []
rows_by_run = {}
for label in ("kvc_1p3d_run1", "kvc_1p3d_run2", "kvc_1p3d_run3", "dp4"):
m = root / f"{label}_metrics.jsonl"
s = root / f"{label}_summary.json"
if not m.exists() or not s.exists():
print(f" [{label}] not yet available ({m.name})")
continue
rows = load_metrics(m)
summary = load_summary(s)
rows_by_run[label] = rows
stats.append(summarize_run(label, rows, summary))
if not stats:
print("No runs available yet.")
return
# 1. Headline table
headline_table(stats)
# 2. §1 session pinning per KVC run + per-D balance + execution modes
print("\n" + "=" * 110)
print("§1 / §7: SESSION PINNING + LOAD BALANCE")
print("=" * 110)
per_run_rates = []
for label, rows in rows_by_run.items():
if not label.startswith("kvc_"):
continue
pin = session_pinning(rows, label)
if pin:
print(f"\n [{label}] sessions={pin['n_sessions']} "
f"avg_distinct_D={pin['avg_distinct_D']:.2f} "
f"max_distinct_D={pin['max_distinct_D']} "
f"(ts=10 baseline avg=1.00 → 100% pin)")
rates = direct_to_d_distribution(rows, label)
per_run_rates.append(rates)
per_d_balance(rows, label)
execution_modes_table(rows, label)
# 3. §1 cross-run starvation
if len(per_run_rates) >= 2:
print("\n" + "=" * 110)
print(f"§1 CROSS-RUN STARVATION (across {len(per_run_rates)} KVC runs)")
print("=" * 110)
cross = starved_cross_run(per_run_rates)
if cross:
n_starved = len(cross["consistently_starved"])
n_lucky = len(cross["consistently_lucky"])
print(f"\n Sessions starved (<20% direct-to-D) in all {cross['n_runs']} runs: {n_starved}")
print(f" Sessions lucky (>80% direct-to-D) in all {cross['n_runs']} runs: {n_lucky}")
print(f" (ts=10 baseline: 13/52 starved, 14/52 lucky — extreme bimodal)")
# session size comparison from run 1
if "kvc_1p3d_run1" in rows_by_run and n_starved and n_lucky:
session_size_comparison(rows_by_run["kvc_1p3d_run1"],
cross["consistently_starved"],
cross["consistently_lucky"],
"starved", "lucky")
# 4. §2 D-side LRU vs errors from raw logs
print("\n" + "=" * 110)
print("§2: D-SIDE LRU TRIM vs KVTransferError (from worker logs)")
print("=" * 110)
for label in rows_by_run:
if not label.startswith("kvc_"):
continue
# find the matching raw run dir
run_dirs = sorted(root.glob("kvcache-centric-*/"))
if not run_dirs:
continue
# naive: index matches run order; could be wrong if dirs got reordered
idx = int(label.split("run")[-1]) - 1
if idx < len(run_dirs):
lru_vs_errors(run_dirs[idx], label)
# 5. DP-only inspection
if "dp4" in rows_by_run:
print("\n" + "=" * 110)
print("4DP CA SANITY")
print("=" * 110)
per_d_balance(rows_by_run["dp4"], "dp4")
execution_modes_table(rows_by_run["dp4"], "dp4")
if __name__ == "__main__":
main()