Add replay time scaling for smoke tuning

This commit is contained in:
2026-04-04 22:40:49 +08:00
parent dcb972014a
commit 56fa6747d2
4 changed files with 95 additions and 2 deletions

View File

@@ -54,7 +54,8 @@
"u_field": "sampling_u",
"timestamp_field": "timestamp",
"max_concurrency": 2,
"max_requests_per_probe": 24
"max_requests_per_probe": 24,
"replay_time_scale": 0.02
},
"slo": {
"target_pass_rate": 0.95,

View File

@@ -152,6 +152,7 @@ class TraceSpec:
max_concurrency: int
max_requests_per_probe: int | None = None
synthetic_prompt_cap_tokens: int | None = None
replay_time_scale: float = 1.0
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
@@ -172,6 +173,9 @@ class TraceSpec:
synthetic_prompt_cap_tokens=(
int(synthetic_prompt_cap) if synthetic_prompt_cap is not None else None
),
replay_time_scale=_require_float(
data.get("replay_time_scale", 1.0), context="trace.replay_time_scale"
),
)

View File

@@ -134,6 +134,18 @@ def _downsample_requests(
def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[WindowRecord, list[TraceRequest]]:
window = resolve_window_record(study, study_spec_path=study_spec_path)
time_scale = float(study.trace.replay_time_scale)
if time_scale <= 0:
raise TraceError("trace.replay_time_scale must be > 0")
if time_scale != 1.0:
window = WindowRecord(
window_id=window.window_id,
trace_path=window.trace_path,
trace_type=window.trace_type,
window_start=window.window_start * time_scale,
window_end=window.window_end * time_scale,
source_payload=dict(window.source_payload),
)
requests: list[TraceRequest] = []
with window.trace_path.open("r", encoding="utf-8") as handle:
for idx, raw in enumerate(handle):
@@ -181,7 +193,7 @@ def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[Win
requests.append(
TraceRequest(
row_id=str(row.get("request_id") or row.get("id") or idx),
arrival_s=float(timestamp),
arrival_s=float(timestamp) * time_scale,
sampling_u=float(sampling_u),
body=body,
prompt_tokens_hint=prompt_tokens_hint,

View File

@@ -449,6 +449,82 @@ class CoreFlowTests(unittest.TestCase):
_, requests = load_trace_requests(study, study_spec_path=study_path)
self.assertEqual([item.row_id for item in requests], ["r0", "r2", "r5", "r7"])
def test_trace_replay_time_scale_scales_arrivals_and_window(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
trace_dir = tmp_path / "trace_windows" / "traces"
trace_dir.mkdir(parents=True)
trace_path = trace_dir / "chat_scale.jsonl"
trace_path.write_text(
json.dumps(
{
"request_id": "r1",
"timestamp": 10.0,
"sampling_u": 0.25,
"messages": [{"role": "user", "content": "hello"}],
"input_length": 16,
"output_length": 4,
}
)
+ "\n",
encoding="utf-8",
)
windows_path = tmp_path / "trace_windows" / "windows.json"
windows_path.write_text(
json.dumps(
{
"windows": [
{
"window_id": "w1",
"trace_type": "chat",
"trace_file": "traces/chat_scale.jsonl",
"window_start": 0.0,
"window_end": 100.0,
}
]
}
),
encoding="utf-8",
)
study_path = tmp_path / "study.json"
study_path.write_text(
json.dumps(
{
"study_id": "study-scale",
"hardware": {"gpu_count": 1},
"model": {"model_id": "m1", "served_model_name": "dummy-model"},
"engine": {
"engine_name": "vllm",
"exec_path": "/usr/local/bin/vllm",
"host": "127.0.0.1",
"port": 8000,
"ready_timeout_s": 10,
"request_timeout_s": 10,
"healthcheck_path": "/v1/models",
"launch_args": [],
"base_envs": {},
"base_flags": {},
"tunable_envs": [],
"tunable_flags": [],
},
"trace": {
"windows_path": str(windows_path),
"window_id": "w1",
"max_concurrency": 1,
"replay_time_scale": 0.1,
},
"slo": {"target_pass_rate": 0.95},
"search": {"low": 0.0, "high": 1.0, "tolerance": 0.1, "max_probes": 2, "sample_seed": 1},
"llm": {"system_prompt": "", "max_history_trials": 1},
}
),
encoding="utf-8",
)
study = load_study_spec(study_path)
window, requests = load_trace_requests(study, study_spec_path=study_path)
self.assertEqual(window.window_end, 10.0)
self.assertEqual(requests[0].arrival_s, 1.0)
def test_proposal_validation_and_job_emission(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)