Add auto search high measurement policy

This commit is contained in:
2026-06-26 20:05:22 +08:00
parent 95ad124a1b
commit 1dd3eaebaa
5 changed files with 415 additions and 27 deletions

View File

@@ -426,6 +426,9 @@ def _recent_trial_diagnostics(state: StudyState) -> list[dict[str, Any]]:
}
result = _load_result(trial)
if result:
measurement = result.get("measurement")
if isinstance(measurement, dict):
item["measurement_evidence"] = measurement
probes = result.get("probes")
if isinstance(probes, list) and probes:
best_probe = _best_feasible_probe(probes)
@@ -785,11 +788,19 @@ def _harness_stop_decision(
experiment_plan: dict[str, Any] | None = None,
) -> dict[str, Any]:
high_saturation = _search_high_saturation_guard(study, state, recent_diagnostics)
if high_saturation["saturated"]:
if high_saturation["saturated"] and _parallel_size_can_vary(study):
return {
"should_stop": True,
"reason": high_saturation["reason"],
"evidence": high_saturation,
"should_stop": False,
"reason": "search_high_saturation_requires_parallel_size_evidence",
"evidence": {
"summary": (
"search_high_saturation is measurement evidence only; "
"request_rate_per_gpu studies with variable topology need a "
"parallel-size/topology comparison before stop can be authorized."
),
"objective": "request_rate_per_gpu",
"search_high_saturation": high_saturation,
},
}
topology_frontier = _topology_frontier_status(study, state, recent_diagnostics)
if topology_frontier["frontier_open"]:
@@ -1737,6 +1748,40 @@ def _effective_gpu_count(study: StudySpec) -> int:
return min(study.hardware.gpu_count, len(devices))
def _parallel_size_can_vary(study: StudySpec) -> bool:
tunable = set(study.engine.tunable_flags)
if not ({"tensor-parallel-size", "data-parallel-size"} & tunable):
return False
effective_gpu_count = _effective_gpu_count(study)
if effective_gpu_count <= 1:
return False
constraints = study.engine.topology_constraints
if constraints is not None and constraints.allowed_tp_dp_products:
legal_products = {
item for item in constraints.allowed_tp_dp_products if item <= effective_gpu_count
}
return len(legal_products) > 1
if constraints is not None:
tp_values = (
constraints.allowed_tensor_parallel_sizes
if constraints.allowed_tensor_parallel_sizes
else [1, 2, 4, 8]
)
dp_values = (
constraints.allowed_data_parallel_sizes
if constraints.allowed_data_parallel_sizes
else [1]
)
products = {
int(tp) * int(dp)
for tp in tp_values
for dp in dp_values
if int(tp) > 0 and int(dp) > 0 and int(tp) * int(dp) <= effective_gpu_count
}
return len(products) > 1
return True
def _score_topology_candidate(
top_bottleneck: str,
bottleneck_hypotheses: list[dict[str, Any]],
@@ -2087,6 +2132,7 @@ def _search_high_saturation_guard(
"search_high": study.search.high,
"last_threshold": None,
"threshold_gap_to_high": None,
"measurement_evidence": None,
}
if not state.best_trial_id:
return default
@@ -2115,11 +2161,23 @@ def _search_high_saturation_guard(
**default,
"reason": "incumbent_last_probe_missing",
}
measurement_evidence = (
incumbent.get("measurement_evidence")
if isinstance(incumbent.get("measurement_evidence"), dict)
else None
)
search_high = (
_as_float(measurement_evidence.get("search_high"))
if isinstance(measurement_evidence, dict)
and isinstance(measurement_evidence.get("search_high"), (int, float))
else float(study.search.high)
)
last_threshold = _as_float(last_probe.get("threshold"))
threshold_gap = float(study.search.high) - last_threshold
threshold_gap = search_high - last_threshold
binary_probe_resolution = max(
float(study.search.tolerance),
(float(study.search.high) - float(study.search.low)) / float(2 ** max(study.search.max_probes, 1)),
(search_high - float(study.search.low))
/ float(2 ** max(study.search.max_probes, 1)),
)
if not last_probe.get("feasible"):
return {
@@ -2136,18 +2194,30 @@ def _search_high_saturation_guard(
"threshold_gap_to_high": threshold_gap,
"binary_probe_resolution": binary_probe_resolution,
}
reason = "search_high_saturated_by_incumbent"
summary = (
"The incumbent's highest measured probe is feasible and is within "
"the configured binary-search resolution of search.high."
)
if (
isinstance(measurement_evidence, dict)
and measurement_evidence.get("measurement_ceiling_insufficient")
):
reason = "measurement_ceiling_insufficient"
summary = (
"The incumbent saturated the available trace measurement ceiling; "
"this is insufficient measurement evidence, not stop authorization."
)
return {
"saturated": True,
"reason": "search_high_saturated_by_incumbent",
"summary": (
"The incumbent's highest measured probe is feasible and is within "
"the configured binary-search resolution of search.high."
),
"reason": reason,
"summary": summary,
"incumbent_trial_id": state.best_trial_id,
"search_high": study.search.high,
"search_high": search_high,
"last_threshold": last_threshold,
"threshold_gap_to_high": threshold_gap,
"binary_probe_resolution": binary_probe_resolution,
"measurement_evidence": measurement_evidence,
}

View File

@@ -585,6 +585,42 @@ class SloSpec:
)
@dataclass(frozen=True)
class SearchAutoHighSpec:
enabled: bool = False
max_sampling_u: float = 1.0
require_human_confirmation_beyond_trace: bool = True
@classmethod
def from_dict(cls, data: Any) -> "SearchAutoHighSpec":
if data is None:
return cls()
m = _require_mapping(data, context="search.auto_high")
enabled = (
_require_bool(m.get("enabled"), context="search.auto_high.enabled")
if m.get("enabled") is not None
else False
)
max_sampling_u = _require_float(
m.get("max_sampling_u", 1.0), context="search.auto_high.max_sampling_u"
)
if not 0.0 < max_sampling_u <= 1.0:
raise SpecError("search.auto_high.max_sampling_u must be in (0, 1].")
require_confirmation = (
_require_bool(
m.get("require_human_confirmation_beyond_trace"),
context="search.auto_high.require_human_confirmation_beyond_trace",
)
if m.get("require_human_confirmation_beyond_trace") is not None
else True
)
return cls(
enabled=enabled,
max_sampling_u=max_sampling_u,
require_human_confirmation_beyond_trace=require_confirmation,
)
@dataclass(frozen=True)
class SamplingSearchSpec:
low: float
@@ -593,16 +629,27 @@ class SamplingSearchSpec:
max_probes: int
sample_seed: int
inherit_incumbent_floor: bool = False
auto_high: SearchAutoHighSpec = field(default_factory=SearchAutoHighSpec)
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "SamplingSearchSpec":
low = _require_float(data.get("low", 0.0), context="search.low")
high = _require_float(data.get("high", 1.0), context="search.high")
tolerance = _require_float(data.get("tolerance", 0.01), context="search.tolerance")
max_probes = _require_int(data.get("max_probes", 8), context="search.max_probes")
if low < 0:
raise SpecError("search.low must be >= 0.")
if high < low:
raise SpecError("search.high must be >= search.low.")
if tolerance <= 0:
raise SpecError("search.tolerance must be > 0.")
if max_probes <= 0:
raise SpecError("search.max_probes must be > 0.")
return cls(
low=_require_float(data.get("low", 0.0), context="search.low"),
high=_require_float(data.get("high", 1.0), context="search.high"),
tolerance=_require_float(
data.get("tolerance", 0.01), context="search.tolerance"
),
max_probes=_require_int(data.get("max_probes", 8), context="search.max_probes"),
low=low,
high=high,
tolerance=tolerance,
max_probes=max_probes,
sample_seed=_require_int(
data.get("sample_seed", 20260325), context="search.sample_seed"
),
@@ -610,6 +657,7 @@ class SamplingSearchSpec:
data.get("inherit_incumbent_floor", False),
context="search.inherit_incumbent_floor",
),
auto_high=SearchAutoHighSpec.from_dict(data.get("auto_high")),
)
@@ -823,6 +871,7 @@ class TrialSpec:
probe_log_path: str
engine_log_path: str
result_path: str
search_evidence: dict[str, Any] = field(default_factory=dict)
@dataclass

View File

@@ -5,7 +5,16 @@ from dataclasses import replace
from pathlib import Path
from typing import Any
from .spec import ConfigPatch, Proposal, StudySpec, StudyState, TrialSpec, TrialSummary, to_jsonable
from .spec import (
ConfigPatch,
Proposal,
SamplingSearchSpec,
StudySpec,
StudyState,
TrialSpec,
TrialSummary,
to_jsonable,
)
_TOPOLOGY_FLAG_KEYS = {
@@ -95,6 +104,13 @@ class StudyStore:
parallel_size=parallel_size,
),
)
search, search_evidence = resolve_auto_high_search(
search=search,
sampling_us=_sampling_us_for_study_source(
study=study,
study_spec_source_path=self.study_root(study.study_id) / "study_spec.source",
),
)
spec = TrialSpec(
study_id=study.study_id,
trial_id=trial_id,
@@ -105,6 +121,7 @@ class StudyStore:
probe_log_path=str(trial_root / "probe_history.json"),
engine_log_path=str(trial_root / "engine.log"),
result_path=str(trial_root / "result.json"),
search_evidence=search_evidence,
)
self.write_json(trial_root / "trial_spec.json", to_jsonable(spec))
next_trial = (
@@ -323,3 +340,55 @@ def _derive_search_floor(*, study: StudySpec, state: StudyState, parallel_size:
else:
candidate = low
return min(high, max(low, candidate))
def _sampling_us_for_study_source(
*,
study: StudySpec,
study_spec_source_path: Path,
) -> list[float]:
if not study.search.auto_high.enabled:
return []
from .trace import load_trace_requests
study_spec_path = Path(study_spec_source_path.read_text(encoding="utf-8").strip())
_, requests = load_trace_requests(study, study_spec_path=study_spec_path)
return [float(request.sampling_u) for request in requests]
def resolve_auto_high_search(
*,
search: SamplingSearchSpec,
sampling_us: list[float],
) -> tuple[SamplingSearchSpec, dict[str, Any]]:
policy = search.auto_high
trace_max_sampling_u = max(sampling_us) if sampling_us else None
evidence = {
"enabled": policy.enabled,
"original_high": search.high,
"effective_high": search.high,
"trace_max_sampling_u": trace_max_sampling_u,
"max_sampling_u": policy.max_sampling_u,
"require_human_confirmation_beyond_trace": (
policy.require_human_confirmation_beyond_trace
),
"reason": "auto_high_disabled",
}
if not policy.enabled:
return search, evidence
if trace_max_sampling_u is None:
evidence["reason"] = "trace_has_no_sampling_u"
return search, evidence
ceiling = min(float(policy.max_sampling_u), 1.0, float(trace_max_sampling_u))
evidence["effective_ceiling"] = ceiling
if abs(float(search.high) - ceiling) <= 1e-12:
evidence["reason"] = "search_high_already_at_auto_high_ceiling"
return search, evidence
updated = replace(search, high=ceiling)
evidence["effective_high"] = updated.high
evidence["reason"] = (
"search_high_raised_to_trace_ceiling"
if float(search.high) < ceiling
else "search_high_lowered_to_trace_ceiling"
)
return updated, evidence

View File

@@ -96,6 +96,7 @@ def _trial_spec_from_json(path: Path) -> TrialSpec:
probe_log_path=str(payload["probe_log_path"]),
engine_log_path=str(payload["engine_log_path"]),
result_path=str(payload["result_path"]),
search_evidence=dict(payload.get("search_evidence") or {}),
)
@@ -355,6 +356,59 @@ def _best_feasible_probe_record(probe_history: list[dict[str, Any]]) -> dict[str
return max(feasible, key=lambda item: float(item["request_rate"]))
def _binary_probe_resolution(search: SamplingSearchSpec) -> float:
return max(
float(search.tolerance),
(float(search.high) - float(search.low)) / float(2 ** max(search.max_probes, 1)),
)
def _measurement_ceiling_evidence(
*,
search: SamplingSearchSpec,
requests: list[TraceRequest],
best_threshold: float | None,
best_payload: ProbePayload | None,
) -> dict[str, Any]:
trace_max_sampling_u = max((float(request.sampling_u) for request in requests), default=None)
policy = search.auto_high
evidence: dict[str, Any] = {
"auto_high": {
"enabled": policy.enabled,
"max_sampling_u": policy.max_sampling_u,
"require_human_confirmation_beyond_trace": (
policy.require_human_confirmation_beyond_trace
),
},
"search_high": search.high,
"trace_max_sampling_u": trace_max_sampling_u,
"measurement_ceiling_insufficient": False,
"reason": "measurement_ceiling_not_reached",
}
if trace_max_sampling_u is None:
evidence["reason"] = "trace_has_no_requests"
return evidence
if best_threshold is None or best_payload is None:
evidence["reason"] = "no_feasible_probe"
return evidence
resolution = _binary_probe_resolution(search)
threshold_gap_to_high = float(search.high) - float(best_threshold)
evidence["best_sampling_u"] = best_threshold
evidence["best_request_count"] = best_payload.request_count
evidence["threshold_gap_to_high"] = threshold_gap_to_high
evidence["binary_probe_resolution"] = resolution
full_trace_selected = best_payload.request_count >= len(requests)
high_reaches_trace = float(search.high) + 1e-12 >= float(trace_max_sampling_u)
if (
full_trace_selected
and high_reaches_trace
and threshold_gap_to_high <= resolution + 1e-12
):
evidence["measurement_ceiling_insufficient"] = True
evidence["reason"] = "measurement_ceiling_insufficient"
return evidence
def _replay_requests(
requests: list[TraceRequest],
*,
@@ -822,11 +876,19 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
*primary_search.probes,
*((fallback_search.probes if fallback_search is not None else [])),
]
measurement = _measurement_ceiling_evidence(
search=trial.search,
requests=requests,
best_threshold=search_for_best.best_threshold if best is not None else None,
best_payload=best,
)
measurement["auto_high_resolution"] = trial.search_evidence
result = {
"study_id": trial.study_id,
"trial_id": trial.trial_id,
"status": "completed",
"config_patch": to_jsonable(trial.config_patch),
"measurement": measurement,
"best_source": best_source,
"best_sampling_u": search_for_best.best_threshold if best is not None else None,
"best_request_rate": best.request_rate if best is not None else None,

View File

@@ -51,7 +51,7 @@ from aituner.spec import (
TrialSummary,
load_study_spec,
)
from aituner.store import StudyStore
from aituner.store import StudyStore, resolve_auto_high_search
from aituner.trace import load_trace_requests, summarize_window
from aituner.worker import (
_adaptive_replay_set,
@@ -79,6 +79,7 @@ def _write_study_assets(
trace_overrides: dict[str, object] | None = None,
slo_overrides: dict[str, object] | None = None,
engine_overrides: dict[str, object] | None = None,
search_overrides: dict[str, object] | None = None,
) -> Path:
trace_dir = tmp_path / "trace_windows" / "traces"
trace_dir.mkdir(parents=True)
@@ -196,6 +197,8 @@ def _write_study_assets(
study_payload["slo"].update(slo_overrides)
if engine_overrides:
study_payload["engine"].update(engine_overrides)
if search_overrides:
study_payload["search"].update(search_overrides)
study_path.write_text(json.dumps(study_payload), encoding="utf-8")
return study_path
@@ -260,6 +263,76 @@ class CoreFlowTests(unittest.TestCase):
self.assertIn("knob_harnesses", prompt)
self.assertTrue(study_root.exists())
def test_search_auto_high_schema_is_backward_compatible(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
study_path = _write_study_assets(
Path(tmp),
search_overrides={"high": 0.4},
)
study = load_study_spec(study_path)
self.assertFalse(study.search.auto_high.enabled)
updated, evidence = resolve_auto_high_search(
search=study.search,
sampling_us=[0.1, 0.9],
)
self.assertEqual(updated.high, 0.4)
self.assertEqual(evidence["reason"], "auto_high_disabled")
def test_search_auto_high_caps_at_policy_and_trace(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
study_path = _write_study_assets(
Path(tmp),
search_overrides={
"high": 0.2,
"auto_high": {
"enabled": True,
"max_sampling_u": 0.8,
"require_human_confirmation_beyond_trace": True,
},
},
)
study = load_study_spec(study_path)
capped_by_policy, policy_evidence = resolve_auto_high_search(
search=study.search,
sampling_us=[0.1, 0.9],
)
self.assertEqual(capped_by_policy.high, 0.8)
self.assertEqual(
policy_evidence["reason"],
"search_high_raised_to_trace_ceiling",
)
capped_by_trace, trace_evidence = resolve_auto_high_search(
search=study.search,
sampling_us=[0.1, 0.7],
)
self.assertEqual(capped_by_trace.high, 0.7)
self.assertEqual(trace_evidence["effective_ceiling"], 0.7)
high_search = study.search.__class__.from_dict(
{
"low": 0.0,
"high": 0.95,
"tolerance": study.search.tolerance,
"max_probes": study.search.max_probes,
"sample_seed": study.search.sample_seed,
"auto_high": {
"enabled": True,
"max_sampling_u": 0.8,
"require_human_confirmation_beyond_trace": True,
},
}
)
lowered, lowered_evidence = resolve_auto_high_search(
search=high_search,
sampling_us=[0.1, 0.9],
)
self.assertEqual(lowered.high, 0.8)
self.assertEqual(
lowered_evidence["reason"],
"search_high_lowered_to_trace_ceiling",
)
def test_lca_workload_profile_uses_standard_10d_features(self) -> None:
window = WindowRecord(
window_id="w1",
@@ -1381,11 +1454,17 @@ class CoreFlowTests(unittest.TestCase):
window_summary={"prompt_tokens_p95": 2048},
state=state,
)
self.assertTrue(context["harness_stop"]["should_stop"])
self.assertEqual(context["harness_stop"]["reason"], "search_high_saturated_by_incumbent")
self.assertFalse(context["harness_stop"]["should_stop"])
self.assertEqual(
context["harness_stop"]["reason"],
"search_high_saturation_requires_parallel_size_evidence",
)
self.assertEqual(
context["harness_stop"]["evidence"]["objective"],
"request_rate_per_gpu",
)
proposal = build_harness_stop_proposal(context)
self.assertIsNotNone(proposal)
self.assertTrue(proposal.should_stop)
self.assertIsNone(proposal)
def test_harness_stop_allows_feasible_high_probe_with_some_failures(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
@@ -1446,8 +1525,11 @@ class CoreFlowTests(unittest.TestCase):
window_summary={"prompt_tokens_p95": 2048},
state=state,
)
self.assertTrue(context["harness_stop"]["should_stop"])
self.assertEqual(context["harness_stop"]["reason"], "search_high_saturated_by_incumbent")
self.assertFalse(context["harness_stop"]["should_stop"])
self.assertEqual(
context["harness_stop"]["reason"],
"search_high_saturation_requires_parallel_size_evidence",
)
def test_harness_guided_first_tp_probe_for_latency_bottleneck(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
@@ -4498,7 +4580,9 @@ class CoreFlowTests(unittest.TestCase):
with mock.patch("aituner.worker._wait_for_server_or_exit", return_value=None):
with mock.patch("aituner.worker._terminate_process_tree", return_value=None):
with mock.patch("aituner.worker._replay_requests", side_effect=fake_replay):
result = run_trial(Path(trial.artifact_dir) / "trial_spec.json")
result = run_trial(
Path(trial.artifact_dir) / "trial_spec.json"
)
self.assertEqual(result["status"], "completed")
details_path = Path(trial.artifact_dir) / "probe_details.jsonl"
@@ -4512,6 +4596,60 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(rows[0]["outcomes"][0]["request_id"], "r1")
self.assertEqual(rows[0]["outcomes"][0]["sampling_u"], 0.1)
def test_run_trial_marks_full_trace_saturation_as_measurement_ceiling_insufficient(
self,
) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
study = load_study_spec(study_path)
store = StudyStore(tmp_path / ".aituner" / "studies")
store.init_study(spec_path=study_path, study=study)
state = store.load_state(study.study_id)
proposal = Proposal.from_dict(
{
"observation": "baseline",
"diagnosis": "baseline",
"config_patch": {"env_patch": {}, "flag_patch": {}},
"expected_effects": ["measure"],
}
)
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
def fake_replay(requests, **kwargs):
return (
[
RequestOutcome(
request_id=request.row_id,
success=True,
ttft_ms=10.0,
tpot_ms=5.0,
prompt_tokens=request.prompt_tokens_hint,
completion_tokens=request.completion_tokens_hint,
)
for request in requests
],
False,
"",
)
process = mock.Mock()
process.poll.return_value = 0
with mock.patch("aituner.worker.subprocess.Popen", return_value=process):
with mock.patch("aituner.worker._wait_for_server_or_exit", return_value=None):
with mock.patch("aituner.worker._terminate_process_tree", return_value=None):
with mock.patch(
"aituner.worker._replay_requests",
side_effect=fake_replay,
):
result = run_trial(Path(trial.artifact_dir) / "trial_spec.json")
self.assertEqual(result["status"], "completed")
self.assertEqual(result["best_request_count"], 3)
self.assertTrue(result["measurement"]["measurement_ceiling_insufficient"])
self.assertEqual(result["measurement"]["reason"], "measurement_ceiling_insufficient")
self.assertIn("auto_high_resolution", result["measurement"])
def test_run_trial_falls_back_below_inherited_search_floor(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)