Harden LLM proposal parsing

This commit is contained in:
2026-04-04 23:19:42 +08:00
parent 0b7cad7da3
commit 00778eff42
3 changed files with 21 additions and 3 deletions

View File

@@ -141,6 +141,8 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
proposal_source = None proposal_source = None
proposal_text = call_llm_for_proposal(policy=study.llm, prompt=prompt) proposal_text = call_llm_for_proposal(policy=study.llm, prompt=prompt)
proposal_name = f"proposal-{state.next_trial_index:04d}" proposal_name = f"proposal-{state.next_trial_index:04d}"
raw_proposal_path = store.study_root(study.study_id) / "proposals" / f"{proposal_name}.raw.txt"
raw_proposal_path.write_text(proposal_text, encoding="utf-8")
proposal = parse_proposal_text(proposal_text, study) proposal = parse_proposal_text(proposal_text, study)
store.write_proposal(study.study_id, proposal_name, proposal) store.write_proposal(study.study_id, proposal_name, proposal)
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal) trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)

View File

@@ -395,15 +395,20 @@ class Proposal:
@classmethod @classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "Proposal": def from_dict(cls, data: Mapping[str, Any]) -> "Proposal":
expected_effects = data.get("expected_effects")
if isinstance(expected_effects, str):
expected_effects_value = [expected_effects.strip()] if expected_effects.strip() else []
else:
expected_effects_value = _coerce_str_list(
expected_effects, context="proposal.expected_effects"
)
return cls( return cls(
observation=_require_str(data.get("observation"), context="proposal.observation"), observation=_require_str(data.get("observation"), context="proposal.observation"),
diagnosis=_require_str(data.get("diagnosis"), context="proposal.diagnosis"), diagnosis=_require_str(data.get("diagnosis"), context="proposal.diagnosis"),
config_patch=ConfigPatch.from_dict( config_patch=ConfigPatch.from_dict(
_require_mapping(data.get("config_patch"), context="proposal.config_patch") _require_mapping(data.get("config_patch"), context="proposal.config_patch")
), ),
expected_effects=_coerce_str_list( expected_effects=expected_effects_value,
data.get("expected_effects"), context="proposal.expected_effects"
),
why_not_previous_failures=str(data.get("why_not_previous_failures") or "").strip(), why_not_previous_failures=str(data.get("why_not_previous_failures") or "").strip(),
) )

View File

@@ -671,6 +671,17 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(state.best_request_rate, 2.0) self.assertEqual(state.best_request_rate, 2.0)
self.assertEqual(state.next_trial_index, 3) self.assertEqual(state.next_trial_index, 3)
def test_proposal_expected_effects_accepts_string(self) -> None:
proposal = Proposal.from_dict(
{
"observation": "obs",
"diagnosis": "diag",
"config_patch": {"env_patch": {}, "flag_patch": {}},
"expected_effects": "higher throughput",
}
)
self.assertEqual(proposal.expected_effects, ["higher throughput"])
def test_replay_requests_early_stops_when_slo_is_unrecoverable(self) -> None: def test_replay_requests_early_stops_when_slo_is_unrecoverable(self) -> None:
requests = [ requests = [
TraceRequest( TraceRequest(