Stream trace window materialization
This commit is contained in:
@@ -5,6 +5,7 @@ import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
@@ -186,21 +187,53 @@ def _merge_trace_and_prompt(trace_row: dict[str, Any], prompt_row: dict[str, Any
|
||||
return merged
|
||||
|
||||
|
||||
def extract_windows(windows: list[dict[str, Any]], *, sample_seed: int) -> dict[str, list[dict[str, Any]]]:
|
||||
@dataclass
|
||||
class WindowStats:
|
||||
num_requests: int = 0
|
||||
sum_input_length: int = 0
|
||||
max_input_length: int = 0
|
||||
first_request_ts: float | None = None
|
||||
last_request_ts: float | None = None
|
||||
first_request_index: int | None = None
|
||||
last_request_index: int | None = None
|
||||
|
||||
def record(self, row: dict[str, Any]) -> None:
|
||||
input_length = int(row.get("input_length") or 0)
|
||||
timestamp = float(row["timestamp"])
|
||||
if self.num_requests == 0:
|
||||
self.first_request_ts = timestamp
|
||||
self.first_request_index = 0
|
||||
self.last_request_ts = timestamp
|
||||
self.last_request_index = self.num_requests
|
||||
self.num_requests += 1
|
||||
self.sum_input_length += input_length
|
||||
self.max_input_length = max(self.max_input_length, input_length)
|
||||
|
||||
|
||||
def materialize_windows(
|
||||
windows: list[dict[str, Any]], *, sample_seed: int, traces_dir: Path
|
||||
) -> dict[str, WindowStats]:
|
||||
grouped: dict[tuple[Path, Path], list[dict[str, Any]]] = {}
|
||||
for window in windows:
|
||||
trace_path = Path(window["source_trace_path"])
|
||||
prompt_path = Path(window["source_prompt_path"])
|
||||
grouped.setdefault((trace_path, prompt_path), []).append(window)
|
||||
|
||||
extracted: dict[str, list[dict[str, Any]]] = {str(window["window_id"]): [] for window in windows}
|
||||
for (trace_path, prompt_path), bucket in grouped.items():
|
||||
stats_by_window = {str(window["window_id"]): WindowStats() for window in windows}
|
||||
handles: dict[str, Any] = {}
|
||||
try:
|
||||
for window in windows:
|
||||
window_id = str(window["window_id"])
|
||||
handles[window_id] = (traces_dir / f"{window_id}.jsonl").open("w", encoding="utf-8")
|
||||
|
||||
for trace_path, prompt_path in sorted(grouped.keys()):
|
||||
bucket = grouped[(trace_path, prompt_path)]
|
||||
bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"])))
|
||||
matched_rows = 0
|
||||
with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle:
|
||||
for index, (trace_raw, prompt_raw) in enumerate(zip(trace_handle, prompt_handle)):
|
||||
for trace_raw, prompt_raw in zip(trace_handle, prompt_handle):
|
||||
trace_raw = trace_raw.strip()
|
||||
prompt_raw = prompt_raw.strip()
|
||||
if not trace_raw or not prompt_raw:
|
||||
if not trace_raw:
|
||||
continue
|
||||
trace_row = json.loads(trace_raw)
|
||||
timestamp = float(trace_row.get("timestamp") or 0.0)
|
||||
@@ -213,25 +246,38 @@ def extract_windows(windows: list[dict[str, Any]], *, sample_seed: int) -> dict[
|
||||
break
|
||||
if matched_window is None:
|
||||
continue
|
||||
prompt_raw = prompt_raw.strip()
|
||||
if not prompt_raw:
|
||||
continue
|
||||
prompt_row = json.loads(prompt_raw)
|
||||
merged = _merge_trace_and_prompt(trace_row, prompt_row)
|
||||
window_id = str(matched_window["window_id"])
|
||||
out = dict(merged)
|
||||
start = float(matched_window["window_start"])
|
||||
out["source_timestamp"] = timestamp
|
||||
out["timestamp"] = timestamp - start
|
||||
out["sampling_u"] = stable_uniform(
|
||||
seed=sample_seed,
|
||||
window_id=str(matched_window["window_id"]),
|
||||
index=len(extracted[str(matched_window["window_id"])]),
|
||||
window_id=window_id,
|
||||
index=stats_by_window[window_id].num_requests,
|
||||
row=merged,
|
||||
)
|
||||
extracted[str(matched_window["window_id"])].append(out)
|
||||
return extracted
|
||||
handles[window_id].write(json.dumps(out, ensure_ascii=False) + "\n")
|
||||
stats_by_window[window_id].record(out)
|
||||
matched_rows += 1
|
||||
print(
|
||||
f"materialized {trace_path.name} -> matched_rows={matched_rows}",
|
||||
flush=True,
|
||||
)
|
||||
finally:
|
||||
for handle in handles.values():
|
||||
handle.close()
|
||||
return stats_by_window
|
||||
|
||||
|
||||
def build_output_window(
|
||||
window: dict[str, Any],
|
||||
rows: list[dict[str, Any]],
|
||||
stats: WindowStats,
|
||||
trace_relpath: str,
|
||||
*,
|
||||
sample_seed: int,
|
||||
@@ -240,25 +286,17 @@ def build_output_window(
|
||||
output["trace_file"] = trace_relpath
|
||||
output["window_start"] = 0.0
|
||||
output["window_end"] = float(window["window_end"]) - float(window["window_start"])
|
||||
output["num_requests"] = len(rows)
|
||||
output["sum_input_length"] = int(sum(int(row.get("input_length") or 0) for row in rows))
|
||||
output["max_input_length"] = int(
|
||||
max((int(row.get("input_length") or 0) for row in rows), default=0)
|
||||
)
|
||||
output["num_requests"] = stats.num_requests
|
||||
output["sum_input_length"] = stats.sum_input_length
|
||||
output["max_input_length"] = stats.max_input_length
|
||||
output["num_excluded_too_long"] = 0
|
||||
output["sampling_u_field"] = "sampling_u"
|
||||
output["sampling_seed"] = int(sample_seed)
|
||||
output["sampling_strategy"] = "fixed_uniform_score"
|
||||
if rows:
|
||||
output["first_request_ts"] = float(rows[0]["timestamp"])
|
||||
output["last_request_ts"] = float(rows[-1]["timestamp"])
|
||||
output["first_request_index"] = 0
|
||||
output["last_request_index"] = len(rows) - 1
|
||||
else:
|
||||
output["first_request_ts"] = None
|
||||
output["last_request_ts"] = None
|
||||
output["first_request_index"] = None
|
||||
output["last_request_index"] = None
|
||||
output["first_request_ts"] = stats.first_request_ts
|
||||
output["last_request_ts"] = stats.last_request_ts
|
||||
output["first_request_index"] = stats.first_request_index
|
||||
output["last_request_index"] = stats.last_request_index
|
||||
return output
|
||||
|
||||
|
||||
@@ -278,20 +316,19 @@ def main() -> int:
|
||||
"thinking": args.thinking_source.resolve(),
|
||||
}
|
||||
windows = build_windows(source_dirs, workloads=workloads)
|
||||
extracted = extract_windows(windows, sample_seed=args.sample_seed)
|
||||
stats_by_window = materialize_windows(
|
||||
windows,
|
||||
sample_seed=args.sample_seed,
|
||||
traces_dir=traces_dir,
|
||||
)
|
||||
|
||||
rendered_windows: list[dict[str, Any]] = []
|
||||
for window in windows:
|
||||
rows = extracted[str(window["window_id"])]
|
||||
trace_filename = f"{window['window_id']}.jsonl"
|
||||
trace_path = traces_dir / trace_filename
|
||||
with trace_path.open("w", encoding="utf-8") as handle:
|
||||
for row in rows:
|
||||
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
rendered_windows.append(
|
||||
build_output_window(
|
||||
window,
|
||||
rows,
|
||||
stats_by_window[str(window["window_id"])],
|
||||
trace_relpath=f"traces/{trace_filename}",
|
||||
sample_seed=args.sample_seed,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
@@ -258,7 +259,101 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertTrue(evaluations[0].passed)
|
||||
self.assertFalse(evaluations[1].passed)
|
||||
self.assertEqual(summary["slo_pass_rate"], 0.5)
|
||||
self.assertFalse(summary["feasible"])
|
||||
|
||||
def test_prepare_trace_windows_materializes_repo_local_assets(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
legacy_source = tmp_path / "legacy"
|
||||
thinking_source = tmp_path / "thinking"
|
||||
legacy_source.mkdir()
|
||||
thinking_source.mkdir()
|
||||
|
||||
for filename in [
|
||||
"qwen_chat_blksz_64_031109-031111",
|
||||
"qwen_chat_blksz_64_031121-031123",
|
||||
"qwen_chat_blksz_64_031209-031211",
|
||||
"qwen_chat_blksz_64_031221-031223",
|
||||
"qwen_chat_blksz_64_031309-031311",
|
||||
"qwen_chat_blksz_64_031321-031323",
|
||||
"qwen_chat_blksz_64_031609-031611",
|
||||
"qwen_chat_blksz_64_031621-031623",
|
||||
"qwen_chat_blksz_64_031709-031711",
|
||||
"qwen_chat_blksz_64_031721-031723",
|
||||
]:
|
||||
for suffix in [".jsonl", "_prompt.jsonl"]:
|
||||
path = legacy_source / f"{filename}{suffix}"
|
||||
path.write_text("", encoding="utf-8")
|
||||
|
||||
peak_trace = legacy_source / "qwen_chat_blksz_64_031109-031111.jsonl"
|
||||
peak_prompt = legacy_source / "qwen_chat_blksz_64_031109-031111_prompt.jsonl"
|
||||
peak_trace.write_text(
|
||||
"\n".join(
|
||||
[
|
||||
json.dumps(
|
||||
{
|
||||
"chat_id": "c1",
|
||||
"turn": 1,
|
||||
"timestamp": 3599.0,
|
||||
"input_length": 10,
|
||||
"output_length": 3,
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"chat_id": "c2",
|
||||
"turn": 2,
|
||||
"timestamp": 3605.0,
|
||||
"input_length": 20,
|
||||
"output_length": 7,
|
||||
}
|
||||
),
|
||||
]
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
peak_prompt.write_text(
|
||||
"\n".join(
|
||||
[
|
||||
json.dumps({"chat_id": "c1", "turn": 1, "prompt": "ignore me"}),
|
||||
json.dumps({"chat_id": "c2", "turn": 2, "prompt": "real prompt"}),
|
||||
]
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
output_root = tmp_path / "trace_windows"
|
||||
subprocess.run(
|
||||
[
|
||||
"python3",
|
||||
"scripts/prepare_trace_windows.py",
|
||||
"--legacy-source",
|
||||
str(legacy_source),
|
||||
"--thinking-source",
|
||||
str(thinking_source),
|
||||
"--output-root",
|
||||
str(output_root),
|
||||
"--workloads",
|
||||
"chat",
|
||||
"--overwrite",
|
||||
],
|
||||
check=True,
|
||||
cwd="/home/gahow/phd/aituner",
|
||||
)
|
||||
|
||||
windows_payload = json.loads((output_root / "windows.json").read_text(encoding="utf-8"))
|
||||
windows = {item["window_id"]: item for item in windows_payload["windows"]}
|
||||
self.assertIn("chat_w20260311_peak_1000", windows)
|
||||
self.assertEqual(windows["chat_w20260311_peak_1000"]["num_requests"], 1)
|
||||
|
||||
trace_path = output_root / windows["chat_w20260311_peak_1000"]["trace_file"]
|
||||
rows = [json.loads(line) for line in trace_path.read_text(encoding="utf-8").splitlines()]
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(rows[0]["prompt"], "real prompt")
|
||||
self.assertEqual(rows[0]["timestamp"], 5.0)
|
||||
self.assertEqual(rows[0]["output_length"], 7)
|
||||
self.assertIsInstance(rows[0]["sampling_u"], float)
|
||||
|
||||
def test_binary_search_max_feasible(self) -> None:
|
||||
result = binary_search_max_feasible(
|
||||
|
||||
Reference in New Issue
Block a user