diff --git a/figs/working_set/glm5_fp8_tp8_b300.png b/figs/working_set/glm5_fp8_tp8_b300.png index e3ba68c..1a75f06 100644 Binary files a/figs/working_set/glm5_fp8_tp8_b300.png and b/figs/working_set/glm5_fp8_tp8_b300.png differ diff --git a/scripts/working_set_analysis.py b/scripts/working_set_analysis.py index 70a684b..43e3641 100644 --- a/scripts/working_set_analysis.py +++ b/scripts/working_set_analysis.py @@ -166,12 +166,12 @@ def plot(ws, hw, block_bytes, label, out_path): ceil = ws["apc_ceiling"] * 100 oracle_nodes = ws["oracle_peak_blocks"] * bgb / pool - # operating points up to the ceiling: beyond oracle, TTL is strictly worse, so drop. - rows = [r for r in ws["taus"] if r["tau"] <= 300] + # all operating points, out to the largest retention window (~50 nodes) + rows = list(ws["taus"]) nodes = np.array([r["peak_blocks"] * bgb / pool for r in rows]) apc = np.array([r["apc"] * 100 for r in rows]) - tau = np.array([r["tau"] for r in rows]) - XMAX = 16 + XMAX_L = 53 # left panel x-axis (nodes), shows up to T=1800s (~52 nodes) + XMAX_R = 16 # right panel y-axis (nodes) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) @@ -182,30 +182,33 @@ def plot(ws, hw, block_bytes, label, out_path): ax1.scatter([1], [apc_at_1], s=90, facecolors="none", edgecolors="#ff7f0e", lw=2, zorder=6) ax1.annotate(f"1 node -> ~{apc_at_1:.0f}% APC\n(TTL model; real LRU higher)", - (1, apc_at_1), textcoords="offset points", xytext=(12, -2), - fontsize=9, color="#ff7f0e", va="top") + (1, apc_at_1), textcoords="offset points", xytext=(14, 8), + fontsize=9, color="#ff7f0e", va="bottom") # label the well-separated decision-zone points for r, x, y in zip(rows, nodes, apc): if x >= 1.5: - ax1.annotate(f"keep {r['tau']:g}s reuse", (x, y), - textcoords="offset points", xytext=(6, 6), fontsize=8.5) - ax1.annotate("T<=10s reuse:\nall < 1.4 nodes", (0.5, 22), fontsize=8.5, + ax1.annotate(f"{r['tau']:g}s", (x, y), + textcoords="offset points", xytext=(5, 6), fontsize=9) + ax1.annotate("T<=10s reuse:\nall < 1.4 nodes", (1.5, 18), fontsize=8.5, color="#1f77b4", ha="left") + # diminishing returns past the oracle point + ax1.annotate("diminishing returns:\n14 -> 52 nodes buys only +6pp", + (30, 64), fontsize=9, color="#555", ha="center") # budget + ceiling ax1.axvspan(0, 1, color="#2ca02c", alpha=.08) ax1.axvline(1, ls="--", color="#2ca02c", lw=1.8) - ax1.text(1.05, 96, "1 B300 node (your budget)", color="#2ca02c", fontsize=9, va="top") + ax1.text(1.6, 96, "1 B300 node (your budget)", color="#2ca02c", fontsize=9, va="top") ax1.scatter([oracle_nodes], [ceil], marker="*", s=340, color="#d62728", zorder=7) - ax1.annotate(f"ceiling {ceil:.1f}%\noracle: {oracle_nodes:.0f} nodes", - (oracle_nodes, ceil), textcoords="offset points", xytext=(-10, -8), - fontsize=9, color="#d62728", ha="right", va="top") + ax1.annotate(f"ceiling {ceil:.1f}% — oracle/LRU\nreaches it at {oracle_nodes:.0f} nodes", + (oracle_nodes, ceil), textcoords="offset points", xytext=(12, -4), + fontsize=9, color="#d62728", ha="left", va="top") ax1.axhline(ceil, ls=":", color="#d62728", alpha=.5) - ax1.set_xlim(0, XMAX); ax1.set_ylim(0, 100) - ax1.set_xticks(range(0, XMAX + 1, 2)); ax1.set_xticks(range(0, XMAX + 1), minor=True) + ax1.set_xlim(0, XMAX_L); ax1.set_ylim(0, 100) + ax1.set_xticks(range(0, 51, 10)); ax1.set_xticks(range(0, XMAX_L, 5), minor=True) ax1.set_xlabel(f"# nodes of GPU HBM needed (1 node = {gpr}x {hw['gpu']} = {pool:.0f} GB KV)") ax1.set_ylabel("Prefix-cache hit rate (APC %)") ax1.set_title("Benefit vs cost: APC per cluster size", fontweight="bold") - ax1.grid(alpha=.3); ax1.grid(alpha=.15, which="minor"); ax1.legend(loc="center right") + ax1.grid(alpha=.3); ax1.grid(alpha=.15, which="minor"); ax1.legend(loc="lower right") # ===== panel 2: working set W(t) over time (steady -> peak ~ median) ===== apc_of = {r["tau"]: r["apc"] * 100 for r in ws["taus"]} @@ -221,7 +224,7 @@ def plot(ws, hw, block_bytes, label, out_path): ax2.axhline(oracle_nodes, ls="--", color="#d62728", lw=1.6, alpha=.8) ax2.text(t_min.max(), oracle_nodes, " ceiling: 14 nodes", color="#d62728", fontsize=8.5, va="center") - ax2.set_ylim(0, XMAX); ax2.set_yticks(range(0, XMAX + 1, 2)) + ax2.set_ylim(0, XMAX_R); ax2.set_yticks(range(0, XMAX_R + 1, 2)) ax2.set_xlim(0, t_min.max()) ax2.set_xlabel("wall-clock time into the trace (min)") ax2.set_ylabel("# nodes of GPU HBM resident (W(t))")