Make the offered-load axis session-coherent
Phase 1 of the two-stop work. Subsampling the trace by per-request uniform score broke multi-turn sessions (a kept turn-2 could lose its turn-1), which lowered the realized KV-cache hit rate as offered load dropped — so the feasibility boundary was measured on a workload with a different C than production, contradicting the paper's scale-stationary L-C-A premise. prepare_trace_windows now resolves each row's session root via the parent_chat_id chain in a single streaming pass and assigns sampling_u per session, so thresholding keeps or drops whole sessions and preserves intra-session prefix reuse. Rows whose parent fell outside the span fall back to grouping under the parent id. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -92,17 +92,39 @@ def parse_args() -> argparse.Namespace:
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def stable_uniform(*, seed: int, window_id: str, index: int, row: dict[str, Any]) -> float:
|
||||
def resolve_session_root(row: dict[str, Any], root_of: dict[Any, Any]) -> Any:
|
||||
"""Resolve the session root chat_id for a trace row.
|
||||
|
||||
Sessions are multi-turn chains linked via parent_chat_id (turn>1 points to the
|
||||
parent turn's chat_id, the root turn has parent_chat_id=-1). Because parent
|
||||
turns precede their children in time, a single streaming pass that records
|
||||
chat_id -> root resolves the full chain. Rows whose parent is not yet known
|
||||
(e.g. it fell outside the materialized span) fall back to the parent id so
|
||||
siblings still group together.
|
||||
"""
|
||||
chat_id = row.get("chat_id")
|
||||
parent = row.get("parent_chat_id")
|
||||
parent_is_root = (
|
||||
parent is None
|
||||
or (isinstance(parent, (int, float)) and not isinstance(parent, bool) and int(parent) < 0)
|
||||
)
|
||||
root = chat_id if parent_is_root else root_of.get(parent, parent)
|
||||
if chat_id is not None:
|
||||
root_of[chat_id] = root
|
||||
return root
|
||||
|
||||
|
||||
def session_uniform(*, seed: int, window_id: str, session_root: Any) -> float:
|
||||
"""Deterministic per-session uniform score in [0, 1).
|
||||
|
||||
All turns of a session share one score, so thresholding sampling_u keeps or
|
||||
drops whole sessions and preserves intra-session prefix (KV-cache) reuse.
|
||||
"""
|
||||
payload = json.dumps(
|
||||
{
|
||||
"seed": seed,
|
||||
"window_id": window_id,
|
||||
"index": index,
|
||||
"timestamp": row.get("timestamp"),
|
||||
"input_length": row.get("input_length"),
|
||||
"output_length": row.get("output_length"),
|
||||
"chat_id": row.get("chat_id"),
|
||||
"turn": row.get("turn"),
|
||||
"session_root": session_root,
|
||||
},
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
@@ -241,12 +263,16 @@ def materialize_windows(
|
||||
bucket = grouped[(trace_path, prompt_path)]
|
||||
bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"])))
|
||||
matched_rows = 0
|
||||
root_of: dict[Any, Any] = {}
|
||||
with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle:
|
||||
for trace_raw, prompt_raw in zip(trace_handle, prompt_handle):
|
||||
trace_raw = trace_raw.strip()
|
||||
if not trace_raw:
|
||||
continue
|
||||
trace_row = json.loads(trace_raw)
|
||||
# Resolve session linkage for every row (even unmatched ones)
|
||||
# so multi-turn chains crossing the window edge still group.
|
||||
session_root = resolve_session_root(trace_row, root_of)
|
||||
timestamp = float(trace_row.get("timestamp") or 0.0)
|
||||
matched_window: dict[str, Any] | None = None
|
||||
for window in bucket:
|
||||
@@ -267,11 +293,11 @@ def materialize_windows(
|
||||
start = float(matched_window["window_start"])
|
||||
out["source_timestamp"] = timestamp
|
||||
out["timestamp"] = timestamp - start
|
||||
out["sampling_u"] = stable_uniform(
|
||||
out["session_root"] = session_root
|
||||
out["sampling_u"] = session_uniform(
|
||||
seed=sample_seed,
|
||||
window_id=window_id,
|
||||
index=stats_by_window[window_id].num_requests,
|
||||
row=merged,
|
||||
session_root=session_root,
|
||||
)
|
||||
handles[window_id].write(json.dumps(out, ensure_ascii=False) + "\n")
|
||||
stats_by_window[window_id].record(out)
|
||||
@@ -311,7 +337,7 @@ def build_output_window(
|
||||
output["num_excluded_too_long"] = 0
|
||||
output["sampling_u_field"] = "sampling_u"
|
||||
output["sampling_seed"] = int(sample_seed)
|
||||
output["sampling_strategy"] = "fixed_uniform_score"
|
||||
output["sampling_strategy"] = "session_coherent_uniform_score"
|
||||
output["first_request_ts"] = stats.first_request_ts
|
||||
output["last_request_ts"] = stats.last_request_ts
|
||||
output["first_request_index"] = stats.first_request_index
|
||||
|
||||
100
tests/test_prepare_trace_windows.py
Normal file
100
tests/test_prepare_trace_windows.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
_SPEC = importlib.util.spec_from_file_location(
|
||||
"prepare_trace_windows",
|
||||
REPO_ROOT / "scripts" / "prepare_trace_windows.py",
|
||||
)
|
||||
assert _SPEC and _SPEC.loader
|
||||
ptw = importlib.util.module_from_spec(_SPEC)
|
||||
# Register before exec so dataclasses can resolve the module's annotations.
|
||||
sys.modules[_SPEC.name] = ptw
|
||||
_SPEC.loader.exec_module(ptw)
|
||||
|
||||
|
||||
class SessionCoherentSamplingTests(unittest.TestCase):
|
||||
def test_multi_hop_chain_resolves_to_root(self) -> None:
|
||||
root_of: dict[object, object] = {}
|
||||
# turn1 root, turn2 -> turn1, turn3 -> turn2 (multi-hop), streamed in order.
|
||||
self.assertEqual(
|
||||
ptw.resolve_session_root({"chat_id": 1, "parent_chat_id": -1, "turn": 1}, root_of),
|
||||
1,
|
||||
)
|
||||
self.assertEqual(
|
||||
ptw.resolve_session_root({"chat_id": 2, "parent_chat_id": 1, "turn": 2}, root_of),
|
||||
1,
|
||||
)
|
||||
self.assertEqual(
|
||||
ptw.resolve_session_root({"chat_id": 3, "parent_chat_id": 2, "turn": 3}, root_of),
|
||||
1,
|
||||
)
|
||||
|
||||
def test_unknown_parent_falls_back_to_parent_id(self) -> None:
|
||||
root_of: dict[object, object] = {}
|
||||
# parent never seen (fell outside the span): group siblings under the parent.
|
||||
self.assertEqual(
|
||||
ptw.resolve_session_root({"chat_id": 50, "parent_chat_id": 9, "turn": 2}, root_of),
|
||||
9,
|
||||
)
|
||||
self.assertEqual(
|
||||
ptw.resolve_session_root({"chat_id": 51, "parent_chat_id": 9, "turn": 2}, root_of),
|
||||
9,
|
||||
)
|
||||
|
||||
def test_all_turns_of_a_session_share_one_u(self) -> None:
|
||||
root_of: dict[object, object] = {}
|
||||
rows = [
|
||||
{"chat_id": 1, "parent_chat_id": -1, "turn": 1},
|
||||
{"chat_id": 2, "parent_chat_id": 1, "turn": 2},
|
||||
{"chat_id": 3, "parent_chat_id": 2, "turn": 3},
|
||||
]
|
||||
us = {
|
||||
ptw.session_uniform(
|
||||
seed=7,
|
||||
window_id="w",
|
||||
session_root=ptw.resolve_session_root(row, root_of),
|
||||
)
|
||||
for row in rows
|
||||
}
|
||||
self.assertEqual(len(us), 1)
|
||||
only = next(iter(us))
|
||||
self.assertGreaterEqual(only, 0.0)
|
||||
self.assertLess(only, 1.0)
|
||||
|
||||
def test_thresholding_keeps_or_drops_whole_sessions(self) -> None:
|
||||
# Two distinct sessions get distinct scores; a threshold either keeps a
|
||||
# session's every turn or none of them.
|
||||
seed, window_id = 20260325, "chat_w_x"
|
||||
sessions = {
|
||||
"A": [
|
||||
{"chat_id": 10, "parent_chat_id": -1},
|
||||
{"chat_id": 11, "parent_chat_id": 10},
|
||||
],
|
||||
"B": [
|
||||
{"chat_id": 20, "parent_chat_id": -1},
|
||||
{"chat_id": 21, "parent_chat_id": 20},
|
||||
],
|
||||
}
|
||||
root_of: dict[object, object] = {}
|
||||
scored: list[tuple[str, float]] = []
|
||||
for name, rows in sessions.items():
|
||||
for row in rows:
|
||||
root = ptw.resolve_session_root(row, root_of)
|
||||
u = ptw.session_uniform(seed=seed, window_id=window_id, session_root=root)
|
||||
scored.append((name, u))
|
||||
for name in sessions:
|
||||
us = {u for n, u in scored if n == name}
|
||||
self.assertEqual(len(us), 1, f"session {name} turns must share one u")
|
||||
for threshold in (0.0, 0.25, 0.5, 0.75, 1.0):
|
||||
for name in sessions:
|
||||
kept = {u <= threshold for n, u in scored if n == name}
|
||||
self.assertEqual(len(kept), 1, "a session must be kept/dropped as a whole")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user