Repair truncated LLM proposal JSON

This commit is contained in:
2026-04-07 11:38:08 +08:00
parent 94c89e1103
commit 79ba8a50c8
2 changed files with 63 additions and 3 deletions

View File

@@ -128,15 +128,50 @@ def validate_proposal(proposal: Proposal, study: StudySpec) -> Proposal:
return proposal
def _repair_json_text(text: str) -> str:
candidate = text.strip()
if not candidate:
return candidate
stack: list[str] = []
in_string = False
escape = False
for char in candidate:
if in_string:
if escape:
escape = False
elif char == "\\":
escape = True
elif char == '"':
in_string = False
continue
if char == '"':
in_string = True
elif char == "{":
stack.append("}")
elif char == "[":
stack.append("]")
elif char in {"}", "]"} and stack and stack[-1] == char:
stack.pop()
if in_string:
candidate += '"'
if stack:
candidate += "".join(reversed(stack))
return candidate
def _parse_json_object_text(text: str) -> dict[str, Any]:
try:
payload = json.loads(text)
except json.JSONDecodeError:
start = text.find("{")
end = text.rfind("}")
if start < 0 or end < start:
if start < 0:
raise
payload = json.loads(text[start : end + 1])
end = text.rfind("}")
candidate = text[start:] if end < start else text[start : end + 1]
try:
payload = json.loads(candidate)
except json.JSONDecodeError:
payload = json.loads(_repair_json_text(candidate))
if not isinstance(payload, dict):
raise SpecError("proposal payload must be a JSON object")
return payload

View File

@@ -317,6 +317,31 @@ class CoreFlowTests(unittest.TestCase):
self.assertIn('"failure_reason": "engine_process_exited_before_ready exit_code=1"', prompt)
self.assertIn('"VLLM_ATTENTION_BACKEND": "FLASHINFER"', prompt)
def test_parse_proposal_text_repairs_truncated_json(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study = load_study_spec(_write_study_assets(tmp_path))
proposal = parse_proposal_text(
"""
{
"observation": "obs",
"diagnosis": "diag",
"config_patch": {
"env_patch": {},
"flag_patch": {
"max-num-seqs": 24
}
},
"expected_effects": [
"faster batching"
],
"why_not_previous_failures": "none"
""",
study,
)
self.assertEqual(proposal.diagnosis, "diag")
self.assertEqual(proposal.config_patch.flag_patch["max-num-seqs"], 24)
def test_length_only_trace_rows_are_synthesized(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)