diff --git a/src/aituner/cli.py b/src/aituner/cli.py index ad59ce7..3a74b3e 100644 --- a/src/aituner/cli.py +++ b/src/aituner/cli.py @@ -141,6 +141,8 @@ def cmd_study_tune(args: argparse.Namespace) -> int: proposal_source = None proposal_text = call_llm_for_proposal(policy=study.llm, prompt=prompt) proposal_name = f"proposal-{state.next_trial_index:04d}" + raw_proposal_path = store.study_root(study.study_id) / "proposals" / f"{proposal_name}.raw.txt" + raw_proposal_path.write_text(proposal_text, encoding="utf-8") proposal = parse_proposal_text(proposal_text, study) store.write_proposal(study.study_id, proposal_name, proposal) trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal) diff --git a/src/aituner/spec.py b/src/aituner/spec.py index dffe693..fc51b8b 100644 --- a/src/aituner/spec.py +++ b/src/aituner/spec.py @@ -395,15 +395,20 @@ class Proposal: @classmethod def from_dict(cls, data: Mapping[str, Any]) -> "Proposal": + expected_effects = data.get("expected_effects") + if isinstance(expected_effects, str): + expected_effects_value = [expected_effects.strip()] if expected_effects.strip() else [] + else: + expected_effects_value = _coerce_str_list( + expected_effects, context="proposal.expected_effects" + ) return cls( observation=_require_str(data.get("observation"), context="proposal.observation"), diagnosis=_require_str(data.get("diagnosis"), context="proposal.diagnosis"), config_patch=ConfigPatch.from_dict( _require_mapping(data.get("config_patch"), context="proposal.config_patch") ), - expected_effects=_coerce_str_list( - data.get("expected_effects"), context="proposal.expected_effects" - ), + expected_effects=expected_effects_value, why_not_previous_failures=str(data.get("why_not_previous_failures") or "").strip(), ) diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index 2304fe5..ba182be 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -671,6 +671,17 @@ class CoreFlowTests(unittest.TestCase): self.assertEqual(state.best_request_rate, 2.0) self.assertEqual(state.next_trial_index, 3) + def test_proposal_expected_effects_accepts_string(self) -> None: + proposal = Proposal.from_dict( + { + "observation": "obs", + "diagnosis": "diag", + "config_patch": {"env_patch": {}, "flag_patch": {}}, + "expected_effects": "higher throughput", + } + ) + self.assertEqual(proposal.expected_effects, ["higher throughput"]) + def test_replay_requests_early_stops_when_slo_is_unrecoverable(self) -> None: requests = [ TraceRequest(