Harden LLM proposal parsing
This commit is contained in:
@@ -141,6 +141,8 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
||||
proposal_source = None
|
||||
proposal_text = call_llm_for_proposal(policy=study.llm, prompt=prompt)
|
||||
proposal_name = f"proposal-{state.next_trial_index:04d}"
|
||||
raw_proposal_path = store.study_root(study.study_id) / "proposals" / f"{proposal_name}.raw.txt"
|
||||
raw_proposal_path.write_text(proposal_text, encoding="utf-8")
|
||||
proposal = parse_proposal_text(proposal_text, study)
|
||||
store.write_proposal(study.study_id, proposal_name, proposal)
|
||||
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||
|
||||
@@ -395,15 +395,20 @@ class Proposal:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "Proposal":
|
||||
expected_effects = data.get("expected_effects")
|
||||
if isinstance(expected_effects, str):
|
||||
expected_effects_value = [expected_effects.strip()] if expected_effects.strip() else []
|
||||
else:
|
||||
expected_effects_value = _coerce_str_list(
|
||||
expected_effects, context="proposal.expected_effects"
|
||||
)
|
||||
return cls(
|
||||
observation=_require_str(data.get("observation"), context="proposal.observation"),
|
||||
diagnosis=_require_str(data.get("diagnosis"), context="proposal.diagnosis"),
|
||||
config_patch=ConfigPatch.from_dict(
|
||||
_require_mapping(data.get("config_patch"), context="proposal.config_patch")
|
||||
),
|
||||
expected_effects=_coerce_str_list(
|
||||
data.get("expected_effects"), context="proposal.expected_effects"
|
||||
),
|
||||
expected_effects=expected_effects_value,
|
||||
why_not_previous_failures=str(data.get("why_not_previous_failures") or "").strip(),
|
||||
)
|
||||
|
||||
|
||||
@@ -671,6 +671,17 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertEqual(state.best_request_rate, 2.0)
|
||||
self.assertEqual(state.next_trial_index, 3)
|
||||
|
||||
def test_proposal_expected_effects_accepts_string(self) -> None:
|
||||
proposal = Proposal.from_dict(
|
||||
{
|
||||
"observation": "obs",
|
||||
"diagnosis": "diag",
|
||||
"config_patch": {"env_patch": {}, "flag_patch": {}},
|
||||
"expected_effects": "higher throughput",
|
||||
}
|
||||
)
|
||||
self.assertEqual(proposal.expected_effects, ["higher throughput"])
|
||||
|
||||
def test_replay_requests_early_stops_when_slo_is_unrecoverable(self) -> None:
|
||||
requests = [
|
||||
TraceRequest(
|
||||
|
||||
Reference in New Issue
Block a user