diff --git a/scripts/prepare_trace_windows.py b/scripts/prepare_trace_windows.py index beca987..09d2d48 100644 --- a/scripts/prepare_trace_windows.py +++ b/scripts/prepare_trace_windows.py @@ -203,24 +203,29 @@ def extract_windows(windows: list[dict[str, Any]], *, sample_seed: int) -> dict[ if not trace_raw or not prompt_raw: continue trace_row = json.loads(trace_raw) - prompt_row = json.loads(prompt_raw) - merged = _merge_trace_and_prompt(trace_row, prompt_row) - timestamp = float(merged.get("timestamp") or 0.0) + timestamp = float(trace_row.get("timestamp") or 0.0) + matched_window: dict[str, Any] | None = None for window in bucket: start = float(window["window_start"]) end = float(window["window_end"]) if start <= timestamp < end: - out = dict(merged) - out["source_timestamp"] = timestamp - out["timestamp"] = timestamp - start - out["sampling_u"] = stable_uniform( - seed=sample_seed, - window_id=str(window["window_id"]), - index=len(extracted[str(window["window_id"])]), - row=merged, - ) - extracted[str(window["window_id"])].append(out) + matched_window = window break + if matched_window is None: + continue + prompt_row = json.loads(prompt_raw) + merged = _merge_trace_and_prompt(trace_row, prompt_row) + out = dict(merged) + start = float(matched_window["window_start"]) + out["source_timestamp"] = timestamp + out["timestamp"] = timestamp - start + out["sampling_u"] = stable_uniform( + seed=sample_seed, + window_id=str(matched_window["window_id"]), + index=len(extracted[str(matched_window["window_id"])]), + row=merged, + ) + extracted[str(matched_window["window_id"])].append(out) return extracted