diff --git a/src/aituner/spec.py b/src/aituner/spec.py index 3b7fbca..e39c0a1 100644 --- a/src/aituner/spec.py +++ b/src/aituner/spec.py @@ -151,10 +151,12 @@ class TraceSpec: timestamp_field: str max_concurrency: int max_requests_per_probe: int | None = None + synthetic_prompt_cap_tokens: int | None = None @classmethod def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec": max_requests = data.get("max_requests_per_probe") + synthetic_prompt_cap = data.get("synthetic_prompt_cap_tokens") return cls( windows_path=_require_str(data.get("windows_path"), context="trace.windows_path"), window_id=_require_str(data.get("window_id"), context="trace.window_id"), @@ -167,6 +169,9 @@ class TraceSpec: data.get("max_concurrency", 64), context="trace.max_concurrency" ), max_requests_per_probe=int(max_requests) if max_requests is not None else None, + synthetic_prompt_cap_tokens=( + int(synthetic_prompt_cap) if synthetic_prompt_cap is not None else None + ), ) diff --git a/src/aituner/trace.py b/src/aituner/trace.py index 8eec760..1fbd707 100644 --- a/src/aituner/trace.py +++ b/src/aituner/trace.py @@ -81,6 +81,14 @@ def _coerce_messages(row: Mapping[str, Any]) -> list[dict[str, Any]]: raise TraceError("trace row is missing chat messages/prompt text") +def _synthetic_prompt_from_tokens(token_count: int) -> str: + if token_count <= 0: + return "hello" + # Keep it ASCII and structurally simple so the same trace can be replayed + # on any OpenAI-compatible engine without extra tokenizer assets. + return " ".join(["token"] * token_count) + + def _coerce_completion_tokens(row: Mapping[str, Any]) -> int | None: for key in ("max_completion_tokens", "max_tokens", "output_length", "completion_tokens"): value = row.get(key) @@ -123,9 +131,24 @@ def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[Win sampling_u = row.get(study.trace.u_field, 1.0) if isinstance(sampling_u, bool) or not isinstance(sampling_u, (int, float)): raise TraceError(f"trace row {idx} is missing numeric {study.trace.u_field}") + prompt_tokens_hint = _coerce_prompt_tokens(row) + try: + messages = _coerce_messages(row) + except TraceError: + capped_prompt_tokens = prompt_tokens_hint or 0 + if study.trace.synthetic_prompt_cap_tokens is not None: + capped_prompt_tokens = min( + capped_prompt_tokens, study.trace.synthetic_prompt_cap_tokens + ) + messages = [ + { + "role": "user", + "content": _synthetic_prompt_from_tokens(capped_prompt_tokens), + } + ] body: dict[str, Any] = { "model": study.model.served_model_name, - "messages": _coerce_messages(row), + "messages": messages, "stream": True, "stream_options": {"include_usage": True}, } @@ -141,7 +164,7 @@ def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[Win arrival_s=float(timestamp), sampling_u=float(sampling_u), body=body, - prompt_tokens_hint=_coerce_prompt_tokens(row), + prompt_tokens_hint=prompt_tokens_hint, completion_tokens_hint=completion_tokens, ) ) diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index 9734ee5..7b7c4bd 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -153,6 +153,85 @@ class CoreFlowTests(unittest.TestCase): self.assertIn("queueing_knee_by_bucket", prompt) self.assertTrue(study_root.exists()) + def test_length_only_trace_rows_are_synthesized(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + trace_dir = tmp_path / "trace_windows" / "traces" + trace_dir.mkdir(parents=True) + trace_path = trace_dir / "chat_len_only.jsonl" + with trace_path.open("w", encoding="utf-8") as handle: + handle.write( + json.dumps( + { + "timestamp": 0.0, + "sampling_u": 0.1, + "input_length": 32, + "output_length": 16 + } + ) + + "\n" + ) + windows_path = tmp_path / "trace_windows" / "windows.json" + windows_path.write_text( + json.dumps( + { + "windows": [ + { + "window_id": "w1", + "trace_type": "chat", + "trace_file": "traces/chat_len_only.jsonl", + "window_start": 0.0, + "window_end": 10.0 + } + ] + } + ), + encoding="utf-8", + ) + study_path = tmp_path / "study.json" + study_path.write_text( + json.dumps( + { + "study_id": "study-len-only", + "hardware": {"gpu_count": 1}, + "model": { + "model_id": "m1", + "served_model_name": "dummy-model" + }, + "engine": { + "engine_name": "vllm", + "exec_path": "/usr/local/bin/vllm", + "host": "127.0.0.1", + "port": 8000, + "ready_timeout_s": 10, + "request_timeout_s": 10, + "healthcheck_path": "/v1/models", + "launch_args": [], + "base_envs": {}, + "base_flags": {}, + "tunable_envs": [], + "tunable_flags": [] + }, + "trace": { + "windows_path": str(windows_path), + "window_id": "w1", + "max_concurrency": 1, + "synthetic_prompt_cap_tokens": 8 + }, + "slo": {"target_pass_rate": 0.95}, + "search": {"low": 0.0, "high": 1.0, "tolerance": 0.1, "max_probes": 2, "sample_seed": 1}, + "llm": {"system_prompt": "", "max_history_trials": 1} + } + ), + encoding="utf-8", + ) + study = load_study_spec(study_path) + _, requests = load_trace_requests(study, study_spec_path=study_path) + self.assertEqual(len(requests), 1) + message = requests[0].body["messages"][0]["content"] + self.assertEqual(message.count("token"), 8) + self.assertEqual(requests[0].body["max_tokens"], 16) + def test_slo_evaluation_step_and_fixed_rules(self) -> None: with tempfile.TemporaryDirectory() as tmp: study = load_study_spec(_write_study_assets(Path(tmp)))