Support length-only trace windows

This commit is contained in:
2026-04-04 21:31:11 +08:00
parent cdcca1d9d7
commit 647d241725
3 changed files with 109 additions and 2 deletions

View File

@@ -153,6 +153,85 @@ class CoreFlowTests(unittest.TestCase):
self.assertIn("queueing_knee_by_bucket", prompt)
self.assertTrue(study_root.exists())
def test_length_only_trace_rows_are_synthesized(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
trace_dir = tmp_path / "trace_windows" / "traces"
trace_dir.mkdir(parents=True)
trace_path = trace_dir / "chat_len_only.jsonl"
with trace_path.open("w", encoding="utf-8") as handle:
handle.write(
json.dumps(
{
"timestamp": 0.0,
"sampling_u": 0.1,
"input_length": 32,
"output_length": 16
}
)
+ "\n"
)
windows_path = tmp_path / "trace_windows" / "windows.json"
windows_path.write_text(
json.dumps(
{
"windows": [
{
"window_id": "w1",
"trace_type": "chat",
"trace_file": "traces/chat_len_only.jsonl",
"window_start": 0.0,
"window_end": 10.0
}
]
}
),
encoding="utf-8",
)
study_path = tmp_path / "study.json"
study_path.write_text(
json.dumps(
{
"study_id": "study-len-only",
"hardware": {"gpu_count": 1},
"model": {
"model_id": "m1",
"served_model_name": "dummy-model"
},
"engine": {
"engine_name": "vllm",
"exec_path": "/usr/local/bin/vllm",
"host": "127.0.0.1",
"port": 8000,
"ready_timeout_s": 10,
"request_timeout_s": 10,
"healthcheck_path": "/v1/models",
"launch_args": [],
"base_envs": {},
"base_flags": {},
"tunable_envs": [],
"tunable_flags": []
},
"trace": {
"windows_path": str(windows_path),
"window_id": "w1",
"max_concurrency": 1,
"synthetic_prompt_cap_tokens": 8
},
"slo": {"target_pass_rate": 0.95},
"search": {"low": 0.0, "high": 1.0, "tolerance": 0.1, "max_probes": 2, "sample_seed": 1},
"llm": {"system_prompt": "", "max_history_trials": 1}
}
),
encoding="utf-8",
)
study = load_study_spec(study_path)
_, requests = load_trace_requests(study, study_spec_path=study_path)
self.assertEqual(len(requests), 1)
message = requests[0].body["messages"][0]["content"]
self.assertEqual(message.count("token"), 8)
self.assertEqual(requests[0].body["max_tokens"], 16)
def test_slo_evaluation_step_and_fixed_rules(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
study = load_study_spec(_write_study_assets(Path(tmp)))