diff --git a/src/aituner/llm.py b/src/aituner/llm.py index fa27755..f458729 100644 --- a/src/aituner/llm.py +++ b/src/aituner/llm.py @@ -128,15 +128,50 @@ def validate_proposal(proposal: Proposal, study: StudySpec) -> Proposal: return proposal +def _repair_json_text(text: str) -> str: + candidate = text.strip() + if not candidate: + return candidate + stack: list[str] = [] + in_string = False + escape = False + for char in candidate: + if in_string: + if escape: + escape = False + elif char == "\\": + escape = True + elif char == '"': + in_string = False + continue + if char == '"': + in_string = True + elif char == "{": + stack.append("}") + elif char == "[": + stack.append("]") + elif char in {"}", "]"} and stack and stack[-1] == char: + stack.pop() + if in_string: + candidate += '"' + if stack: + candidate += "".join(reversed(stack)) + return candidate + + def _parse_json_object_text(text: str) -> dict[str, Any]: try: payload = json.loads(text) except json.JSONDecodeError: start = text.find("{") - end = text.rfind("}") - if start < 0 or end < start: + if start < 0: raise - payload = json.loads(text[start : end + 1]) + end = text.rfind("}") + candidate = text[start:] if end < start else text[start : end + 1] + try: + payload = json.loads(candidate) + except json.JSONDecodeError: + payload = json.loads(_repair_json_text(candidate)) if not isinstance(payload, dict): raise SpecError("proposal payload must be a JSON object") return payload diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index 2ee88fc..8c6a392 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -317,6 +317,31 @@ class CoreFlowTests(unittest.TestCase): self.assertIn('"failure_reason": "engine_process_exited_before_ready exit_code=1"', prompt) self.assertIn('"VLLM_ATTENTION_BACKEND": "FLASHINFER"', prompt) + def test_parse_proposal_text_repairs_truncated_json(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + study = load_study_spec(_write_study_assets(tmp_path)) + proposal = parse_proposal_text( + """ + { + "observation": "obs", + "diagnosis": "diag", + "config_patch": { + "env_patch": {}, + "flag_patch": { + "max-num-seqs": 24 + } + }, + "expected_effects": [ + "faster batching" + ], + "why_not_previous_failures": "none" + """, + study, + ) + self.assertEqual(proposal.diagnosis, "diag") + self.assertEqual(proposal.config_patch.flag_patch["max-num-seqs"], 24) + def test_length_only_trace_rows_are_synthesized(self) -> None: with tempfile.TemporaryDirectory() as tmp: tmp_path = Path(tmp)