Add replay time scaling for smoke tuning
This commit is contained in:
@@ -54,7 +54,8 @@
|
||||
"u_field": "sampling_u",
|
||||
"timestamp_field": "timestamp",
|
||||
"max_concurrency": 2,
|
||||
"max_requests_per_probe": 24
|
||||
"max_requests_per_probe": 24,
|
||||
"replay_time_scale": 0.02
|
||||
},
|
||||
"slo": {
|
||||
"target_pass_rate": 0.95,
|
||||
|
||||
@@ -152,6 +152,7 @@ class TraceSpec:
|
||||
max_concurrency: int
|
||||
max_requests_per_probe: int | None = None
|
||||
synthetic_prompt_cap_tokens: int | None = None
|
||||
replay_time_scale: float = 1.0
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
|
||||
@@ -172,6 +173,9 @@ class TraceSpec:
|
||||
synthetic_prompt_cap_tokens=(
|
||||
int(synthetic_prompt_cap) if synthetic_prompt_cap is not None else None
|
||||
),
|
||||
replay_time_scale=_require_float(
|
||||
data.get("replay_time_scale", 1.0), context="trace.replay_time_scale"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -134,6 +134,18 @@ def _downsample_requests(
|
||||
|
||||
def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[WindowRecord, list[TraceRequest]]:
|
||||
window = resolve_window_record(study, study_spec_path=study_spec_path)
|
||||
time_scale = float(study.trace.replay_time_scale)
|
||||
if time_scale <= 0:
|
||||
raise TraceError("trace.replay_time_scale must be > 0")
|
||||
if time_scale != 1.0:
|
||||
window = WindowRecord(
|
||||
window_id=window.window_id,
|
||||
trace_path=window.trace_path,
|
||||
trace_type=window.trace_type,
|
||||
window_start=window.window_start * time_scale,
|
||||
window_end=window.window_end * time_scale,
|
||||
source_payload=dict(window.source_payload),
|
||||
)
|
||||
requests: list[TraceRequest] = []
|
||||
with window.trace_path.open("r", encoding="utf-8") as handle:
|
||||
for idx, raw in enumerate(handle):
|
||||
@@ -181,7 +193,7 @@ def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[Win
|
||||
requests.append(
|
||||
TraceRequest(
|
||||
row_id=str(row.get("request_id") or row.get("id") or idx),
|
||||
arrival_s=float(timestamp),
|
||||
arrival_s=float(timestamp) * time_scale,
|
||||
sampling_u=float(sampling_u),
|
||||
body=body,
|
||||
prompt_tokens_hint=prompt_tokens_hint,
|
||||
|
||||
@@ -449,6 +449,82 @@ class CoreFlowTests(unittest.TestCase):
|
||||
_, requests = load_trace_requests(study, study_spec_path=study_path)
|
||||
self.assertEqual([item.row_id for item in requests], ["r0", "r2", "r5", "r7"])
|
||||
|
||||
def test_trace_replay_time_scale_scales_arrivals_and_window(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
trace_dir = tmp_path / "trace_windows" / "traces"
|
||||
trace_dir.mkdir(parents=True)
|
||||
trace_path = trace_dir / "chat_scale.jsonl"
|
||||
trace_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"request_id": "r1",
|
||||
"timestamp": 10.0,
|
||||
"sampling_u": 0.25,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"input_length": 16,
|
||||
"output_length": 4,
|
||||
}
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
windows_path = tmp_path / "trace_windows" / "windows.json"
|
||||
windows_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"windows": [
|
||||
{
|
||||
"window_id": "w1",
|
||||
"trace_type": "chat",
|
||||
"trace_file": "traces/chat_scale.jsonl",
|
||||
"window_start": 0.0,
|
||||
"window_end": 100.0,
|
||||
}
|
||||
]
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
study_path = tmp_path / "study.json"
|
||||
study_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"study_id": "study-scale",
|
||||
"hardware": {"gpu_count": 1},
|
||||
"model": {"model_id": "m1", "served_model_name": "dummy-model"},
|
||||
"engine": {
|
||||
"engine_name": "vllm",
|
||||
"exec_path": "/usr/local/bin/vllm",
|
||||
"host": "127.0.0.1",
|
||||
"port": 8000,
|
||||
"ready_timeout_s": 10,
|
||||
"request_timeout_s": 10,
|
||||
"healthcheck_path": "/v1/models",
|
||||
"launch_args": [],
|
||||
"base_envs": {},
|
||||
"base_flags": {},
|
||||
"tunable_envs": [],
|
||||
"tunable_flags": [],
|
||||
},
|
||||
"trace": {
|
||||
"windows_path": str(windows_path),
|
||||
"window_id": "w1",
|
||||
"max_concurrency": 1,
|
||||
"replay_time_scale": 0.1,
|
||||
},
|
||||
"slo": {"target_pass_rate": 0.95},
|
||||
"search": {"low": 0.0, "high": 1.0, "tolerance": 0.1, "max_probes": 2, "sample_seed": 1},
|
||||
"llm": {"system_prompt": "", "max_history_trials": 1},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
study = load_study_spec(study_path)
|
||||
window, requests = load_trace_requests(study, study_spec_path=study_path)
|
||||
self.assertEqual(window.window_end, 10.0)
|
||||
self.assertEqual(requests[0].arrival_s, 1.0)
|
||||
|
||||
def test_proposal_validation_and_job_emission(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
|
||||
Reference in New Issue
Block a user