Harden LLM proposal parsing
This commit is contained in:
@@ -141,6 +141,8 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
|||||||
proposal_source = None
|
proposal_source = None
|
||||||
proposal_text = call_llm_for_proposal(policy=study.llm, prompt=prompt)
|
proposal_text = call_llm_for_proposal(policy=study.llm, prompt=prompt)
|
||||||
proposal_name = f"proposal-{state.next_trial_index:04d}"
|
proposal_name = f"proposal-{state.next_trial_index:04d}"
|
||||||
|
raw_proposal_path = store.study_root(study.study_id) / "proposals" / f"{proposal_name}.raw.txt"
|
||||||
|
raw_proposal_path.write_text(proposal_text, encoding="utf-8")
|
||||||
proposal = parse_proposal_text(proposal_text, study)
|
proposal = parse_proposal_text(proposal_text, study)
|
||||||
store.write_proposal(study.study_id, proposal_name, proposal)
|
store.write_proposal(study.study_id, proposal_name, proposal)
|
||||||
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||||
|
|||||||
@@ -395,15 +395,20 @@ class Proposal:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Mapping[str, Any]) -> "Proposal":
|
def from_dict(cls, data: Mapping[str, Any]) -> "Proposal":
|
||||||
|
expected_effects = data.get("expected_effects")
|
||||||
|
if isinstance(expected_effects, str):
|
||||||
|
expected_effects_value = [expected_effects.strip()] if expected_effects.strip() else []
|
||||||
|
else:
|
||||||
|
expected_effects_value = _coerce_str_list(
|
||||||
|
expected_effects, context="proposal.expected_effects"
|
||||||
|
)
|
||||||
return cls(
|
return cls(
|
||||||
observation=_require_str(data.get("observation"), context="proposal.observation"),
|
observation=_require_str(data.get("observation"), context="proposal.observation"),
|
||||||
diagnosis=_require_str(data.get("diagnosis"), context="proposal.diagnosis"),
|
diagnosis=_require_str(data.get("diagnosis"), context="proposal.diagnosis"),
|
||||||
config_patch=ConfigPatch.from_dict(
|
config_patch=ConfigPatch.from_dict(
|
||||||
_require_mapping(data.get("config_patch"), context="proposal.config_patch")
|
_require_mapping(data.get("config_patch"), context="proposal.config_patch")
|
||||||
),
|
),
|
||||||
expected_effects=_coerce_str_list(
|
expected_effects=expected_effects_value,
|
||||||
data.get("expected_effects"), context="proposal.expected_effects"
|
|
||||||
),
|
|
||||||
why_not_previous_failures=str(data.get("why_not_previous_failures") or "").strip(),
|
why_not_previous_failures=str(data.get("why_not_previous_failures") or "").strip(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -671,6 +671,17 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
self.assertEqual(state.best_request_rate, 2.0)
|
self.assertEqual(state.best_request_rate, 2.0)
|
||||||
self.assertEqual(state.next_trial_index, 3)
|
self.assertEqual(state.next_trial_index, 3)
|
||||||
|
|
||||||
|
def test_proposal_expected_effects_accepts_string(self) -> None:
|
||||||
|
proposal = Proposal.from_dict(
|
||||||
|
{
|
||||||
|
"observation": "obs",
|
||||||
|
"diagnosis": "diag",
|
||||||
|
"config_patch": {"env_patch": {}, "flag_patch": {}},
|
||||||
|
"expected_effects": "higher throughput",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(proposal.expected_effects, ["higher throughput"])
|
||||||
|
|
||||||
def test_replay_requests_early_stops_when_slo_is_unrecoverable(self) -> None:
|
def test_replay_requests_early_stops_when_slo_is_unrecoverable(self) -> None:
|
||||||
requests = [
|
requests = [
|
||||||
TraceRequest(
|
TraceRequest(
|
||||||
|
|||||||
Reference in New Issue
Block a user