Make the offered-load axis session-coherent
Phase 1 of the two-stop work. Subsampling the trace by per-request uniform score broke multi-turn sessions (a kept turn-2 could lose its turn-1), which lowered the realized KV-cache hit rate as offered load dropped — so the feasibility boundary was measured on a workload with a different C than production, contradicting the paper's scale-stationary L-C-A premise. prepare_trace_windows now resolves each row's session root via the parent_chat_id chain in a single streaming pass and assigns sampling_u per session, so thresholding keeps or drops whole sessions and preserves intra-session prefix reuse. Rows whose parent fell outside the span fall back to grouping under the parent id. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -92,17 +92,39 @@ def parse_args() -> argparse.Namespace:
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def stable_uniform(*, seed: int, window_id: str, index: int, row: dict[str, Any]) -> float:
|
||||
def resolve_session_root(row: dict[str, Any], root_of: dict[Any, Any]) -> Any:
|
||||
"""Resolve the session root chat_id for a trace row.
|
||||
|
||||
Sessions are multi-turn chains linked via parent_chat_id (turn>1 points to the
|
||||
parent turn's chat_id, the root turn has parent_chat_id=-1). Because parent
|
||||
turns precede their children in time, a single streaming pass that records
|
||||
chat_id -> root resolves the full chain. Rows whose parent is not yet known
|
||||
(e.g. it fell outside the materialized span) fall back to the parent id so
|
||||
siblings still group together.
|
||||
"""
|
||||
chat_id = row.get("chat_id")
|
||||
parent = row.get("parent_chat_id")
|
||||
parent_is_root = (
|
||||
parent is None
|
||||
or (isinstance(parent, (int, float)) and not isinstance(parent, bool) and int(parent) < 0)
|
||||
)
|
||||
root = chat_id if parent_is_root else root_of.get(parent, parent)
|
||||
if chat_id is not None:
|
||||
root_of[chat_id] = root
|
||||
return root
|
||||
|
||||
|
||||
def session_uniform(*, seed: int, window_id: str, session_root: Any) -> float:
|
||||
"""Deterministic per-session uniform score in [0, 1).
|
||||
|
||||
All turns of a session share one score, so thresholding sampling_u keeps or
|
||||
drops whole sessions and preserves intra-session prefix (KV-cache) reuse.
|
||||
"""
|
||||
payload = json.dumps(
|
||||
{
|
||||
"seed": seed,
|
||||
"window_id": window_id,
|
||||
"index": index,
|
||||
"timestamp": row.get("timestamp"),
|
||||
"input_length": row.get("input_length"),
|
||||
"output_length": row.get("output_length"),
|
||||
"chat_id": row.get("chat_id"),
|
||||
"turn": row.get("turn"),
|
||||
"session_root": session_root,
|
||||
},
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
@@ -241,12 +263,16 @@ def materialize_windows(
|
||||
bucket = grouped[(trace_path, prompt_path)]
|
||||
bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"])))
|
||||
matched_rows = 0
|
||||
root_of: dict[Any, Any] = {}
|
||||
with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle:
|
||||
for trace_raw, prompt_raw in zip(trace_handle, prompt_handle):
|
||||
trace_raw = trace_raw.strip()
|
||||
if not trace_raw:
|
||||
continue
|
||||
trace_row = json.loads(trace_raw)
|
||||
# Resolve session linkage for every row (even unmatched ones)
|
||||
# so multi-turn chains crossing the window edge still group.
|
||||
session_root = resolve_session_root(trace_row, root_of)
|
||||
timestamp = float(trace_row.get("timestamp") or 0.0)
|
||||
matched_window: dict[str, Any] | None = None
|
||||
for window in bucket:
|
||||
@@ -267,11 +293,11 @@ def materialize_windows(
|
||||
start = float(matched_window["window_start"])
|
||||
out["source_timestamp"] = timestamp
|
||||
out["timestamp"] = timestamp - start
|
||||
out["sampling_u"] = stable_uniform(
|
||||
out["session_root"] = session_root
|
||||
out["sampling_u"] = session_uniform(
|
||||
seed=sample_seed,
|
||||
window_id=window_id,
|
||||
index=stats_by_window[window_id].num_requests,
|
||||
row=merged,
|
||||
session_root=session_root,
|
||||
)
|
||||
handles[window_id].write(json.dumps(out, ensure_ascii=False) + "\n")
|
||||
stats_by_window[window_id].record(out)
|
||||
@@ -311,7 +337,7 @@ def build_output_window(
|
||||
output["num_excluded_too_long"] = 0
|
||||
output["sampling_u_field"] = "sampling_u"
|
||||
output["sampling_seed"] = int(sample_seed)
|
||||
output["sampling_strategy"] = "fixed_uniform_score"
|
||||
output["sampling_strategy"] = "session_coherent_uniform_score"
|
||||
output["first_request_ts"] = stats.first_request_ts
|
||||
output["last_request_ts"] = stats.last_request_ts
|
||||
output["first_request_index"] = stats.first_request_index
|
||||
|
||||
Reference in New Issue
Block a user