Make the offered-load axis session-coherent
Phase 1 of the two-stop work. Subsampling the trace by per-request uniform score broke multi-turn sessions (a kept turn-2 could lose its turn-1), which lowered the realized KV-cache hit rate as offered load dropped — so the feasibility boundary was measured on a workload with a different C than production, contradicting the paper's scale-stationary L-C-A premise. prepare_trace_windows now resolves each row's session root via the parent_chat_id chain in a single streaming pass and assigns sampling_u per session, so thresholding keeps or drops whole sessions and preserves intra-session prefix reuse. Rows whose parent fell outside the span fall back to grouping under the parent id. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
100
tests/test_prepare_trace_windows.py
Normal file
100
tests/test_prepare_trace_windows.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
_SPEC = importlib.util.spec_from_file_location(
|
||||
"prepare_trace_windows",
|
||||
REPO_ROOT / "scripts" / "prepare_trace_windows.py",
|
||||
)
|
||||
assert _SPEC and _SPEC.loader
|
||||
ptw = importlib.util.module_from_spec(_SPEC)
|
||||
# Register before exec so dataclasses can resolve the module's annotations.
|
||||
sys.modules[_SPEC.name] = ptw
|
||||
_SPEC.loader.exec_module(ptw)
|
||||
|
||||
|
||||
class SessionCoherentSamplingTests(unittest.TestCase):
|
||||
def test_multi_hop_chain_resolves_to_root(self) -> None:
|
||||
root_of: dict[object, object] = {}
|
||||
# turn1 root, turn2 -> turn1, turn3 -> turn2 (multi-hop), streamed in order.
|
||||
self.assertEqual(
|
||||
ptw.resolve_session_root({"chat_id": 1, "parent_chat_id": -1, "turn": 1}, root_of),
|
||||
1,
|
||||
)
|
||||
self.assertEqual(
|
||||
ptw.resolve_session_root({"chat_id": 2, "parent_chat_id": 1, "turn": 2}, root_of),
|
||||
1,
|
||||
)
|
||||
self.assertEqual(
|
||||
ptw.resolve_session_root({"chat_id": 3, "parent_chat_id": 2, "turn": 3}, root_of),
|
||||
1,
|
||||
)
|
||||
|
||||
def test_unknown_parent_falls_back_to_parent_id(self) -> None:
|
||||
root_of: dict[object, object] = {}
|
||||
# parent never seen (fell outside the span): group siblings under the parent.
|
||||
self.assertEqual(
|
||||
ptw.resolve_session_root({"chat_id": 50, "parent_chat_id": 9, "turn": 2}, root_of),
|
||||
9,
|
||||
)
|
||||
self.assertEqual(
|
||||
ptw.resolve_session_root({"chat_id": 51, "parent_chat_id": 9, "turn": 2}, root_of),
|
||||
9,
|
||||
)
|
||||
|
||||
def test_all_turns_of_a_session_share_one_u(self) -> None:
|
||||
root_of: dict[object, object] = {}
|
||||
rows = [
|
||||
{"chat_id": 1, "parent_chat_id": -1, "turn": 1},
|
||||
{"chat_id": 2, "parent_chat_id": 1, "turn": 2},
|
||||
{"chat_id": 3, "parent_chat_id": 2, "turn": 3},
|
||||
]
|
||||
us = {
|
||||
ptw.session_uniform(
|
||||
seed=7,
|
||||
window_id="w",
|
||||
session_root=ptw.resolve_session_root(row, root_of),
|
||||
)
|
||||
for row in rows
|
||||
}
|
||||
self.assertEqual(len(us), 1)
|
||||
only = next(iter(us))
|
||||
self.assertGreaterEqual(only, 0.0)
|
||||
self.assertLess(only, 1.0)
|
||||
|
||||
def test_thresholding_keeps_or_drops_whole_sessions(self) -> None:
|
||||
# Two distinct sessions get distinct scores; a threshold either keeps a
|
||||
# session's every turn or none of them.
|
||||
seed, window_id = 20260325, "chat_w_x"
|
||||
sessions = {
|
||||
"A": [
|
||||
{"chat_id": 10, "parent_chat_id": -1},
|
||||
{"chat_id": 11, "parent_chat_id": 10},
|
||||
],
|
||||
"B": [
|
||||
{"chat_id": 20, "parent_chat_id": -1},
|
||||
{"chat_id": 21, "parent_chat_id": 20},
|
||||
],
|
||||
}
|
||||
root_of: dict[object, object] = {}
|
||||
scored: list[tuple[str, float]] = []
|
||||
for name, rows in sessions.items():
|
||||
for row in rows:
|
||||
root = ptw.resolve_session_root(row, root_of)
|
||||
u = ptw.session_uniform(seed=seed, window_id=window_id, session_root=root)
|
||||
scored.append((name, u))
|
||||
for name in sessions:
|
||||
us = {u for n, u in scored if n == name}
|
||||
self.assertEqual(len(us), 1, f"session {name} turns must share one u")
|
||||
for threshold in (0.0, 0.25, 0.5, 0.75, 1.0):
|
||||
for name in sessions:
|
||||
kept = {u <= threshold for n, u in scored if n == name}
|
||||
self.assertEqual(len(kept), 1, "a session must be kept/dropped as a whole")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user