Make the offered-load axis session-coherent

Phase 1 of the two-stop work. Subsampling the trace by per-request uniform score
broke multi-turn sessions (a kept turn-2 could lose its turn-1), which lowered the
realized KV-cache hit rate as offered load dropped — so the feasibility boundary
was measured on a workload with a different C than production, contradicting the
paper's scale-stationary L-C-A premise.

prepare_trace_windows now resolves each row's session root via the parent_chat_id
chain in a single streaming pass and assigns sampling_u per session, so thresholding
keeps or drops whole sessions and preserves intra-session prefix reuse. Rows whose
parent fell outside the span fall back to grouping under the parent id.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 14:16:06 +08:00
parent 6f8e3c95c1
commit 0f15bbc3f1
2 changed files with 137 additions and 11 deletions

View File

@@ -92,17 +92,39 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args() return parser.parse_args()
def stable_uniform(*, seed: int, window_id: str, index: int, row: dict[str, Any]) -> float: def resolve_session_root(row: dict[str, Any], root_of: dict[Any, Any]) -> Any:
"""Resolve the session root chat_id for a trace row.
Sessions are multi-turn chains linked via parent_chat_id (turn>1 points to the
parent turn's chat_id, the root turn has parent_chat_id=-1). Because parent
turns precede their children in time, a single streaming pass that records
chat_id -> root resolves the full chain. Rows whose parent is not yet known
(e.g. it fell outside the materialized span) fall back to the parent id so
siblings still group together.
"""
chat_id = row.get("chat_id")
parent = row.get("parent_chat_id")
parent_is_root = (
parent is None
or (isinstance(parent, (int, float)) and not isinstance(parent, bool) and int(parent) < 0)
)
root = chat_id if parent_is_root else root_of.get(parent, parent)
if chat_id is not None:
root_of[chat_id] = root
return root
def session_uniform(*, seed: int, window_id: str, session_root: Any) -> float:
"""Deterministic per-session uniform score in [0, 1).
All turns of a session share one score, so thresholding sampling_u keeps or
drops whole sessions and preserves intra-session prefix (KV-cache) reuse.
"""
payload = json.dumps( payload = json.dumps(
{ {
"seed": seed, "seed": seed,
"window_id": window_id, "window_id": window_id,
"index": index, "session_root": session_root,
"timestamp": row.get("timestamp"),
"input_length": row.get("input_length"),
"output_length": row.get("output_length"),
"chat_id": row.get("chat_id"),
"turn": row.get("turn"),
}, },
sort_keys=True, sort_keys=True,
separators=(",", ":"), separators=(",", ":"),
@@ -241,12 +263,16 @@ def materialize_windows(
bucket = grouped[(trace_path, prompt_path)] bucket = grouped[(trace_path, prompt_path)]
bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"]))) bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"])))
matched_rows = 0 matched_rows = 0
root_of: dict[Any, Any] = {}
with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle: with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle:
for trace_raw, prompt_raw in zip(trace_handle, prompt_handle): for trace_raw, prompt_raw in zip(trace_handle, prompt_handle):
trace_raw = trace_raw.strip() trace_raw = trace_raw.strip()
if not trace_raw: if not trace_raw:
continue continue
trace_row = json.loads(trace_raw) trace_row = json.loads(trace_raw)
# Resolve session linkage for every row (even unmatched ones)
# so multi-turn chains crossing the window edge still group.
session_root = resolve_session_root(trace_row, root_of)
timestamp = float(trace_row.get("timestamp") or 0.0) timestamp = float(trace_row.get("timestamp") or 0.0)
matched_window: dict[str, Any] | None = None matched_window: dict[str, Any] | None = None
for window in bucket: for window in bucket:
@@ -267,11 +293,11 @@ def materialize_windows(
start = float(matched_window["window_start"]) start = float(matched_window["window_start"])
out["source_timestamp"] = timestamp out["source_timestamp"] = timestamp
out["timestamp"] = timestamp - start out["timestamp"] = timestamp - start
out["sampling_u"] = stable_uniform( out["session_root"] = session_root
out["sampling_u"] = session_uniform(
seed=sample_seed, seed=sample_seed,
window_id=window_id, window_id=window_id,
index=stats_by_window[window_id].num_requests, session_root=session_root,
row=merged,
) )
handles[window_id].write(json.dumps(out, ensure_ascii=False) + "\n") handles[window_id].write(json.dumps(out, ensure_ascii=False) + "\n")
stats_by_window[window_id].record(out) stats_by_window[window_id].record(out)
@@ -311,7 +337,7 @@ def build_output_window(
output["num_excluded_too_long"] = 0 output["num_excluded_too_long"] = 0
output["sampling_u_field"] = "sampling_u" output["sampling_u_field"] = "sampling_u"
output["sampling_seed"] = int(sample_seed) output["sampling_seed"] = int(sample_seed)
output["sampling_strategy"] = "fixed_uniform_score" output["sampling_strategy"] = "session_coherent_uniform_score"
output["first_request_ts"] = stats.first_request_ts output["first_request_ts"] = stats.first_request_ts
output["last_request_ts"] = stats.last_request_ts output["last_request_ts"] = stats.last_request_ts
output["first_request_index"] = stats.first_request_index output["first_request_index"] = stats.first_request_index

View File

@@ -0,0 +1,100 @@
from __future__ import annotations
import importlib.util
import sys
import unittest
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
_SPEC = importlib.util.spec_from_file_location(
"prepare_trace_windows",
REPO_ROOT / "scripts" / "prepare_trace_windows.py",
)
assert _SPEC and _SPEC.loader
ptw = importlib.util.module_from_spec(_SPEC)
# Register before exec so dataclasses can resolve the module's annotations.
sys.modules[_SPEC.name] = ptw
_SPEC.loader.exec_module(ptw)
class SessionCoherentSamplingTests(unittest.TestCase):
def test_multi_hop_chain_resolves_to_root(self) -> None:
root_of: dict[object, object] = {}
# turn1 root, turn2 -> turn1, turn3 -> turn2 (multi-hop), streamed in order.
self.assertEqual(
ptw.resolve_session_root({"chat_id": 1, "parent_chat_id": -1, "turn": 1}, root_of),
1,
)
self.assertEqual(
ptw.resolve_session_root({"chat_id": 2, "parent_chat_id": 1, "turn": 2}, root_of),
1,
)
self.assertEqual(
ptw.resolve_session_root({"chat_id": 3, "parent_chat_id": 2, "turn": 3}, root_of),
1,
)
def test_unknown_parent_falls_back_to_parent_id(self) -> None:
root_of: dict[object, object] = {}
# parent never seen (fell outside the span): group siblings under the parent.
self.assertEqual(
ptw.resolve_session_root({"chat_id": 50, "parent_chat_id": 9, "turn": 2}, root_of),
9,
)
self.assertEqual(
ptw.resolve_session_root({"chat_id": 51, "parent_chat_id": 9, "turn": 2}, root_of),
9,
)
def test_all_turns_of_a_session_share_one_u(self) -> None:
root_of: dict[object, object] = {}
rows = [
{"chat_id": 1, "parent_chat_id": -1, "turn": 1},
{"chat_id": 2, "parent_chat_id": 1, "turn": 2},
{"chat_id": 3, "parent_chat_id": 2, "turn": 3},
]
us = {
ptw.session_uniform(
seed=7,
window_id="w",
session_root=ptw.resolve_session_root(row, root_of),
)
for row in rows
}
self.assertEqual(len(us), 1)
only = next(iter(us))
self.assertGreaterEqual(only, 0.0)
self.assertLess(only, 1.0)
def test_thresholding_keeps_or_drops_whole_sessions(self) -> None:
# Two distinct sessions get distinct scores; a threshold either keeps a
# session's every turn or none of them.
seed, window_id = 20260325, "chat_w_x"
sessions = {
"A": [
{"chat_id": 10, "parent_chat_id": -1},
{"chat_id": 11, "parent_chat_id": 10},
],
"B": [
{"chat_id": 20, "parent_chat_id": -1},
{"chat_id": 21, "parent_chat_id": 20},
],
}
root_of: dict[object, object] = {}
scored: list[tuple[str, float]] = []
for name, rows in sessions.items():
for row in rows:
root = ptw.resolve_session_root(row, root_of)
u = ptw.session_uniform(seed=seed, window_id=window_id, session_root=root)
scored.append((name, u))
for name in sessions:
us = {u for n, u in scored if n == name}
self.assertEqual(len(us), 1, f"session {name} turns must share one u")
for threshold in (0.0, 0.25, 0.5, 0.75, 1.0):
for name in sessions:
kept = {u <= threshold for n, u in scored if n == name}
self.assertEqual(len(kept), 1, "a session must be kept/dropped as a whole")
if __name__ == "__main__":
unittest.main()