Make the offered-load axis session-coherent

Phase 1 of the two-stop work. Subsampling the trace by per-request uniform score
broke multi-turn sessions (a kept turn-2 could lose its turn-1), which lowered the
realized KV-cache hit rate as offered load dropped — so the feasibility boundary
was measured on a workload with a different C than production, contradicting the
paper's scale-stationary L-C-A premise.

prepare_trace_windows now resolves each row's session root via the parent_chat_id
chain in a single streaming pass and assigns sampling_u per session, so thresholding
keeps or drops whole sessions and preserves intra-session prefix reuse. Rows whose
parent fell outside the span fall back to grouping under the parent id.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 14:16:06 +08:00
parent 6f8e3c95c1
commit 0f15bbc3f1
2 changed files with 137 additions and 11 deletions

View File

@@ -92,17 +92,39 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()
def stable_uniform(*, seed: int, window_id: str, index: int, row: dict[str, Any]) -> float:
def resolve_session_root(row: dict[str, Any], root_of: dict[Any, Any]) -> Any:
"""Resolve the session root chat_id for a trace row.
Sessions are multi-turn chains linked via parent_chat_id (turn>1 points to the
parent turn's chat_id, the root turn has parent_chat_id=-1). Because parent
turns precede their children in time, a single streaming pass that records
chat_id -> root resolves the full chain. Rows whose parent is not yet known
(e.g. it fell outside the materialized span) fall back to the parent id so
siblings still group together.
"""
chat_id = row.get("chat_id")
parent = row.get("parent_chat_id")
parent_is_root = (
parent is None
or (isinstance(parent, (int, float)) and not isinstance(parent, bool) and int(parent) < 0)
)
root = chat_id if parent_is_root else root_of.get(parent, parent)
if chat_id is not None:
root_of[chat_id] = root
return root
def session_uniform(*, seed: int, window_id: str, session_root: Any) -> float:
"""Deterministic per-session uniform score in [0, 1).
All turns of a session share one score, so thresholding sampling_u keeps or
drops whole sessions and preserves intra-session prefix (KV-cache) reuse.
"""
payload = json.dumps(
{
"seed": seed,
"window_id": window_id,
"index": index,
"timestamp": row.get("timestamp"),
"input_length": row.get("input_length"),
"output_length": row.get("output_length"),
"chat_id": row.get("chat_id"),
"turn": row.get("turn"),
"session_root": session_root,
},
sort_keys=True,
separators=(",", ":"),
@@ -241,12 +263,16 @@ def materialize_windows(
bucket = grouped[(trace_path, prompt_path)]
bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"])))
matched_rows = 0
root_of: dict[Any, Any] = {}
with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle:
for trace_raw, prompt_raw in zip(trace_handle, prompt_handle):
trace_raw = trace_raw.strip()
if not trace_raw:
continue
trace_row = json.loads(trace_raw)
# Resolve session linkage for every row (even unmatched ones)
# so multi-turn chains crossing the window edge still group.
session_root = resolve_session_root(trace_row, root_of)
timestamp = float(trace_row.get("timestamp") or 0.0)
matched_window: dict[str, Any] | None = None
for window in bucket:
@@ -267,11 +293,11 @@ def materialize_windows(
start = float(matched_window["window_start"])
out["source_timestamp"] = timestamp
out["timestamp"] = timestamp - start
out["sampling_u"] = stable_uniform(
out["session_root"] = session_root
out["sampling_u"] = session_uniform(
seed=sample_seed,
window_id=window_id,
index=stats_by_window[window_id].num_requests,
row=merged,
session_root=session_root,
)
handles[window_id].write(json.dumps(out, ensure_ascii=False) + "\n")
stats_by_window[window_id].record(out)
@@ -311,7 +337,7 @@ def build_output_window(
output["num_excluded_too_long"] = 0
output["sampling_u_field"] = "sampling_u"
output["sampling_seed"] = int(sample_seed)
output["sampling_strategy"] = "fixed_uniform_score"
output["sampling_strategy"] = "session_coherent_uniform_score"
output["first_request_ts"] = stats.first_request_ts
output["last_request_ts"] = stats.last_request_ts
output["first_request_index"] = stats.first_request_index