Files
aituner/tests/test_prepare_trace_windows.py
Gahow Wang 0f15bbc3f1 Make the offered-load axis session-coherent
Phase 1 of the two-stop work. Subsampling the trace by per-request uniform score
broke multi-turn sessions (a kept turn-2 could lose its turn-1), which lowered the
realized KV-cache hit rate as offered load dropped — so the feasibility boundary
was measured on a workload with a different C than production, contradicting the
paper's scale-stationary L-C-A premise.

prepare_trace_windows now resolves each row's session root via the parent_chat_id
chain in a single streaming pass and assigns sampling_u per session, so thresholding
keeps or drops whole sessions and preserves intra-session prefix reuse. Rows whose
parent fell outside the span fall back to grouping under the parent id.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 14:16:06 +08:00

101 lines
3.7 KiB
Python

from __future__ import annotations
import importlib.util
import sys
import unittest
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
_SPEC = importlib.util.spec_from_file_location(
"prepare_trace_windows",
REPO_ROOT / "scripts" / "prepare_trace_windows.py",
)
assert _SPEC and _SPEC.loader
ptw = importlib.util.module_from_spec(_SPEC)
# Register before exec so dataclasses can resolve the module's annotations.
sys.modules[_SPEC.name] = ptw
_SPEC.loader.exec_module(ptw)
class SessionCoherentSamplingTests(unittest.TestCase):
def test_multi_hop_chain_resolves_to_root(self) -> None:
root_of: dict[object, object] = {}
# turn1 root, turn2 -> turn1, turn3 -> turn2 (multi-hop), streamed in order.
self.assertEqual(
ptw.resolve_session_root({"chat_id": 1, "parent_chat_id": -1, "turn": 1}, root_of),
1,
)
self.assertEqual(
ptw.resolve_session_root({"chat_id": 2, "parent_chat_id": 1, "turn": 2}, root_of),
1,
)
self.assertEqual(
ptw.resolve_session_root({"chat_id": 3, "parent_chat_id": 2, "turn": 3}, root_of),
1,
)
def test_unknown_parent_falls_back_to_parent_id(self) -> None:
root_of: dict[object, object] = {}
# parent never seen (fell outside the span): group siblings under the parent.
self.assertEqual(
ptw.resolve_session_root({"chat_id": 50, "parent_chat_id": 9, "turn": 2}, root_of),
9,
)
self.assertEqual(
ptw.resolve_session_root({"chat_id": 51, "parent_chat_id": 9, "turn": 2}, root_of),
9,
)
def test_all_turns_of_a_session_share_one_u(self) -> None:
root_of: dict[object, object] = {}
rows = [
{"chat_id": 1, "parent_chat_id": -1, "turn": 1},
{"chat_id": 2, "parent_chat_id": 1, "turn": 2},
{"chat_id": 3, "parent_chat_id": 2, "turn": 3},
]
us = {
ptw.session_uniform(
seed=7,
window_id="w",
session_root=ptw.resolve_session_root(row, root_of),
)
for row in rows
}
self.assertEqual(len(us), 1)
only = next(iter(us))
self.assertGreaterEqual(only, 0.0)
self.assertLess(only, 1.0)
def test_thresholding_keeps_or_drops_whole_sessions(self) -> None:
# Two distinct sessions get distinct scores; a threshold either keeps a
# session's every turn or none of them.
seed, window_id = 20260325, "chat_w_x"
sessions = {
"A": [
{"chat_id": 10, "parent_chat_id": -1},
{"chat_id": 11, "parent_chat_id": 10},
],
"B": [
{"chat_id": 20, "parent_chat_id": -1},
{"chat_id": 21, "parent_chat_id": 20},
],
}
root_of: dict[object, object] = {}
scored: list[tuple[str, float]] = []
for name, rows in sessions.items():
for row in rows:
root = ptw.resolve_session_root(row, root_of)
u = ptw.session_uniform(seed=seed, window_id=window_id, session_root=root)
scored.append((name, u))
for name in sessions:
us = {u for n, u in scored if n == name}
self.assertEqual(len(us), 1, f"session {name} turns must share one u")
for threshold in (0.0, 0.25, 0.5, 0.75, 1.0):
for name in sessions:
kept = {u <= threshold for n, u in scored if n == name}
self.assertEqual(len(kept), 1, "a session must be kept/dropped as a whole")
if __name__ == "__main__":
unittest.main()