from __future__ import annotations import importlib.util import sys import unittest from pathlib import Path REPO_ROOT = Path(__file__).resolve().parents[1] _SPEC = importlib.util.spec_from_file_location( "prepare_trace_windows", REPO_ROOT / "scripts" / "prepare_trace_windows.py", ) assert _SPEC and _SPEC.loader ptw = importlib.util.module_from_spec(_SPEC) # Register before exec so dataclasses can resolve the module's annotations. sys.modules[_SPEC.name] = ptw _SPEC.loader.exec_module(ptw) class SessionCoherentSamplingTests(unittest.TestCase): def test_multi_hop_chain_resolves_to_root(self) -> None: root_of: dict[object, object] = {} # turn1 root, turn2 -> turn1, turn3 -> turn2 (multi-hop), streamed in order. self.assertEqual( ptw.resolve_session_root({"chat_id": 1, "parent_chat_id": -1, "turn": 1}, root_of), 1, ) self.assertEqual( ptw.resolve_session_root({"chat_id": 2, "parent_chat_id": 1, "turn": 2}, root_of), 1, ) self.assertEqual( ptw.resolve_session_root({"chat_id": 3, "parent_chat_id": 2, "turn": 3}, root_of), 1, ) def test_unknown_parent_falls_back_to_parent_id(self) -> None: root_of: dict[object, object] = {} # parent never seen (fell outside the span): group siblings under the parent. self.assertEqual( ptw.resolve_session_root({"chat_id": 50, "parent_chat_id": 9, "turn": 2}, root_of), 9, ) self.assertEqual( ptw.resolve_session_root({"chat_id": 51, "parent_chat_id": 9, "turn": 2}, root_of), 9, ) def test_all_turns_of_a_session_share_one_u(self) -> None: root_of: dict[object, object] = {} rows = [ {"chat_id": 1, "parent_chat_id": -1, "turn": 1}, {"chat_id": 2, "parent_chat_id": 1, "turn": 2}, {"chat_id": 3, "parent_chat_id": 2, "turn": 3}, ] us = { ptw.session_uniform( seed=7, window_id="w", session_root=ptw.resolve_session_root(row, root_of), ) for row in rows } self.assertEqual(len(us), 1) only = next(iter(us)) self.assertGreaterEqual(only, 0.0) self.assertLess(only, 1.0) def test_thresholding_keeps_or_drops_whole_sessions(self) -> None: # Two distinct sessions get distinct scores; a threshold either keeps a # session's every turn or none of them. seed, window_id = 20260325, "chat_w_x" sessions = { "A": [ {"chat_id": 10, "parent_chat_id": -1}, {"chat_id": 11, "parent_chat_id": 10}, ], "B": [ {"chat_id": 20, "parent_chat_id": -1}, {"chat_id": 21, "parent_chat_id": 20}, ], } root_of: dict[object, object] = {} scored: list[tuple[str, float]] = [] for name, rows in sessions.items(): for row in rows: root = ptw.resolve_session_root(row, root_of) u = ptw.session_uniform(seed=seed, window_id=window_id, session_root=root) scored.append((name, u)) for name in sessions: us = {u for n, u in scored if n == name} self.assertEqual(len(us), 1, f"session {name} turns must share one u") for threshold in (0.0, 0.25, 0.5, 0.75, 1.0): for name in sessions: kept = {u <= threshold for n, u in scored if n == name} self.assertEqual(len(kept), 1, "a session must be kept/dropped as a whole") if __name__ == "__main__": unittest.main()