Stream trace window materialization
This commit is contained in:
@@ -5,6 +5,7 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
@@ -186,52 +187,97 @@ def _merge_trace_and_prompt(trace_row: dict[str, Any], prompt_row: dict[str, Any
|
|||||||
return merged
|
return merged
|
||||||
|
|
||||||
|
|
||||||
def extract_windows(windows: list[dict[str, Any]], *, sample_seed: int) -> dict[str, list[dict[str, Any]]]:
|
@dataclass
|
||||||
|
class WindowStats:
|
||||||
|
num_requests: int = 0
|
||||||
|
sum_input_length: int = 0
|
||||||
|
max_input_length: int = 0
|
||||||
|
first_request_ts: float | None = None
|
||||||
|
last_request_ts: float | None = None
|
||||||
|
first_request_index: int | None = None
|
||||||
|
last_request_index: int | None = None
|
||||||
|
|
||||||
|
def record(self, row: dict[str, Any]) -> None:
|
||||||
|
input_length = int(row.get("input_length") or 0)
|
||||||
|
timestamp = float(row["timestamp"])
|
||||||
|
if self.num_requests == 0:
|
||||||
|
self.first_request_ts = timestamp
|
||||||
|
self.first_request_index = 0
|
||||||
|
self.last_request_ts = timestamp
|
||||||
|
self.last_request_index = self.num_requests
|
||||||
|
self.num_requests += 1
|
||||||
|
self.sum_input_length += input_length
|
||||||
|
self.max_input_length = max(self.max_input_length, input_length)
|
||||||
|
|
||||||
|
|
||||||
|
def materialize_windows(
|
||||||
|
windows: list[dict[str, Any]], *, sample_seed: int, traces_dir: Path
|
||||||
|
) -> dict[str, WindowStats]:
|
||||||
grouped: dict[tuple[Path, Path], list[dict[str, Any]]] = {}
|
grouped: dict[tuple[Path, Path], list[dict[str, Any]]] = {}
|
||||||
for window in windows:
|
for window in windows:
|
||||||
trace_path = Path(window["source_trace_path"])
|
trace_path = Path(window["source_trace_path"])
|
||||||
prompt_path = Path(window["source_prompt_path"])
|
prompt_path = Path(window["source_prompt_path"])
|
||||||
grouped.setdefault((trace_path, prompt_path), []).append(window)
|
grouped.setdefault((trace_path, prompt_path), []).append(window)
|
||||||
|
|
||||||
extracted: dict[str, list[dict[str, Any]]] = {str(window["window_id"]): [] for window in windows}
|
stats_by_window = {str(window["window_id"]): WindowStats() for window in windows}
|
||||||
for (trace_path, prompt_path), bucket in grouped.items():
|
handles: dict[str, Any] = {}
|
||||||
bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"])))
|
try:
|
||||||
with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle:
|
for window in windows:
|
||||||
for index, (trace_raw, prompt_raw) in enumerate(zip(trace_handle, prompt_handle)):
|
window_id = str(window["window_id"])
|
||||||
trace_raw = trace_raw.strip()
|
handles[window_id] = (traces_dir / f"{window_id}.jsonl").open("w", encoding="utf-8")
|
||||||
prompt_raw = prompt_raw.strip()
|
|
||||||
if not trace_raw or not prompt_raw:
|
for trace_path, prompt_path in sorted(grouped.keys()):
|
||||||
continue
|
bucket = grouped[(trace_path, prompt_path)]
|
||||||
trace_row = json.loads(trace_raw)
|
bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"])))
|
||||||
timestamp = float(trace_row.get("timestamp") or 0.0)
|
matched_rows = 0
|
||||||
matched_window: dict[str, Any] | None = None
|
with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle:
|
||||||
for window in bucket:
|
for trace_raw, prompt_raw in zip(trace_handle, prompt_handle):
|
||||||
start = float(window["window_start"])
|
trace_raw = trace_raw.strip()
|
||||||
end = float(window["window_end"])
|
if not trace_raw:
|
||||||
if start <= timestamp < end:
|
continue
|
||||||
matched_window = window
|
trace_row = json.loads(trace_raw)
|
||||||
break
|
timestamp = float(trace_row.get("timestamp") or 0.0)
|
||||||
if matched_window is None:
|
matched_window: dict[str, Any] | None = None
|
||||||
continue
|
for window in bucket:
|
||||||
prompt_row = json.loads(prompt_raw)
|
start = float(window["window_start"])
|
||||||
merged = _merge_trace_and_prompt(trace_row, prompt_row)
|
end = float(window["window_end"])
|
||||||
out = dict(merged)
|
if start <= timestamp < end:
|
||||||
start = float(matched_window["window_start"])
|
matched_window = window
|
||||||
out["source_timestamp"] = timestamp
|
break
|
||||||
out["timestamp"] = timestamp - start
|
if matched_window is None:
|
||||||
out["sampling_u"] = stable_uniform(
|
continue
|
||||||
seed=sample_seed,
|
prompt_raw = prompt_raw.strip()
|
||||||
window_id=str(matched_window["window_id"]),
|
if not prompt_raw:
|
||||||
index=len(extracted[str(matched_window["window_id"])]),
|
continue
|
||||||
row=merged,
|
prompt_row = json.loads(prompt_raw)
|
||||||
)
|
merged = _merge_trace_and_prompt(trace_row, prompt_row)
|
||||||
extracted[str(matched_window["window_id"])].append(out)
|
window_id = str(matched_window["window_id"])
|
||||||
return extracted
|
out = dict(merged)
|
||||||
|
start = float(matched_window["window_start"])
|
||||||
|
out["source_timestamp"] = timestamp
|
||||||
|
out["timestamp"] = timestamp - start
|
||||||
|
out["sampling_u"] = stable_uniform(
|
||||||
|
seed=sample_seed,
|
||||||
|
window_id=window_id,
|
||||||
|
index=stats_by_window[window_id].num_requests,
|
||||||
|
row=merged,
|
||||||
|
)
|
||||||
|
handles[window_id].write(json.dumps(out, ensure_ascii=False) + "\n")
|
||||||
|
stats_by_window[window_id].record(out)
|
||||||
|
matched_rows += 1
|
||||||
|
print(
|
||||||
|
f"materialized {trace_path.name} -> matched_rows={matched_rows}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
for handle in handles.values():
|
||||||
|
handle.close()
|
||||||
|
return stats_by_window
|
||||||
|
|
||||||
|
|
||||||
def build_output_window(
|
def build_output_window(
|
||||||
window: dict[str, Any],
|
window: dict[str, Any],
|
||||||
rows: list[dict[str, Any]],
|
stats: WindowStats,
|
||||||
trace_relpath: str,
|
trace_relpath: str,
|
||||||
*,
|
*,
|
||||||
sample_seed: int,
|
sample_seed: int,
|
||||||
@@ -240,25 +286,17 @@ def build_output_window(
|
|||||||
output["trace_file"] = trace_relpath
|
output["trace_file"] = trace_relpath
|
||||||
output["window_start"] = 0.0
|
output["window_start"] = 0.0
|
||||||
output["window_end"] = float(window["window_end"]) - float(window["window_start"])
|
output["window_end"] = float(window["window_end"]) - float(window["window_start"])
|
||||||
output["num_requests"] = len(rows)
|
output["num_requests"] = stats.num_requests
|
||||||
output["sum_input_length"] = int(sum(int(row.get("input_length") or 0) for row in rows))
|
output["sum_input_length"] = stats.sum_input_length
|
||||||
output["max_input_length"] = int(
|
output["max_input_length"] = stats.max_input_length
|
||||||
max((int(row.get("input_length") or 0) for row in rows), default=0)
|
|
||||||
)
|
|
||||||
output["num_excluded_too_long"] = 0
|
output["num_excluded_too_long"] = 0
|
||||||
output["sampling_u_field"] = "sampling_u"
|
output["sampling_u_field"] = "sampling_u"
|
||||||
output["sampling_seed"] = int(sample_seed)
|
output["sampling_seed"] = int(sample_seed)
|
||||||
output["sampling_strategy"] = "fixed_uniform_score"
|
output["sampling_strategy"] = "fixed_uniform_score"
|
||||||
if rows:
|
output["first_request_ts"] = stats.first_request_ts
|
||||||
output["first_request_ts"] = float(rows[0]["timestamp"])
|
output["last_request_ts"] = stats.last_request_ts
|
||||||
output["last_request_ts"] = float(rows[-1]["timestamp"])
|
output["first_request_index"] = stats.first_request_index
|
||||||
output["first_request_index"] = 0
|
output["last_request_index"] = stats.last_request_index
|
||||||
output["last_request_index"] = len(rows) - 1
|
|
||||||
else:
|
|
||||||
output["first_request_ts"] = None
|
|
||||||
output["last_request_ts"] = None
|
|
||||||
output["first_request_index"] = None
|
|
||||||
output["last_request_index"] = None
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@@ -278,20 +316,19 @@ def main() -> int:
|
|||||||
"thinking": args.thinking_source.resolve(),
|
"thinking": args.thinking_source.resolve(),
|
||||||
}
|
}
|
||||||
windows = build_windows(source_dirs, workloads=workloads)
|
windows = build_windows(source_dirs, workloads=workloads)
|
||||||
extracted = extract_windows(windows, sample_seed=args.sample_seed)
|
stats_by_window = materialize_windows(
|
||||||
|
windows,
|
||||||
|
sample_seed=args.sample_seed,
|
||||||
|
traces_dir=traces_dir,
|
||||||
|
)
|
||||||
|
|
||||||
rendered_windows: list[dict[str, Any]] = []
|
rendered_windows: list[dict[str, Any]] = []
|
||||||
for window in windows:
|
for window in windows:
|
||||||
rows = extracted[str(window["window_id"])]
|
|
||||||
trace_filename = f"{window['window_id']}.jsonl"
|
trace_filename = f"{window['window_id']}.jsonl"
|
||||||
trace_path = traces_dir / trace_filename
|
|
||||||
with trace_path.open("w", encoding="utf-8") as handle:
|
|
||||||
for row in rows:
|
|
||||||
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
|
|
||||||
rendered_windows.append(
|
rendered_windows.append(
|
||||||
build_output_window(
|
build_output_window(
|
||||||
window,
|
window,
|
||||||
rows,
|
stats_by_window[str(window["window_id"])],
|
||||||
trace_relpath=f"traces/{trace_filename}",
|
trace_relpath=f"traces/{trace_filename}",
|
||||||
sample_seed=args.sample_seed,
|
sample_seed=args.sample_seed,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -258,7 +259,101 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
self.assertTrue(evaluations[0].passed)
|
self.assertTrue(evaluations[0].passed)
|
||||||
self.assertFalse(evaluations[1].passed)
|
self.assertFalse(evaluations[1].passed)
|
||||||
self.assertEqual(summary["slo_pass_rate"], 0.5)
|
self.assertEqual(summary["slo_pass_rate"], 0.5)
|
||||||
self.assertFalse(summary["feasible"])
|
|
||||||
|
def test_prepare_trace_windows_materializes_repo_local_assets(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
tmp_path = Path(tmp)
|
||||||
|
legacy_source = tmp_path / "legacy"
|
||||||
|
thinking_source = tmp_path / "thinking"
|
||||||
|
legacy_source.mkdir()
|
||||||
|
thinking_source.mkdir()
|
||||||
|
|
||||||
|
for filename in [
|
||||||
|
"qwen_chat_blksz_64_031109-031111",
|
||||||
|
"qwen_chat_blksz_64_031121-031123",
|
||||||
|
"qwen_chat_blksz_64_031209-031211",
|
||||||
|
"qwen_chat_blksz_64_031221-031223",
|
||||||
|
"qwen_chat_blksz_64_031309-031311",
|
||||||
|
"qwen_chat_blksz_64_031321-031323",
|
||||||
|
"qwen_chat_blksz_64_031609-031611",
|
||||||
|
"qwen_chat_blksz_64_031621-031623",
|
||||||
|
"qwen_chat_blksz_64_031709-031711",
|
||||||
|
"qwen_chat_blksz_64_031721-031723",
|
||||||
|
]:
|
||||||
|
for suffix in [".jsonl", "_prompt.jsonl"]:
|
||||||
|
path = legacy_source / f"{filename}{suffix}"
|
||||||
|
path.write_text("", encoding="utf-8")
|
||||||
|
|
||||||
|
peak_trace = legacy_source / "qwen_chat_blksz_64_031109-031111.jsonl"
|
||||||
|
peak_prompt = legacy_source / "qwen_chat_blksz_64_031109-031111_prompt.jsonl"
|
||||||
|
peak_trace.write_text(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"chat_id": "c1",
|
||||||
|
"turn": 1,
|
||||||
|
"timestamp": 3599.0,
|
||||||
|
"input_length": 10,
|
||||||
|
"output_length": 3,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"chat_id": "c2",
|
||||||
|
"turn": 2,
|
||||||
|
"timestamp": 3605.0,
|
||||||
|
"input_length": 20,
|
||||||
|
"output_length": 7,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
+ "\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
peak_prompt.write_text(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
json.dumps({"chat_id": "c1", "turn": 1, "prompt": "ignore me"}),
|
||||||
|
json.dumps({"chat_id": "c2", "turn": 2, "prompt": "real prompt"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
+ "\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
output_root = tmp_path / "trace_windows"
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"python3",
|
||||||
|
"scripts/prepare_trace_windows.py",
|
||||||
|
"--legacy-source",
|
||||||
|
str(legacy_source),
|
||||||
|
"--thinking-source",
|
||||||
|
str(thinking_source),
|
||||||
|
"--output-root",
|
||||||
|
str(output_root),
|
||||||
|
"--workloads",
|
||||||
|
"chat",
|
||||||
|
"--overwrite",
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
cwd="/home/gahow/phd/aituner",
|
||||||
|
)
|
||||||
|
|
||||||
|
windows_payload = json.loads((output_root / "windows.json").read_text(encoding="utf-8"))
|
||||||
|
windows = {item["window_id"]: item for item in windows_payload["windows"]}
|
||||||
|
self.assertIn("chat_w20260311_peak_1000", windows)
|
||||||
|
self.assertEqual(windows["chat_w20260311_peak_1000"]["num_requests"], 1)
|
||||||
|
|
||||||
|
trace_path = output_root / windows["chat_w20260311_peak_1000"]["trace_file"]
|
||||||
|
rows = [json.loads(line) for line in trace_path.read_text(encoding="utf-8").splitlines()]
|
||||||
|
self.assertEqual(len(rows), 1)
|
||||||
|
self.assertEqual(rows[0]["prompt"], "real prompt")
|
||||||
|
self.assertEqual(rows[0]["timestamp"], 5.0)
|
||||||
|
self.assertEqual(rows[0]["output_length"], 7)
|
||||||
|
self.assertIsInstance(rows[0]["sampling_u"], float)
|
||||||
|
|
||||||
def test_binary_search_max_feasible(self) -> None:
|
def test_binary_search_max_feasible(self) -> None:
|
||||||
result = binary_search_max_feasible(
|
result = binary_search_max_feasible(
|
||||||
|
|||||||
Reference in New Issue
Block a user