Support length-only trace windows

This commit is contained in:
2026-04-04 21:31:11 +08:00
parent cdcca1d9d7
commit 647d241725
3 changed files with 109 additions and 2 deletions

View File

@@ -151,10 +151,12 @@ class TraceSpec:
timestamp_field: str
max_concurrency: int
max_requests_per_probe: int | None = None
synthetic_prompt_cap_tokens: int | None = None
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
max_requests = data.get("max_requests_per_probe")
synthetic_prompt_cap = data.get("synthetic_prompt_cap_tokens")
return cls(
windows_path=_require_str(data.get("windows_path"), context="trace.windows_path"),
window_id=_require_str(data.get("window_id"), context="trace.window_id"),
@@ -167,6 +169,9 @@ class TraceSpec:
data.get("max_concurrency", 64), context="trace.max_concurrency"
),
max_requests_per_probe=int(max_requests) if max_requests is not None else None,
synthetic_prompt_cap_tokens=(
int(synthetic_prompt_cap) if synthetic_prompt_cap is not None else None
),
)

View File

@@ -81,6 +81,14 @@ def _coerce_messages(row: Mapping[str, Any]) -> list[dict[str, Any]]:
raise TraceError("trace row is missing chat messages/prompt text")
def _synthetic_prompt_from_tokens(token_count: int) -> str:
if token_count <= 0:
return "hello"
# Keep it ASCII and structurally simple so the same trace can be replayed
# on any OpenAI-compatible engine without extra tokenizer assets.
return " ".join(["token"] * token_count)
def _coerce_completion_tokens(row: Mapping[str, Any]) -> int | None:
for key in ("max_completion_tokens", "max_tokens", "output_length", "completion_tokens"):
value = row.get(key)
@@ -123,9 +131,24 @@ def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[Win
sampling_u = row.get(study.trace.u_field, 1.0)
if isinstance(sampling_u, bool) or not isinstance(sampling_u, (int, float)):
raise TraceError(f"trace row {idx} is missing numeric {study.trace.u_field}")
prompt_tokens_hint = _coerce_prompt_tokens(row)
try:
messages = _coerce_messages(row)
except TraceError:
capped_prompt_tokens = prompt_tokens_hint or 0
if study.trace.synthetic_prompt_cap_tokens is not None:
capped_prompt_tokens = min(
capped_prompt_tokens, study.trace.synthetic_prompt_cap_tokens
)
messages = [
{
"role": "user",
"content": _synthetic_prompt_from_tokens(capped_prompt_tokens),
}
]
body: dict[str, Any] = {
"model": study.model.served_model_name,
"messages": _coerce_messages(row),
"messages": messages,
"stream": True,
"stream_options": {"include_usage": True},
}
@@ -141,7 +164,7 @@ def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[Win
arrival_s=float(timestamp),
sampling_u=float(sampling_u),
body=body,
prompt_tokens_hint=_coerce_prompt_tokens(row),
prompt_tokens_hint=prompt_tokens_hint,
completion_tokens_hint=completion_tokens,
)
)

View File

@@ -153,6 +153,85 @@ class CoreFlowTests(unittest.TestCase):
self.assertIn("queueing_knee_by_bucket", prompt)
self.assertTrue(study_root.exists())
def test_length_only_trace_rows_are_synthesized(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
trace_dir = tmp_path / "trace_windows" / "traces"
trace_dir.mkdir(parents=True)
trace_path = trace_dir / "chat_len_only.jsonl"
with trace_path.open("w", encoding="utf-8") as handle:
handle.write(
json.dumps(
{
"timestamp": 0.0,
"sampling_u": 0.1,
"input_length": 32,
"output_length": 16
}
)
+ "\n"
)
windows_path = tmp_path / "trace_windows" / "windows.json"
windows_path.write_text(
json.dumps(
{
"windows": [
{
"window_id": "w1",
"trace_type": "chat",
"trace_file": "traces/chat_len_only.jsonl",
"window_start": 0.0,
"window_end": 10.0
}
]
}
),
encoding="utf-8",
)
study_path = tmp_path / "study.json"
study_path.write_text(
json.dumps(
{
"study_id": "study-len-only",
"hardware": {"gpu_count": 1},
"model": {
"model_id": "m1",
"served_model_name": "dummy-model"
},
"engine": {
"engine_name": "vllm",
"exec_path": "/usr/local/bin/vllm",
"host": "127.0.0.1",
"port": 8000,
"ready_timeout_s": 10,
"request_timeout_s": 10,
"healthcheck_path": "/v1/models",
"launch_args": [],
"base_envs": {},
"base_flags": {},
"tunable_envs": [],
"tunable_flags": []
},
"trace": {
"windows_path": str(windows_path),
"window_id": "w1",
"max_concurrency": 1,
"synthetic_prompt_cap_tokens": 8
},
"slo": {"target_pass_rate": 0.95},
"search": {"low": 0.0, "high": 1.0, "tolerance": 0.1, "max_probes": 2, "sample_seed": 1},
"llm": {"system_prompt": "", "max_history_trials": 1}
}
),
encoding="utf-8",
)
study = load_study_spec(study_path)
_, requests = load_trace_requests(study, study_spec_path=study_path)
self.assertEqual(len(requests), 1)
message = requests[0].body["messages"][0]["content"]
self.assertEqual(message.count("token"), 8)
self.assertEqual(requests[0].body["max_tokens"], 16)
def test_slo_evaluation_step_and_fixed_rules(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
study = load_study_spec(_write_study_assets(Path(tmp)))