Elastic P2P v4: error rate 25% -> 4%, TTFT p50 -12% (median-tail tradeoff)
Fixed offload decision: removed p>=d gate (was blocking all offloads), added MAX_OFFLOAD_INFLIGHT=4 cap and p_saturated threshold. Result (200 req, fresh restart): Baseline: 99% success, TTFT=1.080/9.410, TPOT90=0.076, E2E=5.306 Elastic: 96% success, TTFT=0.946/15.843, TPOT90=0.077, E2E=5.717 Architectural tradeoff confirmed: - Median (p50) improves: D instances not disrupted by heavy prefill - Tail (p90) worsens: offloaded HEAVY requests pay KV transfer cost - TPOT unchanged: decode isolation is not the bottleneck To improve p90: need layerwise pipelined KV transfer (overlap with prefill compute) or smarter offload gating that avoids offloading the very largest requests (which have the longest prefill time and generate the most KV). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -116,6 +116,8 @@ decode_instances: list[InstanceState] = []
|
||||
session_affinity: dict[str, int] = {}
|
||||
is_pd_sep = False
|
||||
_breakdown_log: list[dict] = []
|
||||
_offload_inflight = 0 # number of currently in-flight offloaded HEAVY requests
|
||||
MAX_OFFLOAD_INFLIGHT = 4 # cap concurrent offloads to prevent P overload
|
||||
|
||||
|
||||
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
|
||||
@@ -242,18 +244,21 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
avg_load = max(sum(i.ongoing_tokens for i in combined_instances) / len(combined_instances), 1.0)
|
||||
|
||||
# Decision logic:
|
||||
# 1. P must be less loaded than D (otherwise offload makes things worse)
|
||||
# 2. P must not be overloaded (ongoing > 1.5x average = would queue too long)
|
||||
# 3. D should be currently decoding (otherwise no disruption to avoid)
|
||||
if p_inst.ongoing_tokens >= d_inst.ongoing_tokens:
|
||||
offload_reason = "p_busier_than_d"
|
||||
elif p_inst.ongoing_tokens > avg_load * 1.5:
|
||||
offload_reason = "p_overloaded"
|
||||
elif d_inst.ongoing_decode_tokens == 0 and d_inst.ongoing_tokens < avg_load * 0.5:
|
||||
offload_reason = "d_idle_no_benefit"
|
||||
# 1. Global cap: max N concurrent offloads (prevents all-offload storm)
|
||||
# 2. P must not already be saturated with heavy prefills
|
||||
# 3. D must be doing something (otherwise no benefit from offloading)
|
||||
# NOTE: We do NOT require P < D. P can be busier than D — the point
|
||||
# is to keep heavy prefill OFF the session-sticky D instance so D's
|
||||
# decode is not disrupted and D's KV cache is available for future turns.
|
||||
global _offload_inflight
|
||||
if _offload_inflight >= MAX_OFFLOAD_INFLIGHT:
|
||||
offload_reason = "max_concurrent_reached"
|
||||
elif p_inst.ongoing_tokens >= HEAVY_THRESHOLD * 2:
|
||||
offload_reason = "p_saturated"
|
||||
else:
|
||||
use_offload = True
|
||||
offload_reason = "p_available_d_busy"
|
||||
offload_reason = "offload_accepted"
|
||||
_offload_inflight += 1
|
||||
|
||||
if use_offload:
|
||||
d_idx = best_idx
|
||||
@@ -331,9 +336,12 @@ async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
breakdown["error"] = str(e)
|
||||
_breakdown_log.append(breakdown)
|
||||
global _offload_inflight
|
||||
_offload_inflight = max(0, _offload_inflight - 1)
|
||||
raise HTTPException(status_code=502, detail="Prefill failed: %s" % e)
|
||||
finally:
|
||||
p_inst.ongoing_tokens -= input_length
|
||||
_offload_inflight = max(0, _offload_inflight - 1)
|
||||
|
||||
# Step 2: Stream decode on d_inst (pulls KV from Mooncake)
|
||||
d_inst.ongoing_tokens += input_length
|
||||
|
||||
56
scripts/compare_elastic_v4.py
Normal file
56
scripts/compare_elastic_v4.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Compare elastic v4 (cap=4, relaxed conditions) vs baseline."""
|
||||
import json, os
|
||||
|
||||
def s(path):
|
||||
rows = [json.loads(l) for l in open(path)]
|
||||
ok = [r for r in rows if not r.get("error")]
|
||||
ttfts = sorted([r["ttft_s"] for r in ok if r.get("ttft_s")])
|
||||
tpots = sorted([r["tpot_s"] for r in ok if r.get("tpot_s") and r["tpot_s"]>0])
|
||||
lats = sorted([r["latency_s"] for r in ok])
|
||||
p = lambda v,q: v[min(int(q*len(v)),len(v)-1)] if v else 0
|
||||
ok_inp = sorted([r["input_length"] for r in ok])
|
||||
err_inp = sorted([r["input_length"] for r in rows if r.get("error")])
|
||||
return {"ok": len(ok), "n": len(rows),
|
||||
"t50": p(ttfts,.5), "t90": p(ttfts,.9),
|
||||
"p50": p(tpots,.5), "p90": p(tpots,.9),
|
||||
"e50": p(lats,.5),
|
||||
"inp50": p(ok_inp,.5), "inp90": p(ok_inp,.9),
|
||||
"err_inp50": p(err_inp,.5) if err_inp else 0}
|
||||
|
||||
print("ELASTIC P2P v4 vs BASELINE (both 200 req)")
|
||||
print("=" * 80)
|
||||
fmt = "%-32s %7s %8s %8s %8s %8s %8s %8s"
|
||||
print(fmt % ("Config", "OK/N", "TTFT50", "TTFT90", "TPOT90", "E2E50", "inp_p50", "err_inp"))
|
||||
print("-" * 80)
|
||||
|
||||
configs = [
|
||||
("outputs/baseline_dash1/metrics.jsonl", "Baseline (8 combined, dash1)"),
|
||||
("outputs/elastic_v4/metrics.jsonl", "Elastic P2P (cap=4, dash0)"),
|
||||
]
|
||||
results = {}
|
||||
for path, label in configs:
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
r = s(path)
|
||||
results[label] = r
|
||||
print(fmt % (label, "%d/%d" % (r["ok"],r["n"]),
|
||||
"%.3f" % r["t50"], "%.3f" % r["t90"], "%.3f" % r["p90"],
|
||||
"%.3f" % r["e50"], str(r["inp50"]), str(r["err_inp50"])))
|
||||
|
||||
if len(results) == 2:
|
||||
b = list(results.values())[0]
|
||||
a = list(results.values())[1]
|
||||
print()
|
||||
print("DELTA (Elastic vs Baseline):")
|
||||
for label, bv, av in [
|
||||
("TTFT p50", b["t50"], a["t50"]),
|
||||
("TTFT p90", b["t90"], a["t90"]),
|
||||
("TPOT p90", b["p90"], a["p90"]),
|
||||
("E2E p50", b["e50"], a["e50"]),
|
||||
]:
|
||||
d = (av/bv-1)*100 if bv > 0 else 0
|
||||
print(" %s: %.3f -> %.3f (%+.1f%%)" % (label, bv, av, d))
|
||||
print(" Success: %d/%d (%.1f%%) -> %d/%d (%.1f%%)" % (
|
||||
b["ok"], b["n"], b["ok"]*100/b["n"],
|
||||
a["ok"], a["n"], a["ok"]*100/a["n"]))
|
||||
print(" Input coverage p50: %s -> %s (bias check)" % (b["inp50"], a["inp50"]))
|
||||
Reference in New Issue
Block a user