Support length-only trace windows
This commit is contained in:
@@ -151,10 +151,12 @@ class TraceSpec:
|
|||||||
timestamp_field: str
|
timestamp_field: str
|
||||||
max_concurrency: int
|
max_concurrency: int
|
||||||
max_requests_per_probe: int | None = None
|
max_requests_per_probe: int | None = None
|
||||||
|
synthetic_prompt_cap_tokens: int | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
|
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
|
||||||
max_requests = data.get("max_requests_per_probe")
|
max_requests = data.get("max_requests_per_probe")
|
||||||
|
synthetic_prompt_cap = data.get("synthetic_prompt_cap_tokens")
|
||||||
return cls(
|
return cls(
|
||||||
windows_path=_require_str(data.get("windows_path"), context="trace.windows_path"),
|
windows_path=_require_str(data.get("windows_path"), context="trace.windows_path"),
|
||||||
window_id=_require_str(data.get("window_id"), context="trace.window_id"),
|
window_id=_require_str(data.get("window_id"), context="trace.window_id"),
|
||||||
@@ -167,6 +169,9 @@ class TraceSpec:
|
|||||||
data.get("max_concurrency", 64), context="trace.max_concurrency"
|
data.get("max_concurrency", 64), context="trace.max_concurrency"
|
||||||
),
|
),
|
||||||
max_requests_per_probe=int(max_requests) if max_requests is not None else None,
|
max_requests_per_probe=int(max_requests) if max_requests is not None else None,
|
||||||
|
synthetic_prompt_cap_tokens=(
|
||||||
|
int(synthetic_prompt_cap) if synthetic_prompt_cap is not None else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -81,6 +81,14 @@ def _coerce_messages(row: Mapping[str, Any]) -> list[dict[str, Any]]:
|
|||||||
raise TraceError("trace row is missing chat messages/prompt text")
|
raise TraceError("trace row is missing chat messages/prompt text")
|
||||||
|
|
||||||
|
|
||||||
|
def _synthetic_prompt_from_tokens(token_count: int) -> str:
|
||||||
|
if token_count <= 0:
|
||||||
|
return "hello"
|
||||||
|
# Keep it ASCII and structurally simple so the same trace can be replayed
|
||||||
|
# on any OpenAI-compatible engine without extra tokenizer assets.
|
||||||
|
return " ".join(["token"] * token_count)
|
||||||
|
|
||||||
|
|
||||||
def _coerce_completion_tokens(row: Mapping[str, Any]) -> int | None:
|
def _coerce_completion_tokens(row: Mapping[str, Any]) -> int | None:
|
||||||
for key in ("max_completion_tokens", "max_tokens", "output_length", "completion_tokens"):
|
for key in ("max_completion_tokens", "max_tokens", "output_length", "completion_tokens"):
|
||||||
value = row.get(key)
|
value = row.get(key)
|
||||||
@@ -123,9 +131,24 @@ def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[Win
|
|||||||
sampling_u = row.get(study.trace.u_field, 1.0)
|
sampling_u = row.get(study.trace.u_field, 1.0)
|
||||||
if isinstance(sampling_u, bool) or not isinstance(sampling_u, (int, float)):
|
if isinstance(sampling_u, bool) or not isinstance(sampling_u, (int, float)):
|
||||||
raise TraceError(f"trace row {idx} is missing numeric {study.trace.u_field}")
|
raise TraceError(f"trace row {idx} is missing numeric {study.trace.u_field}")
|
||||||
|
prompt_tokens_hint = _coerce_prompt_tokens(row)
|
||||||
|
try:
|
||||||
|
messages = _coerce_messages(row)
|
||||||
|
except TraceError:
|
||||||
|
capped_prompt_tokens = prompt_tokens_hint or 0
|
||||||
|
if study.trace.synthetic_prompt_cap_tokens is not None:
|
||||||
|
capped_prompt_tokens = min(
|
||||||
|
capped_prompt_tokens, study.trace.synthetic_prompt_cap_tokens
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": _synthetic_prompt_from_tokens(capped_prompt_tokens),
|
||||||
|
}
|
||||||
|
]
|
||||||
body: dict[str, Any] = {
|
body: dict[str, Any] = {
|
||||||
"model": study.model.served_model_name,
|
"model": study.model.served_model_name,
|
||||||
"messages": _coerce_messages(row),
|
"messages": messages,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"stream_options": {"include_usage": True},
|
"stream_options": {"include_usage": True},
|
||||||
}
|
}
|
||||||
@@ -141,7 +164,7 @@ def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[Win
|
|||||||
arrival_s=float(timestamp),
|
arrival_s=float(timestamp),
|
||||||
sampling_u=float(sampling_u),
|
sampling_u=float(sampling_u),
|
||||||
body=body,
|
body=body,
|
||||||
prompt_tokens_hint=_coerce_prompt_tokens(row),
|
prompt_tokens_hint=prompt_tokens_hint,
|
||||||
completion_tokens_hint=completion_tokens,
|
completion_tokens_hint=completion_tokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -153,6 +153,85 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
self.assertIn("queueing_knee_by_bucket", prompt)
|
self.assertIn("queueing_knee_by_bucket", prompt)
|
||||||
self.assertTrue(study_root.exists())
|
self.assertTrue(study_root.exists())
|
||||||
|
|
||||||
|
def test_length_only_trace_rows_are_synthesized(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
tmp_path = Path(tmp)
|
||||||
|
trace_dir = tmp_path / "trace_windows" / "traces"
|
||||||
|
trace_dir.mkdir(parents=True)
|
||||||
|
trace_path = trace_dir / "chat_len_only.jsonl"
|
||||||
|
with trace_path.open("w", encoding="utf-8") as handle:
|
||||||
|
handle.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"timestamp": 0.0,
|
||||||
|
"sampling_u": 0.1,
|
||||||
|
"input_length": 32,
|
||||||
|
"output_length": 16
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
windows_path = tmp_path / "trace_windows" / "windows.json"
|
||||||
|
windows_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"windows": [
|
||||||
|
{
|
||||||
|
"window_id": "w1",
|
||||||
|
"trace_type": "chat",
|
||||||
|
"trace_file": "traces/chat_len_only.jsonl",
|
||||||
|
"window_start": 0.0,
|
||||||
|
"window_end": 10.0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
study_path = tmp_path / "study.json"
|
||||||
|
study_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"study_id": "study-len-only",
|
||||||
|
"hardware": {"gpu_count": 1},
|
||||||
|
"model": {
|
||||||
|
"model_id": "m1",
|
||||||
|
"served_model_name": "dummy-model"
|
||||||
|
},
|
||||||
|
"engine": {
|
||||||
|
"engine_name": "vllm",
|
||||||
|
"exec_path": "/usr/local/bin/vllm",
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": 8000,
|
||||||
|
"ready_timeout_s": 10,
|
||||||
|
"request_timeout_s": 10,
|
||||||
|
"healthcheck_path": "/v1/models",
|
||||||
|
"launch_args": [],
|
||||||
|
"base_envs": {},
|
||||||
|
"base_flags": {},
|
||||||
|
"tunable_envs": [],
|
||||||
|
"tunable_flags": []
|
||||||
|
},
|
||||||
|
"trace": {
|
||||||
|
"windows_path": str(windows_path),
|
||||||
|
"window_id": "w1",
|
||||||
|
"max_concurrency": 1,
|
||||||
|
"synthetic_prompt_cap_tokens": 8
|
||||||
|
},
|
||||||
|
"slo": {"target_pass_rate": 0.95},
|
||||||
|
"search": {"low": 0.0, "high": 1.0, "tolerance": 0.1, "max_probes": 2, "sample_seed": 1},
|
||||||
|
"llm": {"system_prompt": "", "max_history_trials": 1}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
study = load_study_spec(study_path)
|
||||||
|
_, requests = load_trace_requests(study, study_spec_path=study_path)
|
||||||
|
self.assertEqual(len(requests), 1)
|
||||||
|
message = requests[0].body["messages"][0]["content"]
|
||||||
|
self.assertEqual(message.count("token"), 8)
|
||||||
|
self.assertEqual(requests[0].body["max_tokens"], 16)
|
||||||
|
|
||||||
def test_slo_evaluation_step_and_fixed_rules(self) -> None:
|
def test_slo_evaluation_step_and_fixed_rules(self) -> None:
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
study = load_study_spec(_write_study_assets(Path(tmp)))
|
study = load_study_spec(_write_study_assets(Path(tmp)))
|
||||||
|
|||||||
Reference in New Issue
Block a user