Speed up raw trace window extraction

This commit is contained in:
2026-04-04 21:42:02 +08:00
parent 65b122fd4b
commit 69f666593e

View File

@@ -203,24 +203,29 @@ def extract_windows(windows: list[dict[str, Any]], *, sample_seed: int) -> dict[
if not trace_raw or not prompt_raw: if not trace_raw or not prompt_raw:
continue continue
trace_row = json.loads(trace_raw) trace_row = json.loads(trace_raw)
prompt_row = json.loads(prompt_raw) timestamp = float(trace_row.get("timestamp") or 0.0)
merged = _merge_trace_and_prompt(trace_row, prompt_row) matched_window: dict[str, Any] | None = None
timestamp = float(merged.get("timestamp") or 0.0)
for window in bucket: for window in bucket:
start = float(window["window_start"]) start = float(window["window_start"])
end = float(window["window_end"]) end = float(window["window_end"])
if start <= timestamp < end: if start <= timestamp < end:
matched_window = window
break
if matched_window is None:
continue
prompt_row = json.loads(prompt_raw)
merged = _merge_trace_and_prompt(trace_row, prompt_row)
out = dict(merged) out = dict(merged)
start = float(matched_window["window_start"])
out["source_timestamp"] = timestamp out["source_timestamp"] = timestamp
out["timestamp"] = timestamp - start out["timestamp"] = timestamp - start
out["sampling_u"] = stable_uniform( out["sampling_u"] = stable_uniform(
seed=sample_seed, seed=sample_seed,
window_id=str(window["window_id"]), window_id=str(matched_window["window_id"]),
index=len(extracted[str(window["window_id"])]), index=len(extracted[str(matched_window["window_id"])]),
row=merged, row=merged,
) )
extracted[str(window["window_id"])].append(out) extracted[str(matched_window["window_id"])].append(out)
break
return extracted return extracted