Harden trial measurement accounting

This commit is contained in:
2026-05-06 21:18:09 +08:00
parent 871c4cfc02
commit c1ff64381d
8 changed files with 366 additions and 16 deletions

View File

@@ -372,6 +372,7 @@ def _aggregate(rows: list[dict[str, Any]], candidates: list[MultiCompareCandidat
if rates_per_gpu
else None,
"mean_pass_rate": (sum(pass_rates) / len(pass_rates)) if pass_rates else None,
**_candidate_result_counts(rows, name),
}
for row in rows:
wins[row["winner"]] = wins.get(row["winner"], 0) + 1
@@ -382,6 +383,26 @@ def _aggregate(rows: list[dict[str, Any]], candidates: list[MultiCompareCandidat
}
def _candidate_result_counts(rows: list[dict[str, Any]], name: str) -> dict[str, int]:
counts = {
"completed_window_count": 0,
"failed_window_count": 0,
"no_feasible_window_count": 0,
}
for row in rows:
result = row.get("candidates", {}).get(name)
if not isinstance(result, dict):
continue
status = str(result.get("status") or "")
if status == "completed":
counts["completed_window_count"] += 1
elif status == "failed":
counts["failed_window_count"] += 1
if not isinstance(result.get("best_request_rate_per_gpu"), (int, float)):
counts["no_feasible_window_count"] += 1
return counts
def _render_report(summary: dict[str, Any], candidates: list[MultiCompareCandidate]) -> str:
candidate_names = [item.name for item in candidates]
lines = [
@@ -413,6 +434,9 @@ def _render_report(summary: dict[str, Any], candidates: list[MultiCompareCandida
lines.append(
f"- `{name}` mean req/s=`{aggregate['mean_request_rate']}`, mean req/s/gpu=`{aggregate['mean_request_rate_per_gpu']}`, mean pass_rate=`{aggregate['mean_pass_rate']}`"
)
lines.append(
f" completed/failed/no-feasible windows=`{aggregate['completed_window_count']}`/`{aggregate['failed_window_count']}`/`{aggregate['no_feasible_window_count']}`"
)
header = ["Window", "Date"]
for name in candidate_names:
header.extend([f"{name} req/s", f"{name} req/s/gpu"])

View File

@@ -382,6 +382,8 @@ def _aggregate_summary(rows: list[dict[str, Any]]) -> dict[str, Any]:
wins = {"baseline": 0, "tuned": 0, "tie": 0, "incomparable": 0}
for row in rows:
wins[row["delta"]["winner"]] += 1
baseline_counts = _candidate_result_counts(rows, "baseline")
tuned_counts = _candidate_result_counts(rows, "tuned")
return {
"window_count": len(rows),
"wins": wins,
@@ -389,9 +391,31 @@ def _aggregate_summary(rows: list[dict[str, Any]]) -> dict[str, Any]:
"tuned_mean_request_rate": _mean_or_none(tuned_rates),
"baseline_mean_request_rate_per_gpu": _mean_or_none(baseline_per_gpu),
"tuned_mean_request_rate_per_gpu": _mean_or_none(tuned_per_gpu),
"baseline_completed_window_count": baseline_counts["completed"],
"baseline_failed_window_count": baseline_counts["failed"],
"baseline_no_feasible_window_count": baseline_counts["no_feasible"],
"tuned_completed_window_count": tuned_counts["completed"],
"tuned_failed_window_count": tuned_counts["failed"],
"tuned_no_feasible_window_count": tuned_counts["no_feasible"],
}
def _candidate_result_counts(rows: list[dict[str, Any]], name: str) -> dict[str, int]:
counts = {"completed": 0, "failed": 0, "no_feasible": 0}
for row in rows:
result = row.get(name)
if not isinstance(result, dict):
continue
status = str(result.get("status") or "")
if status == "completed":
counts["completed"] += 1
elif status == "failed":
counts["failed"] += 1
if not isinstance(result.get("best_request_rate_per_gpu"), (int, float)):
counts["no_feasible"] += 1
return counts
def _mean_or_none(values: list[float]) -> float | None:
if not values:
return None
@@ -417,6 +441,8 @@ def _render_report(summary: dict[str, Any]) -> str:
f"- Tuned mean request rate: `{summary['aggregate']['tuned_mean_request_rate']}`",
f"- Baseline mean request rate per GPU: `{summary['aggregate']['baseline_mean_request_rate_per_gpu']}`",
f"- Tuned mean request rate per GPU: `{summary['aggregate']['tuned_mean_request_rate_per_gpu']}`",
f"- Baseline completed/failed/no-feasible windows: `{summary['aggregate']['baseline_completed_window_count']}`/`{summary['aggregate']['baseline_failed_window_count']}`/`{summary['aggregate']['baseline_no_feasible_window_count']}`",
f"- Tuned completed/failed/no-feasible windows: `{summary['aggregate']['tuned_completed_window_count']}`/`{summary['aggregate']['tuned_failed_window_count']}`/`{summary['aggregate']['tuned_no_feasible_window_count']}`",
"",
"## Per Window",
"",

View File

@@ -240,6 +240,8 @@ class StreamMetrics:
ttft_ms: float | None
tpot_ms: float | None
completion_tokens: int | None
completion_tokens_source: str = "usage"
streamed_chunk_count: int = 0
def stream_chat_completion(
@@ -260,6 +262,7 @@ def stream_chat_completion(
last_token_at: float | None = None
chunk_token_count = 0
completion_tokens: int | None = None
completion_tokens_source = "none"
try:
with _urlopen(request, timeout=timeout_s) as response:
for raw in _iter_sse_lines(response):
@@ -273,6 +276,7 @@ def stream_chat_completion(
comp = usage.get("completion_tokens")
if isinstance(comp, int) and comp >= 0:
completion_tokens = comp
completion_tokens_source = "usage"
choices = payload.get("choices")
if not isinstance(choices, list) or not choices:
continue
@@ -290,7 +294,10 @@ def stream_chat_completion(
detail = exc.read().decode("utf-8", errors="replace")
raise HttpClientError(f"stream_chat_completion failed: {exc.code} {detail}") from exc
ttft_ms = None if first_token_at is None else (first_token_at - start) * 1000.0
used_tokens = completion_tokens if completion_tokens is not None else chunk_token_count
if completion_tokens is None and chunk_token_count > 0:
completion_tokens = chunk_token_count
completion_tokens_source = "stream_chunks"
used_tokens = completion_tokens
if (
first_token_at is None
or last_token_at is None
@@ -304,6 +311,8 @@ def stream_chat_completion(
ttft_ms=ttft_ms,
tpot_ms=tpot_ms,
completion_tokens=used_tokens if used_tokens > 0 else None,
completion_tokens_source=completion_tokens_source,
streamed_chunk_count=chunk_token_count,
)

View File

@@ -15,6 +15,7 @@ class RequestOutcome:
prompt_tokens: int | None
completion_tokens: int | None
error: str = ""
completion_tokens_source: str = ""
@dataclass(frozen=True)

View File

@@ -354,6 +354,33 @@ class TraceSpec:
)
if completion_tokens_override < 0:
raise SpecError("trace.completion_tokens_override must be >= 0.")
max_requests_value = (
_require_int(max_requests, context="trace.max_requests_per_probe")
if max_requests is not None
else None
)
if max_requests_value is not None and max_requests_value <= 0:
raise SpecError("trace.max_requests_per_probe must be > 0.")
synthetic_prompt_cap_value = (
_require_int(
synthetic_prompt_cap,
context="trace.synthetic_prompt_cap_tokens",
)
if synthetic_prompt_cap is not None
else None
)
if synthetic_prompt_cap_value is not None and synthetic_prompt_cap_value < 0:
raise SpecError("trace.synthetic_prompt_cap_tokens must be >= 0.")
replay_time_scale = _require_float(
data.get("replay_time_scale", 1.0), context="trace.replay_time_scale"
)
if replay_time_scale <= 0:
raise SpecError("trace.replay_time_scale must be > 0.")
max_concurrency = _require_int(
data.get("max_concurrency", 64), context="trace.max_concurrency"
)
if max_concurrency <= 0:
raise SpecError("trace.max_concurrency must be > 0.")
return cls(
windows_path=_require_str(data.get("windows_path"), context="trace.windows_path"),
window_id=_require_str(data.get("window_id"), context="trace.window_id"),
@@ -364,9 +391,7 @@ class TraceSpec:
completion_tokens_override=completion_tokens_override,
u_field=str(data.get("u_field") or "sampling_u").strip(),
timestamp_field=str(data.get("timestamp_field") or "timestamp").strip(),
max_concurrency=_require_int(
data.get("max_concurrency", 64), context="trace.max_concurrency"
),
max_concurrency=max_concurrency,
input_length_filter=(
InputLengthFilterSpec.from_dict(
_require_mapping(
@@ -378,13 +403,9 @@ class TraceSpec:
if data.get("input_length_filter") is not None
else None
),
max_requests_per_probe=int(max_requests) if max_requests is not None else None,
synthetic_prompt_cap_tokens=(
int(synthetic_prompt_cap) if synthetic_prompt_cap is not None else None
),
replay_time_scale=_require_float(
data.get("replay_time_scale", 1.0), context="trace.replay_time_scale"
),
max_requests_per_probe=max_requests_value,
synthetic_prompt_cap_tokens=synthetic_prompt_cap_value,
replay_time_scale=replay_time_scale,
early_stop_max_lag_s=(
_require_float(
data.get("early_stop_max_lag_s"), context="trace.early_stop_max_lag_s"

View File

@@ -98,8 +98,7 @@ class StudyStore:
result_path=str(trial_root / "result.json"),
)
self.write_json(trial_root / "trial_spec.json", to_jsonable(spec))
next_state = replace(state, next_trial_index=state.next_trial_index + 1)
next_state.trials.append(
next_trial = (
TrialSummary(
trial_id=trial_id,
status="queued",
@@ -108,6 +107,11 @@ class StudyStore:
config_patch=to_jsonable(proposal.config_patch),
)
)
next_state = replace(
state,
next_trial_index=state.next_trial_index + 1,
trials=[*state.trials, next_trial],
)
self.save_state(next_state)
return spec, next_state

View File

@@ -105,13 +105,49 @@ def _run_one_request(
) -> RequestOutcome:
try:
metrics = stream_chat_completion(base_url=base_url, body=request.body, timeout_s=timeout_s)
expected_completion_tokens = request.completion_tokens_hint
actual_completion_tokens = metrics.completion_tokens
completion_tokens_source = getattr(metrics, "completion_tokens_source", "")
if expected_completion_tokens is not None:
if completion_tokens_source != "usage":
return RequestOutcome(
request_id=request.row_id,
success=False,
ttft_ms=metrics.ttft_ms,
tpot_ms=metrics.tpot_ms,
prompt_tokens=request.prompt_tokens_hint,
completion_tokens=actual_completion_tokens,
error=(
"completion_tokens_unverified "
f"source={completion_tokens_source or 'unknown'} "
f"expected={expected_completion_tokens} "
f"actual={actual_completion_tokens}"
),
completion_tokens_source=completion_tokens_source,
)
if actual_completion_tokens != expected_completion_tokens:
return RequestOutcome(
request_id=request.row_id,
success=False,
ttft_ms=metrics.ttft_ms,
tpot_ms=metrics.tpot_ms,
prompt_tokens=request.prompt_tokens_hint,
completion_tokens=actual_completion_tokens,
error=(
"completion_tokens_mismatch "
f"expected={expected_completion_tokens} "
f"actual={actual_completion_tokens}"
),
completion_tokens_source=completion_tokens_source,
)
return RequestOutcome(
request_id=request.row_id,
success=True,
ttft_ms=metrics.ttft_ms,
tpot_ms=metrics.tpot_ms,
prompt_tokens=request.prompt_tokens_hint,
completion_tokens=metrics.completion_tokens or request.completion_tokens_hint,
completion_tokens=actual_completion_tokens or request.completion_tokens_hint,
completion_tokens_source=completion_tokens_source,
)
except HttpClientError as exc:
return RequestOutcome(
@@ -125,6 +161,53 @@ def _run_one_request(
)
def _probe_outcome_details(
*,
threshold: float,
selected: list[TraceRequest],
outcomes: list[RequestOutcome],
evaluations: list[Any],
early_stopped: bool,
early_stop_reason: str,
) -> dict[str, Any]:
selected_by_id = {request.row_id: request for request in selected}
return {
"threshold": threshold,
"early_stopped": early_stopped,
"early_stop_reason": early_stop_reason,
"outcomes": [
{
"request_id": outcome.request_id,
"sampling_u": (
selected_by_id[outcome.request_id].sampling_u
if outcome.request_id in selected_by_id
else None
),
"arrival_s": (
selected_by_id[outcome.request_id].arrival_s
if outcome.request_id in selected_by_id
else None
),
"success": outcome.success,
"ttft_ms": outcome.ttft_ms,
"tpot_ms": outcome.tpot_ms,
"prompt_tokens": outcome.prompt_tokens,
"expected_completion_tokens": (
selected_by_id[outcome.request_id].completion_tokens_hint
if outcome.request_id in selected_by_id
else None
),
"completion_tokens": outcome.completion_tokens,
"completion_tokens_source": outcome.completion_tokens_source,
"error": outcome.error,
"evaluation": evaluation.passed,
"reasons": evaluation.reasons,
}
for outcome, evaluation in zip(outcomes, evaluations)
],
}
def _replay_requests(
requests: list[TraceRequest],
*,
@@ -340,6 +423,9 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
artifact_dir = Path(trial.artifact_dir)
artifact_dir.mkdir(parents=True, exist_ok=True)
engine_log_path = Path(trial.engine_log_path)
probe_details_path = artifact_dir / "probe_details.jsonl"
if probe_details_path.exists():
probe_details_path.unlink()
with engine_log_path.open("w", encoding="utf-8") as engine_log:
def launch_process() -> subprocess.Popen[str]:
return subprocess.Popen( # noqa: S603
@@ -380,6 +466,18 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
drain_inflight_on_early_stop=not restart_after_early_stop,
)
evaluations, summary = summarize_evaluations(outcomes, study.slo)
probe_details = _probe_outcome_details(
threshold=threshold,
selected=selected,
outcomes=outcomes,
evaluations=evaluations,
early_stopped=early_stopped,
early_stop_reason=early_stop_reason,
)
with probe_details_path.open("a", encoding="utf-8") as details_handle:
details_handle.write(
json.dumps(probe_details, ensure_ascii=False) + "\n"
)
request_rate = (
len(selected) / max(window.window_end - window.window_start, 1e-9)
if selected
@@ -406,6 +504,7 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
"tpot_ms": outcome.tpot_ms,
"prompt_tokens": outcome.prompt_tokens,
"completion_tokens": outcome.completion_tokens,
"completion_tokens_source": outcome.completion_tokens_source,
"evaluation": evaluation.passed,
"reasons": evaluation.reasons,
}

View File

@@ -9,9 +9,9 @@ from pathlib import Path
from unittest import mock
from aituner.cli import main as cli_main
from aituner.compare import load_compare_spec, run_compare
from aituner.compare import _aggregate_summary, load_compare_spec, run_compare
from aituner.engine import build_launch_recipe
from aituner.http_client import _auth_headers, _openai_url, _should_bypass_proxy
from aituner.http_client import StreamMetrics, _auth_headers, _openai_url, _should_bypass_proxy
from aituner.job import append_job, build_trial_job
from aituner.harness import (
build_harness_context,
@@ -34,9 +34,11 @@ from aituner.store import StudyStore
from aituner.trace import load_trace_requests, summarize_window
from aituner.worker import (
_latency_summary,
_run_one_request,
_replay_requests,
_terminate_process_tree,
_wait_for_server_or_exit,
run_trial,
)
from aituner.trace import TraceRequest
@@ -863,6 +865,24 @@ class CoreFlowTests(unittest.TestCase):
with self.assertRaisesRegex(SpecError, "min_input_tokens must be <="):
load_study_spec(study_path)
def test_trace_rejects_non_positive_max_requests_per_probe(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
study_path = _write_study_assets(
Path(tmp),
trace_overrides={"max_requests_per_probe": 0},
)
with self.assertRaisesRegex(SpecError, "max_requests_per_probe must be > 0"):
load_study_spec(study_path)
def test_trace_rejects_invalid_replay_time_scale(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
study_path = _write_study_assets(
Path(tmp),
trace_overrides={"replay_time_scale": 0.0},
)
with self.assertRaisesRegex(SpecError, "replay_time_scale must be > 0"):
load_study_spec(study_path)
def test_decode_only_mode_is_loaded_and_prompt_mentions_it(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
@@ -1456,6 +1476,34 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(requests[2].body["min_tokens"], 1)
self.assertEqual(requests[2].body["max_tokens"], 1)
def test_run_one_request_fails_fixed_length_completion_mismatch(self) -> None:
request = TraceRequest(
row_id="r1",
arrival_s=0.0,
sampling_u=0.1,
body={"model": "m", "messages": [{"role": "user", "content": "x"}]},
prompt_tokens_hint=8,
completion_tokens_hint=2,
)
with mock.patch(
"aituner.worker.stream_chat_completion",
return_value=StreamMetrics(
ttft_ms=10.0,
tpot_ms=5.0,
completion_tokens=1,
),
):
outcome = _run_one_request(
request,
base_url="http://127.0.0.1:8000",
timeout_s=1.0,
)
self.assertFalse(outcome.success)
self.assertEqual(outcome.error, "completion_tokens_mismatch expected=2 actual=1")
self.assertEqual(outcome.completion_tokens, 1)
def test_build_prompt_mentions_completion_tokens_override(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
study_path = _write_study_assets(
@@ -1950,6 +1998,86 @@ class CoreFlowTests(unittest.TestCase):
3.125,
)
def test_run_trial_persists_probe_request_details(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
payload = json.loads(study_path.read_text(encoding="utf-8"))
payload["search"]["max_probes"] = 1
study_path.write_text(json.dumps(payload), encoding="utf-8")
study = load_study_spec(study_path)
store = StudyStore(tmp_path / ".aituner" / "studies")
store.init_study(spec_path=study_path, study=study)
state = store.load_state(study.study_id)
proposal = Proposal.from_dict(
{
"observation": "baseline",
"diagnosis": "baseline",
"config_patch": {"env_patch": {}, "flag_patch": {}},
"expected_effects": ["measure"],
}
)
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
def fake_replay(requests, **kwargs):
return (
[
RequestOutcome(
request_id=request.row_id,
success=True,
ttft_ms=10.0,
tpot_ms=5.0,
prompt_tokens=request.prompt_tokens_hint,
completion_tokens=request.completion_tokens_hint,
)
for request in requests
],
False,
"",
)
process = mock.Mock()
process.poll.return_value = 0
with mock.patch("aituner.worker.subprocess.Popen", return_value=process):
with mock.patch("aituner.worker._wait_for_server_or_exit", return_value=None):
with mock.patch("aituner.worker._terminate_process_tree", return_value=None):
with mock.patch("aituner.worker._replay_requests", side_effect=fake_replay):
result = run_trial(Path(trial.artifact_dir) / "trial_spec.json")
self.assertEqual(result["status"], "completed")
details_path = Path(trial.artifact_dir) / "probe_details.jsonl"
self.assertTrue(details_path.exists())
rows = [
json.loads(line)
for line in details_path.read_text(encoding="utf-8").splitlines()
]
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["threshold"], 0.5)
self.assertEqual(rows[0]["outcomes"][0]["request_id"], "r1")
self.assertEqual(rows[0]["outcomes"][0]["sampling_u"], 0.1)
def test_materialize_trial_does_not_mutate_input_state_trials(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
study = load_study_spec(study_path)
store = StudyStore(tmp_path / ".aituner" / "studies")
store.init_study(spec_path=study_path, study=study)
state = store.load_state(study.study_id)
proposal = Proposal.from_dict(
{
"observation": "baseline",
"diagnosis": "baseline",
"config_patch": {"env_patch": {}, "flag_patch": {}},
"expected_effects": ["measure"],
}
)
_, next_state = store.materialize_trial(study=study, state=state, proposal=proposal)
self.assertEqual(state.trials, [])
self.assertEqual(len(next_state.trials), 1)
def test_materialize_trial_uses_incumbent_sampling_u_as_search_floor(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
@@ -2969,6 +3097,44 @@ class CoreFlowTests(unittest.TestCase):
self.assertTrue((tmp_path / ".compare" / "summary.json").exists())
self.assertTrue((tmp_path / ".compare" / "report.md").exists())
def test_compare_aggregate_counts_failed_and_no_feasible_windows(self) -> None:
summary = _aggregate_summary(
[
{
"baseline": {
"status": "completed",
"best_request_rate": 1.0,
"best_request_rate_per_gpu": 1.0,
},
"tuned": {
"status": "completed",
"best_request_rate": None,
"best_request_rate_per_gpu": None,
},
"delta": {"winner": "baseline"},
},
{
"baseline": {
"status": "failed",
"best_request_rate": None,
"best_request_rate_per_gpu": None,
},
"tuned": {
"status": "completed",
"best_request_rate": 2.0,
"best_request_rate_per_gpu": 2.0,
},
"delta": {"winner": "tuned"},
},
]
)
self.assertEqual(summary["baseline_completed_window_count"], 1)
self.assertEqual(summary["baseline_failed_window_count"], 1)
self.assertEqual(summary["baseline_no_feasible_window_count"], 1)
self.assertEqual(summary["tuned_completed_window_count"], 2)
self.assertEqual(summary["tuned_failed_window_count"], 0)
self.assertEqual(summary["tuned_no_feasible_window_count"], 1)
def test_run_compare_resolves_trial_ref_candidate(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)