Add initial config preflight review
This commit is contained in:
@@ -25,7 +25,15 @@ from .lca import (
|
||||
resolve_length_mode,
|
||||
similarity_report,
|
||||
)
|
||||
from .llm import build_prompt, call_llm_for_proposal, load_capability_profile, parse_proposal_text
|
||||
from .llm import (
|
||||
build_initial_config_review_prompt,
|
||||
build_prompt,
|
||||
call_llm_for_initial_config_review,
|
||||
call_llm_for_proposal,
|
||||
load_capability_profile,
|
||||
parse_initial_config_review_text,
|
||||
parse_proposal_text,
|
||||
)
|
||||
from .spec import (
|
||||
ConfigPatch,
|
||||
Proposal,
|
||||
@@ -403,6 +411,77 @@ def _harness_snapshot_payload(
|
||||
}
|
||||
|
||||
|
||||
def _maybe_run_initial_config_review(
|
||||
*,
|
||||
study: StudySpec,
|
||||
spec_path: Path,
|
||||
store: StudyStore,
|
||||
capability_profile: dict[str, Any] | None,
|
||||
) -> dict[str, Any] | None:
|
||||
mode = study.llm.initial_config_review.mode
|
||||
if mode == "off":
|
||||
return None
|
||||
state = store.load_state(study.study_id)
|
||||
if state.trials or state.next_trial_index != 1:
|
||||
return None
|
||||
|
||||
audit_name = "initial-config-0001"
|
||||
audit_root = store.study_root(study.study_id) / "preflight_audits"
|
||||
audit_path = audit_root / f"{audit_name}.json"
|
||||
if audit_path.exists():
|
||||
return json.loads(audit_path.read_text(encoding="utf-8"))
|
||||
|
||||
base_payload: dict[str, Any] = {
|
||||
"schema_version": 1,
|
||||
"audit_type": "initial_config_review",
|
||||
"study_id": study.study_id,
|
||||
"mode": mode,
|
||||
"repair_applied": False,
|
||||
}
|
||||
if study.llm.endpoint is None:
|
||||
payload = {
|
||||
**base_payload,
|
||||
"status": "skipped",
|
||||
"reason": "llm.endpoint_not_configured",
|
||||
}
|
||||
store.write_preflight_audit(study.study_id, audit_name, payload)
|
||||
return payload
|
||||
|
||||
window, requests = load_trace_requests(study, study_spec_path=spec_path)
|
||||
window_summary = summarize_window(requests, window)
|
||||
workload_profile = build_study_workload_profile(study, requests, window)
|
||||
prompt = build_initial_config_review_prompt(
|
||||
study=study,
|
||||
window_summary=window_summary,
|
||||
capability_profile=capability_profile,
|
||||
workload_profile=workload_profile,
|
||||
)
|
||||
prompt_path = audit_root / f"{audit_name}.prompt.txt"
|
||||
raw_path = audit_root / f"{audit_name}.raw.txt"
|
||||
prompt_path.write_text(prompt, encoding="utf-8")
|
||||
try:
|
||||
raw_text = call_llm_for_initial_config_review(policy=study.llm, prompt=prompt)
|
||||
raw_path.write_text(raw_text, encoding="utf-8")
|
||||
review = parse_initial_config_review_text(raw_text, study)
|
||||
payload = {
|
||||
**base_payload,
|
||||
"status": "completed",
|
||||
"prompt_path": str(prompt_path),
|
||||
"raw_response_path": str(raw_path),
|
||||
"review": review,
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
payload = {
|
||||
**base_payload,
|
||||
"status": "failed",
|
||||
"prompt_path": str(prompt_path),
|
||||
"raw_response_path": str(raw_path) if raw_path.exists() else None,
|
||||
"error": str(exc),
|
||||
}
|
||||
store.write_preflight_audit(study.study_id, audit_name, payload)
|
||||
return payload
|
||||
|
||||
|
||||
def _latency_percentiles(summary: object, metric: str) -> dict[str, float]:
|
||||
if not isinstance(summary, dict):
|
||||
return {}
|
||||
@@ -585,6 +664,12 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
||||
store = StudyStore(Path(args.store_root) if args.store_root else None)
|
||||
study_root = store.init_study(spec_path=spec_path, study=study)
|
||||
capability_profile = load_capability_profile(study, study_spec_path=spec_path)
|
||||
preflight_audit = _maybe_run_initial_config_review(
|
||||
study=study,
|
||||
spec_path=spec_path,
|
||||
store=store,
|
||||
capability_profile=capability_profile,
|
||||
)
|
||||
proposal_files = [Path(item).resolve() for item in (args.proposal_file or [])]
|
||||
max_trials = args.max_trials or (len(proposal_files) if proposal_files else 2)
|
||||
proposal_policy = args.proposal_policy
|
||||
@@ -892,6 +977,7 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
||||
json.dumps(
|
||||
{
|
||||
"study_root": str(study_root),
|
||||
"preflight_audit": preflight_audit,
|
||||
"executed_trials": executed,
|
||||
"best_trial_id": final_state.best_trial_id,
|
||||
"best_request_rate": final_state.best_request_rate,
|
||||
|
||||
@@ -5,9 +5,10 @@ import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .engine_adapters.vllm import default_vllm_descriptors
|
||||
from .harness import _effective_config_signature, build_harness_context, render_harness_context
|
||||
from .http_client import chat_completion, stream_text_completion
|
||||
from .spec import LLMPolicySpec, Proposal, SpecError, StudySpec, StudyState
|
||||
from .spec import LLMPolicySpec, Proposal, SpecError, StudySpec, StudyState, to_jsonable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lca import WorkloadProfile
|
||||
@@ -175,6 +176,108 @@ def _enumerate_parallel_candidates(study: StudySpec) -> list[dict[str, int | boo
|
||||
return candidates
|
||||
|
||||
|
||||
def _descriptor_payload(study: StudySpec) -> list[dict[str, Any]]:
|
||||
if study.engine.engine_name.lower() != "vllm":
|
||||
return []
|
||||
return [
|
||||
to_jsonable(descriptor)
|
||||
for descriptor in default_vllm_descriptors(tunable_flags=study.engine.tunable_flags)
|
||||
]
|
||||
|
||||
|
||||
def build_initial_config_review_prompt(
|
||||
*,
|
||||
study: StudySpec,
|
||||
window_summary: dict[str, Any],
|
||||
capability_profile: dict[str, Any] | None,
|
||||
workload_profile: "WorkloadProfile | None" = None,
|
||||
) -> str:
|
||||
"""Build the static pre-flight audit prompt for the study's base config."""
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"study_id": study.study_id,
|
||||
"objective": "static initial-config audit before any measurement or repair",
|
||||
"review_mode": study.llm.initial_config_review.mode,
|
||||
"output_contract": {
|
||||
"verdict": "one of ok, risky, invalid, unknown",
|
||||
"issues": [
|
||||
{
|
||||
"knob": "flag/env knob name or mechanism name",
|
||||
"mechanism": "affected mechanism if known",
|
||||
"reason": "short reason grounded in the supplied descriptors/context",
|
||||
"severity": "low|medium|high",
|
||||
}
|
||||
],
|
||||
"minimal_repair_patch": {
|
||||
"env_patch": {},
|
||||
"flag_patch": {},
|
||||
},
|
||||
"do_not_change": ["knob names that should remain fixed during repair"],
|
||||
"confidence": "number in [0, 1]",
|
||||
"requires_harness_validation": True,
|
||||
},
|
||||
"hardware": {
|
||||
"gpu_count": study.hardware.gpu_count,
|
||||
"gpu_model": study.hardware.gpu_model,
|
||||
},
|
||||
"model": {
|
||||
"model_id": study.model.model_id,
|
||||
"served_model_name": study.model.served_model_name,
|
||||
},
|
||||
"engine": {
|
||||
"engine_name": study.engine.engine_name,
|
||||
"engine_version": study.engine.engine_version,
|
||||
"base_envs": study.engine.base_envs,
|
||||
"base_flags": study.engine.base_flags,
|
||||
"allowed_env_keys": study.engine.tunable_envs,
|
||||
"allowed_flag_keys": study.engine.tunable_flags,
|
||||
"topology_constraints": (
|
||||
study.engine.topology_constraints.__dict__
|
||||
if study.engine.topology_constraints is not None
|
||||
else None
|
||||
),
|
||||
"effective_topology": _effective_topology(study),
|
||||
},
|
||||
"trace": {
|
||||
"window_id": study.trace.window_id,
|
||||
"request_mode": study.trace.request_mode,
|
||||
"completion_tokens_override": study.trace.completion_tokens_override,
|
||||
"input_length_filter": (
|
||||
{
|
||||
"min_input_tokens": study.trace.input_length_filter.min_input_tokens,
|
||||
"max_input_tokens": study.trace.input_length_filter.max_input_tokens,
|
||||
}
|
||||
if study.trace.input_length_filter is not None
|
||||
else None
|
||||
),
|
||||
"window_summary": window_summary,
|
||||
"workload_lca_profile": (
|
||||
workload_profile.to_dict() if workload_profile is not None else None
|
||||
),
|
||||
},
|
||||
"slo": {
|
||||
"target_pass_rate": study.slo.target_pass_rate,
|
||||
"ttft_rule": study.slo.ttft_rule,
|
||||
"tpot_rule": study.slo.tpot_rule,
|
||||
},
|
||||
"capability_profile": capability_profile or {},
|
||||
"knob_descriptors": _descriptor_payload(study),
|
||||
}
|
||||
sections = [
|
||||
"You are doing a pre-flight review of an LLM serving initial configuration.",
|
||||
"This is not a tuning proposal. Do not claim measured performance.",
|
||||
"Identify obviously risky or inconsistent initial settings before the first baseline trial.",
|
||||
"Use the supplied knob descriptors and constraints; do not use case-specific memorized values.",
|
||||
"Return exactly one JSON object matching output_contract. Do not wrap it in markdown.",
|
||||
"If uncertain, use verdict=unknown and an empty minimal_repair_patch.",
|
||||
"The current warn mode records your audit only; the patch will not be applied automatically.",
|
||||
"",
|
||||
"Audit context:",
|
||||
json.dumps(payload, default=lambda value: value.__dict__, ensure_ascii=False, indent=2),
|
||||
]
|
||||
return "\n".join(sections)
|
||||
|
||||
|
||||
def build_prompt(
|
||||
*,
|
||||
study: StudySpec,
|
||||
@@ -644,6 +747,87 @@ def parse_proposal_text(text: str, study: StudySpec) -> Proposal:
|
||||
return validate_proposal(proposal, study)
|
||||
|
||||
|
||||
def parse_initial_config_review_text(text: str, study: StudySpec) -> dict[str, Any]:
|
||||
payload = _parse_json_object_text(text)
|
||||
verdict = str(payload.get("verdict") or "unknown").strip().lower()
|
||||
if verdict not in {"ok", "risky", "invalid", "unknown"}:
|
||||
raise SpecError(
|
||||
"initial-config review verdict must be one of: ok, risky, invalid, unknown."
|
||||
)
|
||||
|
||||
issues_payload = payload.get("issues", [])
|
||||
if not isinstance(issues_payload, list):
|
||||
raise SpecError("initial-config review issues must be a list.")
|
||||
issues: list[dict[str, str]] = []
|
||||
for idx, item in enumerate(issues_payload):
|
||||
if isinstance(item, str):
|
||||
issues.append(
|
||||
{
|
||||
"knob": "",
|
||||
"mechanism": "",
|
||||
"reason": item.strip(),
|
||||
"severity": "medium",
|
||||
}
|
||||
)
|
||||
continue
|
||||
if not isinstance(item, dict):
|
||||
raise SpecError(f"initial-config review issues[{idx}] must be an object.")
|
||||
severity = str(item.get("severity") or "medium").strip().lower()
|
||||
if severity not in {"low", "medium", "high"}:
|
||||
severity = "medium"
|
||||
issues.append(
|
||||
{
|
||||
"knob": str(item.get("knob") or "").strip(),
|
||||
"mechanism": str(item.get("mechanism") or "").strip(),
|
||||
"reason": str(item.get("reason") or "").strip(),
|
||||
"severity": severity,
|
||||
}
|
||||
)
|
||||
|
||||
repair_payload = payload.get("minimal_repair_patch") or {}
|
||||
if not isinstance(repair_payload, dict):
|
||||
raise SpecError("initial-config review minimal_repair_patch must be an object.")
|
||||
if "env_patch" not in repair_payload and "flag_patch" not in repair_payload:
|
||||
repair_payload = {"env_patch": {}, "flag_patch": repair_payload}
|
||||
repair_proposal = Proposal.from_dict(
|
||||
{
|
||||
"observation": "Initial-config pre-flight repair candidate.",
|
||||
"diagnosis": "Validate the LLM audit's minimal repair patch against study constraints.",
|
||||
"config_patch": repair_payload,
|
||||
"expected_effects": ["pre-flight audit only"],
|
||||
"should_stop": False,
|
||||
}
|
||||
)
|
||||
validate_proposal(repair_proposal, study)
|
||||
|
||||
do_not_change_payload = payload.get("do_not_change", [])
|
||||
if isinstance(do_not_change_payload, str):
|
||||
do_not_change = [do_not_change_payload.strip()] if do_not_change_payload.strip() else []
|
||||
elif isinstance(do_not_change_payload, list):
|
||||
do_not_change = [str(item).strip() for item in do_not_change_payload if str(item).strip()]
|
||||
else:
|
||||
raise SpecError("initial-config review do_not_change must be a list.")
|
||||
|
||||
raw_confidence = payload.get("confidence", 0.0)
|
||||
confidence = float(raw_confidence) if isinstance(raw_confidence, (int, float)) else 0.0
|
||||
confidence = max(0.0, min(1.0, confidence))
|
||||
requires_validation = payload.get("requires_harness_validation")
|
||||
if requires_validation is None:
|
||||
requires_validation = True
|
||||
if not isinstance(requires_validation, bool):
|
||||
raise SpecError("initial-config review requires_harness_validation must be boolean.")
|
||||
|
||||
return {
|
||||
"schema_version": 1,
|
||||
"verdict": verdict,
|
||||
"issues": issues,
|
||||
"minimal_repair_patch": to_jsonable(repair_proposal.config_patch),
|
||||
"do_not_change": do_not_change,
|
||||
"confidence": confidence,
|
||||
"requires_harness_validation": requires_validation,
|
||||
}
|
||||
|
||||
|
||||
def _extract_response_text(response: dict[str, Any]) -> str:
|
||||
output_text = response.get("output_text")
|
||||
if isinstance(output_text, str) and output_text:
|
||||
@@ -683,11 +867,11 @@ def _extract_response_text(response: dict[str, Any]) -> str:
|
||||
raise RuntimeError("LLM response content is empty")
|
||||
|
||||
|
||||
def call_llm_for_proposal(
|
||||
def _call_llm_text(
|
||||
*,
|
||||
policy: LLMPolicySpec,
|
||||
prompt: str,
|
||||
use_harness: bool = True,
|
||||
system_prompt: str = "",
|
||||
) -> str:
|
||||
if policy.endpoint is None:
|
||||
raise RuntimeError("study.llm.endpoint is not configured")
|
||||
@@ -695,7 +879,6 @@ def call_llm_for_proposal(
|
||||
max_attempts = 4
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
system_prompt = policy.system_prompt if use_harness else ""
|
||||
if policy.endpoint.stream:
|
||||
text = stream_text_completion(
|
||||
base_url=policy.endpoint.base_url,
|
||||
@@ -730,3 +913,29 @@ def call_llm_for_proposal(
|
||||
time.sleep(min(30.0, 2.0 * (2**attempt)))
|
||||
continue
|
||||
raise RuntimeError(f"LLM proposal failed after retry: {last_error}") from last_error
|
||||
|
||||
|
||||
def call_llm_for_proposal(
|
||||
*,
|
||||
policy: LLMPolicySpec,
|
||||
prompt: str,
|
||||
use_harness: bool = True,
|
||||
) -> str:
|
||||
system_prompt = policy.system_prompt if use_harness else ""
|
||||
return _call_llm_text(policy=policy, prompt=prompt, system_prompt=system_prompt)
|
||||
|
||||
|
||||
def call_llm_for_initial_config_review(
|
||||
*,
|
||||
policy: LLMPolicySpec,
|
||||
prompt: str,
|
||||
) -> str:
|
||||
review_system = "\n".join(
|
||||
item
|
||||
for item in (
|
||||
policy.system_prompt,
|
||||
"You are auditing an initial serving config. Return only the requested JSON audit.",
|
||||
)
|
||||
if item.strip()
|
||||
)
|
||||
return _call_llm_text(policy=policy, prompt=prompt, system_prompt=review_system)
|
||||
|
||||
@@ -726,6 +726,21 @@ class LLMEndpointSpec:
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InitialConfigReviewSpec:
|
||||
mode: str = "off"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Any) -> "InitialConfigReviewSpec":
|
||||
if data is None:
|
||||
return cls()
|
||||
payload = _require_mapping(data, context="llm.initial_config_review")
|
||||
mode = str(payload.get("mode") or "off").strip().lower()
|
||||
if mode not in {"off", "warn"}:
|
||||
raise SpecError("llm.initial_config_review.mode must be one of: off, warn.")
|
||||
return cls(mode=mode)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMPolicySpec:
|
||||
endpoint: LLMEndpointSpec | None
|
||||
@@ -733,6 +748,9 @@ class LLMPolicySpec:
|
||||
max_history_trials: int
|
||||
use_harness: bool = True
|
||||
harness_candidate_policy: str = "advisory"
|
||||
initial_config_review: InitialConfigReviewSpec = field(
|
||||
default_factory=InitialConfigReviewSpec
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any] | None) -> "LLMPolicySpec":
|
||||
@@ -763,6 +781,9 @@ class LLMPolicySpec:
|
||||
else True
|
||||
),
|
||||
harness_candidate_policy=harness_candidate_policy,
|
||||
initial_config_review=InitialConfigReviewSpec.from_dict(
|
||||
payload.get("initial_config_review")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ class StudyStore:
|
||||
"results",
|
||||
"harness",
|
||||
"candidate_family_gaps",
|
||||
"preflight_audits",
|
||||
):
|
||||
(root / rel).mkdir(parents=True, exist_ok=True)
|
||||
(root / "study_spec.source").write_text(str(spec_path.resolve()) + "\n", encoding="utf-8")
|
||||
@@ -108,6 +109,16 @@ class StudyStore:
|
||||
self.write_json(path, payload)
|
||||
return path
|
||||
|
||||
def write_preflight_audit(
|
||||
self,
|
||||
study_id: str,
|
||||
audit_name: str,
|
||||
payload: dict[str, Any],
|
||||
) -> Path:
|
||||
path = self.study_root(study_id) / "preflight_audits" / f"{audit_name}.json"
|
||||
self.write_json(path, payload)
|
||||
return path
|
||||
|
||||
def materialize_trial(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -40,7 +40,14 @@ from aituner.lca import (
|
||||
resolve_length_mode,
|
||||
similarity_report,
|
||||
)
|
||||
from aituner.llm import _extract_response_text, build_prompt, parse_proposal_text, validate_proposal
|
||||
from aituner.llm import (
|
||||
_extract_response_text,
|
||||
build_initial_config_review_prompt,
|
||||
build_prompt,
|
||||
parse_initial_config_review_text,
|
||||
parse_proposal_text,
|
||||
validate_proposal,
|
||||
)
|
||||
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
||||
from aituner.slo import RequestOutcome, evaluate_request, summarize_evaluations
|
||||
from aituner.spec import (
|
||||
@@ -266,6 +273,62 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertIn("knob_harnesses", prompt)
|
||||
self.assertTrue(study_root.exists())
|
||||
|
||||
def test_initial_config_review_schema_prompt_and_parse(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
payload = json.loads(study_path.read_text(encoding="utf-8"))
|
||||
payload["llm"]["initial_config_review"] = {"mode": "warn"}
|
||||
study_path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
study = load_study_spec(study_path)
|
||||
self.assertEqual(study.llm.initial_config_review.mode, "warn")
|
||||
|
||||
window, requests = load_trace_requests(study, study_spec_path=study_path)
|
||||
prompt = build_initial_config_review_prompt(
|
||||
study=study,
|
||||
window_summary=summarize_window(requests, window),
|
||||
capability_profile={"prefill": "profile"},
|
||||
workload_profile=build_study_workload_profile(study, requests, window),
|
||||
)
|
||||
self.assertIn("pre-flight review", prompt)
|
||||
self.assertIn("knob_descriptors", prompt)
|
||||
self.assertIn("minimal_repair_patch", prompt)
|
||||
|
||||
review = parse_initial_config_review_text(
|
||||
json.dumps(
|
||||
{
|
||||
"verdict": "risky",
|
||||
"issues": [
|
||||
{
|
||||
"knob": "max-num-seqs",
|
||||
"mechanism": "admission_capacity",
|
||||
"reason": "low admission capacity may throttle concurrency",
|
||||
"severity": "high",
|
||||
}
|
||||
],
|
||||
"minimal_repair_patch": {
|
||||
"env_patch": {},
|
||||
"flag_patch": {"max-num-seqs": 64},
|
||||
},
|
||||
"do_not_change": ["tensor-parallel-size"],
|
||||
"confidence": 0.8,
|
||||
"requires_harness_validation": True,
|
||||
}
|
||||
),
|
||||
study,
|
||||
)
|
||||
self.assertEqual(review["verdict"], "risky")
|
||||
self.assertEqual(review["minimal_repair_patch"]["flag_patch"], {"max-num-seqs": 64})
|
||||
self.assertEqual(review["do_not_change"], ["tensor-parallel-size"])
|
||||
|
||||
bad_payload = dict(payload)
|
||||
bad_payload["llm"] = dict(payload["llm"])
|
||||
bad_payload["llm"]["initial_config_review"] = {"mode": "repair"}
|
||||
bad_path = tmp_path / "bad-study.json"
|
||||
bad_path.write_text(json.dumps(bad_payload), encoding="utf-8")
|
||||
with self.assertRaisesRegex(SpecError, "llm.initial_config_review.mode"):
|
||||
load_study_spec(bad_path)
|
||||
|
||||
def test_search_auto_high_schema_is_backward_compatible(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
study_path = _write_study_assets(
|
||||
@@ -7810,6 +7873,98 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertEqual(state.trials[0].config_patch, {"env_patch": {}, "flag_patch": {}})
|
||||
self.assertEqual(state.trials[1].config_patch["flag_patch"], {"max-num-seqs": 64})
|
||||
|
||||
def test_cli_tune_records_warn_initial_config_review_without_repairing_baseline(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
payload = json.loads(study_path.read_text(encoding="utf-8"))
|
||||
payload["llm"]["initial_config_review"] = {"mode": "warn"}
|
||||
payload["llm"]["endpoint"] = {
|
||||
"provider": "custom",
|
||||
"base_url": "http://llm.example/v1",
|
||||
"wire_api": "chat.completions",
|
||||
"model": "test-model",
|
||||
"api_key_env": "OPENAI_API_KEY",
|
||||
}
|
||||
study_path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
store_root = tmp_path / "store"
|
||||
|
||||
def fake_run_trial(trial_spec_path: Path) -> dict[str, object]:
|
||||
payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
|
||||
trial_root = Path(payload["artifact_dir"])
|
||||
result = {
|
||||
"study_id": payload["study_id"],
|
||||
"trial_id": payload["trial_id"],
|
||||
"status": "completed",
|
||||
"best_sampling_u": 0.25,
|
||||
"best_request_rate": 1.0,
|
||||
"best_pass_rate": 1.0,
|
||||
"best_request_count": 2,
|
||||
"probes": [],
|
||||
}
|
||||
(trial_root / "result.json").write_text(json.dumps(result), encoding="utf-8")
|
||||
return result
|
||||
|
||||
audit_payload = json.dumps(
|
||||
{
|
||||
"verdict": "risky",
|
||||
"issues": [
|
||||
{
|
||||
"knob": "max-num-seqs",
|
||||
"mechanism": "admission_capacity",
|
||||
"reason": "initial admission cap may be too low",
|
||||
"severity": "medium",
|
||||
}
|
||||
],
|
||||
"minimal_repair_patch": {
|
||||
"env_patch": {},
|
||||
"flag_patch": {"max-num-seqs": 64},
|
||||
},
|
||||
"do_not_change": ["tensor-parallel-size"],
|
||||
"confidence": 0.7,
|
||||
"requires_harness_validation": True,
|
||||
}
|
||||
)
|
||||
buffer = io.StringIO()
|
||||
with mock.patch("aituner.cli.run_trial", side_effect=fake_run_trial):
|
||||
with mock.patch(
|
||||
"aituner.cli.call_llm_for_initial_config_review",
|
||||
return_value=audit_payload,
|
||||
) as audit_mock:
|
||||
with contextlib.redirect_stdout(buffer):
|
||||
exit_code = cli_main(
|
||||
[
|
||||
"study",
|
||||
"tune",
|
||||
"--spec",
|
||||
str(study_path),
|
||||
"--store-root",
|
||||
str(store_root),
|
||||
"--max-trials",
|
||||
"1",
|
||||
]
|
||||
)
|
||||
self.assertEqual(exit_code, 0)
|
||||
audit_mock.assert_called_once()
|
||||
summary = json.loads(buffer.getvalue())
|
||||
self.assertEqual(summary["preflight_audit"]["status"], "completed")
|
||||
self.assertFalse(summary["preflight_audit"]["repair_applied"])
|
||||
|
||||
store = StudyStore(store_root)
|
||||
state = store.load_state("study-1")
|
||||
self.assertEqual(state.next_trial_index, 2)
|
||||
self.assertEqual(state.trials[0].config_patch, {"env_patch": {}, "flag_patch": {}})
|
||||
audit_dir = store.study_root("study-1") / "preflight_audits"
|
||||
audit = json.loads((audit_dir / "initial-config-0001.json").read_text(encoding="utf-8"))
|
||||
self.assertEqual(audit["status"], "completed")
|
||||
self.assertEqual(audit["review"]["verdict"], "risky")
|
||||
self.assertEqual(
|
||||
audit["review"]["minimal_repair_patch"]["flag_patch"],
|
||||
{"max-num-seqs": 64},
|
||||
)
|
||||
self.assertTrue((audit_dir / "initial-config-0001.prompt.txt").exists())
|
||||
self.assertTrue((audit_dir / "initial-config-0001.raw.txt").exists())
|
||||
|
||||
def test_cli_tune_stops_when_baseline_is_all_infeasible(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
|
||||
Reference in New Issue
Block a user