From 56fa6747d2582f2b1866205606f3dc8ab330adae Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Sat, 4 Apr 2026 22:40:49 +0800 Subject: [PATCH] Add replay time scaling for smoke tuning --- configs/examples/dash0_smoke_study.json | 3 +- src/aituner/spec.py | 4 ++ src/aituner/trace.py | 14 ++++- tests/test_core_flow.py | 76 +++++++++++++++++++++++++ 4 files changed, 95 insertions(+), 2 deletions(-) diff --git a/configs/examples/dash0_smoke_study.json b/configs/examples/dash0_smoke_study.json index c10c494..868ae70 100644 --- a/configs/examples/dash0_smoke_study.json +++ b/configs/examples/dash0_smoke_study.json @@ -54,7 +54,8 @@ "u_field": "sampling_u", "timestamp_field": "timestamp", "max_concurrency": 2, - "max_requests_per_probe": 24 + "max_requests_per_probe": 24, + "replay_time_scale": 0.02 }, "slo": { "target_pass_rate": 0.95, diff --git a/src/aituner/spec.py b/src/aituner/spec.py index e39c0a1..895aa91 100644 --- a/src/aituner/spec.py +++ b/src/aituner/spec.py @@ -152,6 +152,7 @@ class TraceSpec: max_concurrency: int max_requests_per_probe: int | None = None synthetic_prompt_cap_tokens: int | None = None + replay_time_scale: float = 1.0 @classmethod def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec": @@ -172,6 +173,9 @@ class TraceSpec: synthetic_prompt_cap_tokens=( int(synthetic_prompt_cap) if synthetic_prompt_cap is not None else None ), + replay_time_scale=_require_float( + data.get("replay_time_scale", 1.0), context="trace.replay_time_scale" + ), ) diff --git a/src/aituner/trace.py b/src/aituner/trace.py index 398972f..36c7905 100644 --- a/src/aituner/trace.py +++ b/src/aituner/trace.py @@ -134,6 +134,18 @@ def _downsample_requests( def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[WindowRecord, list[TraceRequest]]: window = resolve_window_record(study, study_spec_path=study_spec_path) + time_scale = float(study.trace.replay_time_scale) + if time_scale <= 0: + raise TraceError("trace.replay_time_scale must be > 0") + if time_scale != 1.0: + window = WindowRecord( + window_id=window.window_id, + trace_path=window.trace_path, + trace_type=window.trace_type, + window_start=window.window_start * time_scale, + window_end=window.window_end * time_scale, + source_payload=dict(window.source_payload), + ) requests: list[TraceRequest] = [] with window.trace_path.open("r", encoding="utf-8") as handle: for idx, raw in enumerate(handle): @@ -181,7 +193,7 @@ def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[Win requests.append( TraceRequest( row_id=str(row.get("request_id") or row.get("id") or idx), - arrival_s=float(timestamp), + arrival_s=float(timestamp) * time_scale, sampling_u=float(sampling_u), body=body, prompt_tokens_hint=prompt_tokens_hint, diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index d614b8c..b514ee6 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -449,6 +449,82 @@ class CoreFlowTests(unittest.TestCase): _, requests = load_trace_requests(study, study_spec_path=study_path) self.assertEqual([item.row_id for item in requests], ["r0", "r2", "r5", "r7"]) + def test_trace_replay_time_scale_scales_arrivals_and_window(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + trace_dir = tmp_path / "trace_windows" / "traces" + trace_dir.mkdir(parents=True) + trace_path = trace_dir / "chat_scale.jsonl" + trace_path.write_text( + json.dumps( + { + "request_id": "r1", + "timestamp": 10.0, + "sampling_u": 0.25, + "messages": [{"role": "user", "content": "hello"}], + "input_length": 16, + "output_length": 4, + } + ) + + "\n", + encoding="utf-8", + ) + windows_path = tmp_path / "trace_windows" / "windows.json" + windows_path.write_text( + json.dumps( + { + "windows": [ + { + "window_id": "w1", + "trace_type": "chat", + "trace_file": "traces/chat_scale.jsonl", + "window_start": 0.0, + "window_end": 100.0, + } + ] + } + ), + encoding="utf-8", + ) + study_path = tmp_path / "study.json" + study_path.write_text( + json.dumps( + { + "study_id": "study-scale", + "hardware": {"gpu_count": 1}, + "model": {"model_id": "m1", "served_model_name": "dummy-model"}, + "engine": { + "engine_name": "vllm", + "exec_path": "/usr/local/bin/vllm", + "host": "127.0.0.1", + "port": 8000, + "ready_timeout_s": 10, + "request_timeout_s": 10, + "healthcheck_path": "/v1/models", + "launch_args": [], + "base_envs": {}, + "base_flags": {}, + "tunable_envs": [], + "tunable_flags": [], + }, + "trace": { + "windows_path": str(windows_path), + "window_id": "w1", + "max_concurrency": 1, + "replay_time_scale": 0.1, + }, + "slo": {"target_pass_rate": 0.95}, + "search": {"low": 0.0, "high": 1.0, "tolerance": 0.1, "max_probes": 2, "sample_seed": 1}, + "llm": {"system_prompt": "", "max_history_trials": 1}, + } + ), + encoding="utf-8", + ) + study = load_study_spec(study_path) + window, requests = load_trace_requests(study, study_spec_path=study_path) + self.assertEqual(window.window_end, 10.0) + self.assertEqual(requests[0].arrival_s, 1.0) + def test_proposal_validation_and_job_emission(self) -> None: with tempfile.TemporaryDirectory() as tmp: tmp_path = Path(tmp)