Compare commits
5 Commits
d7df1ebdac
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| adc4351e5d | |||
| eb137a0b62 | |||
| f212673f44 | |||
| a7a5e9ad80 | |||
| 7263587cb6 |
25
.github/workflows/ci.yml
vendored
25
.github/workflows/ci.yml
vendored
@@ -1,25 +0,0 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install
|
||||
run: python -m pip install -e .
|
||||
- name: Test
|
||||
run: python -m unittest discover -s tests -v
|
||||
131
docs/qwen27b-chat-0-8k-tpot40-baseline-infeasible-20260507.md
Normal file
131
docs/qwen27b-chat-0-8k-tpot40-baseline-infeasible-20260507.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# Qwen27B Chat 0-8k TPOT 40ms Baseline Infeasible Run
|
||||
|
||||
Date: 2026-05-07
|
||||
|
||||
## Goal
|
||||
|
||||
Re-run the internal vLLM + Qwen3.5-27B chat 0-8k tuning comparison after adding a study-level guard:
|
||||
|
||||
- if the automatic baseline trial has no feasible probe;
|
||||
- and the lowest sampled request rate still fails the SLO target pass rate;
|
||||
- then AITuner stops the whole study and reports that the SLO is too tight for the current setup.
|
||||
|
||||
This prevents spending the remaining tuning budget on LLM or harness proposals when the baseline itself demonstrates that the workload/SLO is infeasible at the search floor.
|
||||
|
||||
## Implementation
|
||||
|
||||
Commit: `f212673 Stop tuning when baseline is infeasible`
|
||||
|
||||
Changed behavior:
|
||||
|
||||
- `study tune` now persists `tuning_stop_reason` and `tuning_stop_diagnosis` in `state.json`.
|
||||
- `study tune` also persists `tuning_stop_details`, including the lowest sampled probe's TTFT/TPOT mean, p50, p95, and p99.
|
||||
- After the automatic baseline trial is ingested, AITuner checks the worker result:
|
||||
- `status == completed`
|
||||
- `best_request_rate is None`
|
||||
- at least one probe exists
|
||||
- all probes are infeasible
|
||||
- If true, AITuner stops before asking the LLM or harness for any proposal.
|
||||
- Re-running the same study respects the persisted stop state and does not resume tuning.
|
||||
|
||||
Validation:
|
||||
|
||||
```bash
|
||||
python3 -m compileall -q src tests
|
||||
PYTHONPATH=src python3 -m unittest tests.test_core_flow
|
||||
```
|
||||
|
||||
Local and `dash0` both passed.
|
||||
|
||||
## Setup
|
||||
|
||||
Host: `dash0`
|
||||
|
||||
Remote repo: `/home/admin/cpfs/wjh/aituner/aituner`
|
||||
|
||||
Base spec: `configs/examples/dash0_qwen27b_tight_slo_run4_0_8k.json`
|
||||
|
||||
Model: `/home/admin/resource/model/464482ce/qwen3.5-27b/256k-0223-internal`
|
||||
|
||||
Workload: chat, 0-8k input window
|
||||
|
||||
SLO:
|
||||
|
||||
- TTFT: existing step rule from the base spec
|
||||
- TPOT: fixed `40ms`
|
||||
- target pass rate: `0.95`
|
||||
|
||||
Search:
|
||||
|
||||
- Direct AITuner command: `python3 -m aituner.cli study tune ... --max-trials 12`
|
||||
- No manual proposal/state edits during either run.
|
||||
- Both variants used `CUDA_VISIBLE_DEVICES=0,1,2,4,5,6,7`; this was identical for both specs.
|
||||
- The two specs were verified equal after normalizing only `study_id` and `llm.use_harness`.
|
||||
|
||||
Specs:
|
||||
|
||||
- no-harness: `.aituner-tight/specs/dash0-qwen27b-chat-0-8k-tpot40-gpu3skip-12iter-noharness-20260507.json`
|
||||
- harness: `.aituner-tight/specs/dash0-qwen27b-chat-0-8k-tpot40-gpu3skip-12iter-harness-20260507.json`
|
||||
|
||||
## Commands
|
||||
|
||||
No harness:
|
||||
|
||||
```bash
|
||||
PYTHONPATH=src python3 -m aituner.cli study tune \
|
||||
--spec .aituner-tight/specs/dash0-qwen27b-chat-0-8k-tpot40-gpu3skip-12iter-noharness-20260507.json \
|
||||
--store-root .aituner-tight \
|
||||
--max-trials 12
|
||||
```
|
||||
|
||||
Harness:
|
||||
|
||||
```bash
|
||||
PYTHONPATH=src python3 -m aituner.cli study tune \
|
||||
--spec .aituner-tight/specs/dash0-qwen27b-chat-0-8k-tpot40-gpu3skip-12iter-harness-20260507.json \
|
||||
--store-root .aituner-tight \
|
||||
--max-trials 12
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
Both runs stopped after the baseline trial. No LLM/harness proposal was evaluated because baseline had no feasible probe.
|
||||
|
||||
| Variant | Trials executed | Best request rate | Best request rate / GPU | Stop reason |
|
||||
| --- | ---: | ---: | ---: | --- |
|
||||
| no-harness | 1 | - | - | `baseline_all_infeasible` |
|
||||
| harness | 1 | - | - | `baseline_all_infeasible` |
|
||||
|
||||
Baseline probe curve:
|
||||
|
||||
| sampling_u | request rate | pass rate | feasible | early stop reason |
|
||||
| ---: | ---: | ---: | --- | --- |
|
||||
| 0.03125 | 0.895 | 0.000000 | false | `slo_pass_rate_unrecoverable` |
|
||||
| 0.015625 | 0.483333 | 0.137931 | false | `slo_pass_rate_unrecoverable` |
|
||||
| 0.0078125 | 0.246667 | 0.236486 | false | `slo_pass_rate_unrecoverable` |
|
||||
| 0.00390625 | 0.123333 | 0.189189 | false | `slo_pass_rate_unrecoverable` |
|
||||
| 0.001953125 | 0.065000 | 0.205128 | false | `slo_pass_rate_unrecoverable` |
|
||||
| 0.0009765625 | 0.035000 | 0.142857 | false | `slo_pass_rate_unrecoverable` |
|
||||
|
||||
Lowest request rate latency summary:
|
||||
|
||||
| Variant | request rate | pass rate | TTFT mean | TTFT p50 | TTFT p95 | TTFT p99 | TPOT mean | TPOT p50 | TPOT p95 | TPOT p99 |
|
||||
| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: |
|
||||
| no-harness | 0.035000 | 0.142857 | 1288.953ms | 446.586ms | 3011.814ms | 3011.814ms | 12.661ms | 13.141ms | 15.097ms | 15.097ms |
|
||||
| harness | 0.035000 | 0.142857 | 1268.090ms | 445.274ms | 2889.080ms | 2889.080ms | 12.658ms | 13.170ms | 15.102ms | 15.102ms |
|
||||
|
||||
This shows that the TPOT threshold of `40ms` is not the binding constraint at the lowest sampled rate. The observed TPOT p99 is about `15.1ms`; failures are driven by TTFT and by the unrecoverable-pass-rate early stop after too many requests have already failed or been skipped.
|
||||
|
||||
Final diagnosis written by AITuner:
|
||||
|
||||
```text
|
||||
Baseline configuration has no feasible probe under the current SLO. Stopping tuning because even the lowest sampled request rate did not meet the target pass rate. lowest_sampled_request_rate=0.035 lowest_sampling_u=0.000976562 lowest_probe_pass_rate=0.142857 early_stop_reason=slo_pass_rate_unrecoverable
|
||||
```
|
||||
|
||||
## Interpretation
|
||||
|
||||
This run does not measure harness acceleration. It proves that the TPOT 40ms setup is infeasible for the current baseline and search floor: even at `0.035` aggregate request rate, only `14.29%` of requests pass the SLO, far below the `95%` target.
|
||||
|
||||
The correct behavior is to stop the study early and report SLO infeasibility instead of spending the remaining 11 trial slots. Harness cannot accelerate convergence when there is no feasible baseline point and no incumbent for guided tuning.
|
||||
|
||||
For a Fig. 18-style convergence comparison, the next setup must first have at least one feasible baseline or feasible low-rate point under the same metric definitions.
|
||||
@@ -19,6 +19,88 @@ from .trace import load_trace_requests, summarize_window
|
||||
from .worker import run_trial
|
||||
|
||||
|
||||
def _is_empty_config_patch(proposal: Proposal) -> bool:
|
||||
return not proposal.config_patch.env_patch and not proposal.config_patch.flag_patch
|
||||
|
||||
|
||||
def _latency_percentiles(summary: object, metric: str) -> dict[str, float]:
|
||||
if not isinstance(summary, dict):
|
||||
return {}
|
||||
payload = summary.get(metric)
|
||||
if not isinstance(payload, dict):
|
||||
return {}
|
||||
selected: dict[str, float] = {}
|
||||
for key in ("mean", "p50", "p95", "p99"):
|
||||
value = payload.get(key)
|
||||
if isinstance(value, (int, float)):
|
||||
selected[key] = float(value)
|
||||
return selected
|
||||
|
||||
|
||||
def _format_latency_percentiles(metric: str, values: dict[str, float]) -> str:
|
||||
if not values:
|
||||
return ""
|
||||
ordered = ", ".join(
|
||||
f"{key}={values[key]:.3f}"
|
||||
for key in ("mean", "p50", "p95", "p99")
|
||||
if key in values
|
||||
)
|
||||
return f"{metric}({ordered})"
|
||||
|
||||
|
||||
def _baseline_all_infeasible_stop(result: dict[str, object]) -> tuple[str, dict[str, object]] | None:
|
||||
if result.get("status") != "completed":
|
||||
return None
|
||||
if isinstance(result.get("best_request_rate"), (int, float)):
|
||||
return None
|
||||
probes = result.get("probes")
|
||||
if not isinstance(probes, list) or not probes:
|
||||
return None
|
||||
if any(isinstance(probe, dict) and probe.get("feasible") for probe in probes):
|
||||
return None
|
||||
|
||||
diagnostics = result.get("all_infeasible_diagnostics")
|
||||
if not isinstance(diagnostics, dict):
|
||||
diagnostics = {}
|
||||
lowest_rate = diagnostics.get("request_rate")
|
||||
lowest_threshold = diagnostics.get("threshold")
|
||||
pass_rate = diagnostics.get("pass_rate")
|
||||
early_stop_reason = str(diagnostics.get("early_stop_reason") or "").strip()
|
||||
latency_summary = diagnostics.get("latency_summary")
|
||||
ttft = _latency_percentiles(latency_summary, "ttft_ms")
|
||||
tpot = _latency_percentiles(latency_summary, "tpot_ms")
|
||||
details: dict[str, object] = {
|
||||
"lowest_sampled_request_rate": lowest_rate,
|
||||
"lowest_sampling_u": lowest_threshold,
|
||||
"lowest_probe_pass_rate": pass_rate,
|
||||
"early_stop_reason": early_stop_reason,
|
||||
"lowest_probe_latency_ms": {
|
||||
"ttft": ttft,
|
||||
"tpot": tpot,
|
||||
},
|
||||
"lowest_probe_latency_summary": latency_summary if isinstance(latency_summary, dict) else {},
|
||||
}
|
||||
pieces = [
|
||||
"Baseline configuration has no feasible probe under the current SLO.",
|
||||
"Stopping tuning because even the lowest sampled request rate did not meet the target pass rate.",
|
||||
]
|
||||
if isinstance(lowest_rate, (int, float)):
|
||||
pieces.append(f"lowest_sampled_request_rate={float(lowest_rate):.6g}")
|
||||
if isinstance(lowest_threshold, (int, float)):
|
||||
pieces.append(f"lowest_sampling_u={float(lowest_threshold):.6g}")
|
||||
if isinstance(pass_rate, (int, float)):
|
||||
pieces.append(f"lowest_probe_pass_rate={float(pass_rate):.6g}")
|
||||
if early_stop_reason:
|
||||
pieces.append(f"early_stop_reason={early_stop_reason}")
|
||||
for item in (
|
||||
_format_latency_percentiles("lowest_probe_ttft_ms", ttft),
|
||||
_format_latency_percentiles("lowest_probe_tpot_ms", tpot),
|
||||
):
|
||||
if item:
|
||||
pieces.append(item)
|
||||
return " ".join(pieces), details
|
||||
|
||||
|
||||
def _study_source_path(study_root: Path) -> Path:
|
||||
return Path((study_root / "study_spec.source").read_text(encoding="utf-8").strip())
|
||||
|
||||
@@ -126,6 +208,21 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
||||
executed: list[dict[str, object]] = []
|
||||
for idx in range(max_trials):
|
||||
state = store.load_state(study.study_id)
|
||||
if state.tuning_stop_reason:
|
||||
executed.append(
|
||||
{
|
||||
"trial_id": None,
|
||||
"stopped": True,
|
||||
"reason": state.tuning_stop_reason,
|
||||
"diagnosis": state.tuning_stop_diagnosis,
|
||||
"details": state.tuning_stop_details,
|
||||
"state_best_trial_id": state.best_trial_id,
|
||||
"state_best_request_rate": state.best_request_rate,
|
||||
}
|
||||
)
|
||||
break
|
||||
if state.next_trial_index > max_trials:
|
||||
break
|
||||
window, requests = load_trace_requests(study, study_spec_path=spec_path)
|
||||
window_summary = summarize_window(requests, window)
|
||||
harness_context = (
|
||||
@@ -169,7 +266,10 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
||||
ensure_ascii=False,
|
||||
)
|
||||
elif proposal_files:
|
||||
proposal_source = proposal_files[idx]
|
||||
proposal_index = state.next_trial_index - 1
|
||||
if proposal_index >= len(proposal_files):
|
||||
break
|
||||
proposal_source = proposal_files[proposal_index]
|
||||
proposal_text = proposal_source.read_text(encoding="utf-8")
|
||||
proposal_name = proposal_source.stem
|
||||
else:
|
||||
@@ -223,6 +323,13 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
||||
}
|
||||
)
|
||||
break
|
||||
is_auto_baseline = (
|
||||
not proposal_files
|
||||
and not args.skip_baseline
|
||||
and state.next_trial_index == 1
|
||||
and not state.trials
|
||||
and _is_empty_config_patch(proposal)
|
||||
)
|
||||
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||
trial_spec_path = Path(trial.artifact_dir) / "trial_spec.json"
|
||||
result = run_trial(trial_spec_path)
|
||||
@@ -243,6 +350,26 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
||||
"state_best_request_rate": state.best_request_rate,
|
||||
}
|
||||
)
|
||||
if is_auto_baseline:
|
||||
stop = _baseline_all_infeasible_stop(result)
|
||||
if stop is not None:
|
||||
diagnosis, details = stop
|
||||
state.tuning_stop_reason = "baseline_all_infeasible"
|
||||
state.tuning_stop_diagnosis = diagnosis
|
||||
state.tuning_stop_details = details
|
||||
store.save_state(state)
|
||||
executed.append(
|
||||
{
|
||||
"trial_id": None,
|
||||
"stopped": True,
|
||||
"reason": state.tuning_stop_reason,
|
||||
"diagnosis": diagnosis,
|
||||
"details": details,
|
||||
"state_best_trial_id": state.best_trial_id,
|
||||
"state_best_request_rate": state.best_request_rate,
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
final_state = store.load_state(study.study_id)
|
||||
print(
|
||||
@@ -252,6 +379,9 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
||||
"executed_trials": executed,
|
||||
"best_trial_id": final_state.best_trial_id,
|
||||
"best_request_rate": final_state.best_request_rate,
|
||||
"tuning_stop_reason": final_state.tuning_stop_reason,
|
||||
"tuning_stop_diagnosis": final_state.tuning_stop_diagnosis,
|
||||
"tuning_stop_details": final_state.tuning_stop_details,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -604,7 +605,8 @@ def call_llm_for_proposal(
|
||||
if policy.endpoint is None:
|
||||
raise RuntimeError("study.llm.endpoint is not configured")
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(2):
|
||||
max_attempts = 4
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
if policy.endpoint.stream:
|
||||
text = stream_text_completion(
|
||||
@@ -636,6 +638,7 @@ def call_llm_for_proposal(
|
||||
last_error = RuntimeError("LLM response content is empty")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_error = exc
|
||||
if attempt == 0:
|
||||
if attempt < max_attempts - 1:
|
||||
time.sleep(min(30.0, 2.0 * (2**attempt)))
|
||||
continue
|
||||
raise RuntimeError(f"LLM proposal failed after retry: {last_error}") from last_error
|
||||
|
||||
@@ -764,6 +764,9 @@ class StudyState:
|
||||
best_request_rate: float | None = None
|
||||
best_request_rate_per_gpu: float | None = None
|
||||
next_trial_index: int = 1
|
||||
tuning_stop_reason: str = ""
|
||||
tuning_stop_diagnosis: str = ""
|
||||
tuning_stop_details: dict[str, Any] = field(default_factory=dict)
|
||||
best_by_parallel_size: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
trials: list[TrialSummary] = field(default_factory=list)
|
||||
|
||||
|
||||
@@ -45,6 +45,9 @@ class StudyStore:
|
||||
best_request_rate=payload.get("best_request_rate"),
|
||||
best_request_rate_per_gpu=payload.get("best_request_rate_per_gpu"),
|
||||
next_trial_index=int(payload.get("next_trial_index", 1)),
|
||||
tuning_stop_reason=str(payload.get("tuning_stop_reason") or ""),
|
||||
tuning_stop_diagnosis=str(payload.get("tuning_stop_diagnosis") or ""),
|
||||
tuning_stop_details=dict(payload.get("tuning_stop_details") or {}),
|
||||
best_by_parallel_size={
|
||||
str(key): value
|
||||
for key, value in (payload.get("best_by_parallel_size") or {}).items()
|
||||
|
||||
@@ -2919,7 +2919,7 @@ class CoreFlowTests(unittest.TestCase):
|
||||
"--store-root",
|
||||
str(store_root),
|
||||
"--max-trials",
|
||||
"1",
|
||||
"5",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -2997,6 +2997,169 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertEqual(state.trials[0].config_patch, {"env_patch": {}, "flag_patch": {}})
|
||||
self.assertEqual(state.trials[1].config_patch["flag_patch"], {"max-num-seqs": 64})
|
||||
|
||||
def test_cli_tune_stops_when_baseline_is_all_infeasible(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
payload = json.loads(study_path.read_text(encoding="utf-8"))
|
||||
payload["llm"]["endpoint"] = {
|
||||
"provider": "custom",
|
||||
"base_url": "http://llm.example/v1",
|
||||
"wire_api": "chat.completions",
|
||||
"model": "test-model",
|
||||
"api_key_env": "OPENAI_API_KEY",
|
||||
}
|
||||
study_path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
store_root = tmp_path / "store"
|
||||
|
||||
def fake_run_trial(trial_spec_path: Path) -> dict[str, object]:
|
||||
payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
|
||||
trial_root = Path(payload["artifact_dir"])
|
||||
result = {
|
||||
"study_id": payload["study_id"],
|
||||
"trial_id": payload["trial_id"],
|
||||
"status": "completed",
|
||||
"best_sampling_u": None,
|
||||
"best_request_rate": None,
|
||||
"best_pass_rate": None,
|
||||
"best_request_count": None,
|
||||
"probes": [
|
||||
{
|
||||
"threshold": 0.5,
|
||||
"feasible": False,
|
||||
"payload": {"pass_rate": 0.0, "request_rate": 2.0},
|
||||
},
|
||||
{
|
||||
"threshold": 0.25,
|
||||
"feasible": False,
|
||||
"payload": {"pass_rate": 0.5, "request_rate": 1.0},
|
||||
},
|
||||
],
|
||||
"all_infeasible_diagnostics": {
|
||||
"threshold": 0.25,
|
||||
"request_rate": 1.0,
|
||||
"pass_rate": 0.5,
|
||||
"early_stop_reason": "slo_pass_rate_unrecoverable",
|
||||
"latency_summary": {
|
||||
"ttft_ms": {
|
||||
"count": 2,
|
||||
"mean": 1200.0,
|
||||
"p50": 1100.0,
|
||||
"p95": 1900.0,
|
||||
"p99": 1980.0,
|
||||
},
|
||||
"tpot_ms": {
|
||||
"count": 2,
|
||||
"mean": 35.0,
|
||||
"p50": 32.0,
|
||||
"p95": 48.0,
|
||||
"p99": 49.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
(trial_root / "result.json").write_text(json.dumps(result), encoding="utf-8")
|
||||
return result
|
||||
|
||||
with mock.patch("aituner.cli.run_trial", side_effect=fake_run_trial):
|
||||
with mock.patch("aituner.cli.call_llm_for_proposal") as llm_mock:
|
||||
exit_code = cli_main(
|
||||
[
|
||||
"study",
|
||||
"tune",
|
||||
"--spec",
|
||||
str(study_path),
|
||||
"--store-root",
|
||||
str(store_root),
|
||||
"--max-trials",
|
||||
"3",
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(exit_code, 0)
|
||||
llm_mock.assert_not_called()
|
||||
store = StudyStore(store_root)
|
||||
state = store.load_state("study-1")
|
||||
self.assertEqual(state.next_trial_index, 2)
|
||||
self.assertEqual(len(state.trials), 1)
|
||||
self.assertEqual(state.tuning_stop_reason, "baseline_all_infeasible")
|
||||
self.assertIn("lowest_sampled_request_rate=1", state.tuning_stop_diagnosis)
|
||||
self.assertIn("lowest_probe_ttft_ms", state.tuning_stop_diagnosis)
|
||||
self.assertEqual(
|
||||
state.tuning_stop_details["lowest_probe_latency_ms"]["ttft"]["p95"],
|
||||
1900.0,
|
||||
)
|
||||
self.assertEqual(
|
||||
state.tuning_stop_details["lowest_probe_latency_ms"]["tpot"]["p99"],
|
||||
49.0,
|
||||
)
|
||||
|
||||
with mock.patch("aituner.cli.run_trial") as run_trial_mock:
|
||||
with mock.patch("aituner.cli.call_llm_for_proposal") as llm_mock:
|
||||
exit_code = cli_main(
|
||||
[
|
||||
"study",
|
||||
"tune",
|
||||
"--spec",
|
||||
str(study_path),
|
||||
"--store-root",
|
||||
str(store_root),
|
||||
"--max-trials",
|
||||
"3",
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(exit_code, 0)
|
||||
run_trial_mock.assert_not_called()
|
||||
llm_mock.assert_not_called()
|
||||
|
||||
def test_cli_tune_max_trials_is_total_budget_on_resume(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
payload = json.loads(study_path.read_text(encoding="utf-8"))
|
||||
payload["llm"]["endpoint"] = {
|
||||
"provider": "custom",
|
||||
"base_url": "http://llm.example/v1",
|
||||
"wire_api": "chat.completions",
|
||||
"model": "test-model",
|
||||
"api_key_env": "OPENAI_API_KEY",
|
||||
}
|
||||
study_path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
store_root = tmp_path / "store"
|
||||
study = load_study_spec(study_path)
|
||||
store = StudyStore(store_root)
|
||||
store.init_study(spec_path=study_path, study=study)
|
||||
state = StudyState(
|
||||
study_id=study.study_id,
|
||||
next_trial_index=3,
|
||||
trials=[
|
||||
TrialSummary(trial_id="trial-0001", status="completed"),
|
||||
TrialSummary(trial_id="trial-0002", status="completed"),
|
||||
],
|
||||
)
|
||||
store.save_state(state)
|
||||
|
||||
with mock.patch("aituner.cli.call_llm_for_proposal") as llm_mock:
|
||||
with mock.patch("aituner.cli.run_trial") as run_trial_mock:
|
||||
exit_code = cli_main(
|
||||
[
|
||||
"study",
|
||||
"tune",
|
||||
"--spec",
|
||||
str(study_path),
|
||||
"--store-root",
|
||||
str(store_root),
|
||||
"--max-trials",
|
||||
"2",
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(exit_code, 0)
|
||||
llm_mock.assert_not_called()
|
||||
run_trial_mock.assert_not_called()
|
||||
self.assertEqual(store.load_state(study.study_id).next_trial_index, 3)
|
||||
|
||||
def test_load_compare_spec_requires_window_selection(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
|
||||
Reference in New Issue
Block a user