Speed up raw trace window extraction
This commit is contained in:
@@ -203,24 +203,29 @@ def extract_windows(windows: list[dict[str, Any]], *, sample_seed: int) -> dict[
|
||||
if not trace_raw or not prompt_raw:
|
||||
continue
|
||||
trace_row = json.loads(trace_raw)
|
||||
prompt_row = json.loads(prompt_raw)
|
||||
merged = _merge_trace_and_prompt(trace_row, prompt_row)
|
||||
timestamp = float(merged.get("timestamp") or 0.0)
|
||||
timestamp = float(trace_row.get("timestamp") or 0.0)
|
||||
matched_window: dict[str, Any] | None = None
|
||||
for window in bucket:
|
||||
start = float(window["window_start"])
|
||||
end = float(window["window_end"])
|
||||
if start <= timestamp < end:
|
||||
matched_window = window
|
||||
break
|
||||
if matched_window is None:
|
||||
continue
|
||||
prompt_row = json.loads(prompt_raw)
|
||||
merged = _merge_trace_and_prompt(trace_row, prompt_row)
|
||||
out = dict(merged)
|
||||
start = float(matched_window["window_start"])
|
||||
out["source_timestamp"] = timestamp
|
||||
out["timestamp"] = timestamp - start
|
||||
out["sampling_u"] = stable_uniform(
|
||||
seed=sample_seed,
|
||||
window_id=str(window["window_id"]),
|
||||
index=len(extracted[str(window["window_id"])]),
|
||||
window_id=str(matched_window["window_id"]),
|
||||
index=len(extracted[str(matched_window["window_id"])]),
|
||||
row=merged,
|
||||
)
|
||||
extracted[str(window["window_id"])].append(out)
|
||||
break
|
||||
extracted[str(matched_window["window_id"])].append(out)
|
||||
return extracted
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user