Support length-only trace windows
This commit is contained in:
@@ -151,10 +151,12 @@ class TraceSpec:
|
||||
timestamp_field: str
|
||||
max_concurrency: int
|
||||
max_requests_per_probe: int | None = None
|
||||
synthetic_prompt_cap_tokens: int | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
|
||||
max_requests = data.get("max_requests_per_probe")
|
||||
synthetic_prompt_cap = data.get("synthetic_prompt_cap_tokens")
|
||||
return cls(
|
||||
windows_path=_require_str(data.get("windows_path"), context="trace.windows_path"),
|
||||
window_id=_require_str(data.get("window_id"), context="trace.window_id"),
|
||||
@@ -167,6 +169,9 @@ class TraceSpec:
|
||||
data.get("max_concurrency", 64), context="trace.max_concurrency"
|
||||
),
|
||||
max_requests_per_probe=int(max_requests) if max_requests is not None else None,
|
||||
synthetic_prompt_cap_tokens=(
|
||||
int(synthetic_prompt_cap) if synthetic_prompt_cap is not None else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -81,6 +81,14 @@ def _coerce_messages(row: Mapping[str, Any]) -> list[dict[str, Any]]:
|
||||
raise TraceError("trace row is missing chat messages/prompt text")
|
||||
|
||||
|
||||
def _synthetic_prompt_from_tokens(token_count: int) -> str:
|
||||
if token_count <= 0:
|
||||
return "hello"
|
||||
# Keep it ASCII and structurally simple so the same trace can be replayed
|
||||
# on any OpenAI-compatible engine without extra tokenizer assets.
|
||||
return " ".join(["token"] * token_count)
|
||||
|
||||
|
||||
def _coerce_completion_tokens(row: Mapping[str, Any]) -> int | None:
|
||||
for key in ("max_completion_tokens", "max_tokens", "output_length", "completion_tokens"):
|
||||
value = row.get(key)
|
||||
@@ -123,9 +131,24 @@ def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[Win
|
||||
sampling_u = row.get(study.trace.u_field, 1.0)
|
||||
if isinstance(sampling_u, bool) or not isinstance(sampling_u, (int, float)):
|
||||
raise TraceError(f"trace row {idx} is missing numeric {study.trace.u_field}")
|
||||
prompt_tokens_hint = _coerce_prompt_tokens(row)
|
||||
try:
|
||||
messages = _coerce_messages(row)
|
||||
except TraceError:
|
||||
capped_prompt_tokens = prompt_tokens_hint or 0
|
||||
if study.trace.synthetic_prompt_cap_tokens is not None:
|
||||
capped_prompt_tokens = min(
|
||||
capped_prompt_tokens, study.trace.synthetic_prompt_cap_tokens
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": _synthetic_prompt_from_tokens(capped_prompt_tokens),
|
||||
}
|
||||
]
|
||||
body: dict[str, Any] = {
|
||||
"model": study.model.served_model_name,
|
||||
"messages": _coerce_messages(row),
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
@@ -141,7 +164,7 @@ def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[Win
|
||||
arrival_s=float(timestamp),
|
||||
sampling_u=float(sampling_u),
|
||||
body=body,
|
||||
prompt_tokens_hint=_coerce_prompt_tokens(row),
|
||||
prompt_tokens_hint=prompt_tokens_hint,
|
||||
completion_tokens_hint=completion_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -153,6 +153,85 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertIn("queueing_knee_by_bucket", prompt)
|
||||
self.assertTrue(study_root.exists())
|
||||
|
||||
def test_length_only_trace_rows_are_synthesized(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
trace_dir = tmp_path / "trace_windows" / "traces"
|
||||
trace_dir.mkdir(parents=True)
|
||||
trace_path = trace_dir / "chat_len_only.jsonl"
|
||||
with trace_path.open("w", encoding="utf-8") as handle:
|
||||
handle.write(
|
||||
json.dumps(
|
||||
{
|
||||
"timestamp": 0.0,
|
||||
"sampling_u": 0.1,
|
||||
"input_length": 32,
|
||||
"output_length": 16
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
windows_path = tmp_path / "trace_windows" / "windows.json"
|
||||
windows_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"windows": [
|
||||
{
|
||||
"window_id": "w1",
|
||||
"trace_type": "chat",
|
||||
"trace_file": "traces/chat_len_only.jsonl",
|
||||
"window_start": 0.0,
|
||||
"window_end": 10.0
|
||||
}
|
||||
]
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
study_path = tmp_path / "study.json"
|
||||
study_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"study_id": "study-len-only",
|
||||
"hardware": {"gpu_count": 1},
|
||||
"model": {
|
||||
"model_id": "m1",
|
||||
"served_model_name": "dummy-model"
|
||||
},
|
||||
"engine": {
|
||||
"engine_name": "vllm",
|
||||
"exec_path": "/usr/local/bin/vllm",
|
||||
"host": "127.0.0.1",
|
||||
"port": 8000,
|
||||
"ready_timeout_s": 10,
|
||||
"request_timeout_s": 10,
|
||||
"healthcheck_path": "/v1/models",
|
||||
"launch_args": [],
|
||||
"base_envs": {},
|
||||
"base_flags": {},
|
||||
"tunable_envs": [],
|
||||
"tunable_flags": []
|
||||
},
|
||||
"trace": {
|
||||
"windows_path": str(windows_path),
|
||||
"window_id": "w1",
|
||||
"max_concurrency": 1,
|
||||
"synthetic_prompt_cap_tokens": 8
|
||||
},
|
||||
"slo": {"target_pass_rate": 0.95},
|
||||
"search": {"low": 0.0, "high": 1.0, "tolerance": 0.1, "max_probes": 2, "sample_seed": 1},
|
||||
"llm": {"system_prompt": "", "max_history_trials": 1}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
study = load_study_spec(study_path)
|
||||
_, requests = load_trace_requests(study, study_spec_path=study_path)
|
||||
self.assertEqual(len(requests), 1)
|
||||
message = requests[0].body["messages"][0]["content"]
|
||||
self.assertEqual(message.count("token"), 8)
|
||||
self.assertEqual(requests[0].body["max_tokens"], 16)
|
||||
|
||||
def test_slo_evaluation_step_and_fixed_rules(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
study = load_study_spec(_write_study_assets(Path(tmp)))
|
||||
|
||||
Reference in New Issue
Block a user