Repair truncated LLM proposal JSON
This commit is contained in:
@@ -128,15 +128,50 @@ def validate_proposal(proposal: Proposal, study: StudySpec) -> Proposal:
|
||||
return proposal
|
||||
|
||||
|
||||
def _repair_json_text(text: str) -> str:
|
||||
candidate = text.strip()
|
||||
if not candidate:
|
||||
return candidate
|
||||
stack: list[str] = []
|
||||
in_string = False
|
||||
escape = False
|
||||
for char in candidate:
|
||||
if in_string:
|
||||
if escape:
|
||||
escape = False
|
||||
elif char == "\\":
|
||||
escape = True
|
||||
elif char == '"':
|
||||
in_string = False
|
||||
continue
|
||||
if char == '"':
|
||||
in_string = True
|
||||
elif char == "{":
|
||||
stack.append("}")
|
||||
elif char == "[":
|
||||
stack.append("]")
|
||||
elif char in {"}", "]"} and stack and stack[-1] == char:
|
||||
stack.pop()
|
||||
if in_string:
|
||||
candidate += '"'
|
||||
if stack:
|
||||
candidate += "".join(reversed(stack))
|
||||
return candidate
|
||||
|
||||
|
||||
def _parse_json_object_text(text: str) -> dict[str, Any]:
|
||||
try:
|
||||
payload = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start < 0 or end < start:
|
||||
if start < 0:
|
||||
raise
|
||||
payload = json.loads(text[start : end + 1])
|
||||
end = text.rfind("}")
|
||||
candidate = text[start:] if end < start else text[start : end + 1]
|
||||
try:
|
||||
payload = json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
payload = json.loads(_repair_json_text(candidate))
|
||||
if not isinstance(payload, dict):
|
||||
raise SpecError("proposal payload must be a JSON object")
|
||||
return payload
|
||||
|
||||
@@ -317,6 +317,31 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertIn('"failure_reason": "engine_process_exited_before_ready exit_code=1"', prompt)
|
||||
self.assertIn('"VLLM_ATTENTION_BACKEND": "FLASHINFER"', prompt)
|
||||
|
||||
def test_parse_proposal_text_repairs_truncated_json(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study = load_study_spec(_write_study_assets(tmp_path))
|
||||
proposal = parse_proposal_text(
|
||||
"""
|
||||
{
|
||||
"observation": "obs",
|
||||
"diagnosis": "diag",
|
||||
"config_patch": {
|
||||
"env_patch": {},
|
||||
"flag_patch": {
|
||||
"max-num-seqs": 24
|
||||
}
|
||||
},
|
||||
"expected_effects": [
|
||||
"faster batching"
|
||||
],
|
||||
"why_not_previous_failures": "none"
|
||||
""",
|
||||
study,
|
||||
)
|
||||
self.assertEqual(proposal.diagnosis, "diag")
|
||||
self.assertEqual(proposal.config_patch.flag_patch["max-num-seqs"], 24)
|
||||
|
||||
def test_length_only_trace_rows_are_synthesized(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
|
||||
Reference in New Issue
Block a user