Repair truncated LLM proposal JSON
This commit is contained in:
@@ -128,15 +128,50 @@ def validate_proposal(proposal: Proposal, study: StudySpec) -> Proposal:
|
|||||||
return proposal
|
return proposal
|
||||||
|
|
||||||
|
|
||||||
|
def _repair_json_text(text: str) -> str:
|
||||||
|
candidate = text.strip()
|
||||||
|
if not candidate:
|
||||||
|
return candidate
|
||||||
|
stack: list[str] = []
|
||||||
|
in_string = False
|
||||||
|
escape = False
|
||||||
|
for char in candidate:
|
||||||
|
if in_string:
|
||||||
|
if escape:
|
||||||
|
escape = False
|
||||||
|
elif char == "\\":
|
||||||
|
escape = True
|
||||||
|
elif char == '"':
|
||||||
|
in_string = False
|
||||||
|
continue
|
||||||
|
if char == '"':
|
||||||
|
in_string = True
|
||||||
|
elif char == "{":
|
||||||
|
stack.append("}")
|
||||||
|
elif char == "[":
|
||||||
|
stack.append("]")
|
||||||
|
elif char in {"}", "]"} and stack and stack[-1] == char:
|
||||||
|
stack.pop()
|
||||||
|
if in_string:
|
||||||
|
candidate += '"'
|
||||||
|
if stack:
|
||||||
|
candidate += "".join(reversed(stack))
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
|
||||||
def _parse_json_object_text(text: str) -> dict[str, Any]:
|
def _parse_json_object_text(text: str) -> dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
payload = json.loads(text)
|
payload = json.loads(text)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
start = text.find("{")
|
start = text.find("{")
|
||||||
end = text.rfind("}")
|
if start < 0:
|
||||||
if start < 0 or end < start:
|
|
||||||
raise
|
raise
|
||||||
payload = json.loads(text[start : end + 1])
|
end = text.rfind("}")
|
||||||
|
candidate = text[start:] if end < start else text[start : end + 1]
|
||||||
|
try:
|
||||||
|
payload = json.loads(candidate)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
payload = json.loads(_repair_json_text(candidate))
|
||||||
if not isinstance(payload, dict):
|
if not isinstance(payload, dict):
|
||||||
raise SpecError("proposal payload must be a JSON object")
|
raise SpecError("proposal payload must be a JSON object")
|
||||||
return payload
|
return payload
|
||||||
|
|||||||
@@ -317,6 +317,31 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
self.assertIn('"failure_reason": "engine_process_exited_before_ready exit_code=1"', prompt)
|
self.assertIn('"failure_reason": "engine_process_exited_before_ready exit_code=1"', prompt)
|
||||||
self.assertIn('"VLLM_ATTENTION_BACKEND": "FLASHINFER"', prompt)
|
self.assertIn('"VLLM_ATTENTION_BACKEND": "FLASHINFER"', prompt)
|
||||||
|
|
||||||
|
def test_parse_proposal_text_repairs_truncated_json(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
tmp_path = Path(tmp)
|
||||||
|
study = load_study_spec(_write_study_assets(tmp_path))
|
||||||
|
proposal = parse_proposal_text(
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"observation": "obs",
|
||||||
|
"diagnosis": "diag",
|
||||||
|
"config_patch": {
|
||||||
|
"env_patch": {},
|
||||||
|
"flag_patch": {
|
||||||
|
"max-num-seqs": 24
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"expected_effects": [
|
||||||
|
"faster batching"
|
||||||
|
],
|
||||||
|
"why_not_previous_failures": "none"
|
||||||
|
""",
|
||||||
|
study,
|
||||||
|
)
|
||||||
|
self.assertEqual(proposal.diagnosis, "diag")
|
||||||
|
self.assertEqual(proposal.config_patch.flag_patch["max-num-seqs"], 24)
|
||||||
|
|
||||||
def test_length_only_trace_rows_are_synthesized(self) -> None:
|
def test_length_only_trace_rows_are_synthesized(self) -> None:
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
tmp_path = Path(tmp)
|
tmp_path = Path(tmp)
|
||||||
|
|||||||
Reference in New Issue
Block a user