Add Stop-A: offered-L-C-A convergence early-stop for replay

Phase 2 of the two-stop work. The L-C-A vector is a deterministic function of the
trace's offered metadata, so the convergence of prefix-vs-full L-C-A (the paper's
Fig. 9 curve) can be computed up front rather than monitored live, with identical
result and no per-request overhead.

- lca.find_convergence_prefix: earliest arrival-ordered prefix whose L and A family
  similarities reach tau and the slow C family reaches the stricter tau_c for
  stable_checks consecutive checkpoints. Self-similarity uses the raw log-feature
  vector (same window -> identical per-dim spread; RobustScaler is reserved for the
  cross-window Stop-C). If C never converges it reports the full set, which is the
  C-gate: no early stop on a cold/under-warmed cache. The checkpoint sims double as
  Phase 3 calibration data.
- spec.AdaptiveStopSpec (trace.adaptive_stop), disabled by default until the
  thresholds are calibrated, so existing studies are unaffected.
- worker._adaptive_replay_set truncates each probe's replay to the convergence
  prefix and records a certificate (converged, fraction, family similarity) into
  probe history and probe_details. Offered request_rate at the threshold is
  unchanged; only wall-clock replay shrinks.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 14:23:49 +08:00
parent 0f15bbc3f1
commit 51a9e4a007
4 changed files with 379 additions and 2 deletions

View File

@@ -259,6 +259,151 @@ def similarity_report(profiles: Sequence[WorkloadProfile]) -> dict[str, Any]:
}
@dataclass(frozen=True)
class ConvergencePoint:
converged: bool
stop_index: int
stop_time_s: float
fraction: float
family_similarity: dict[str, float]
checks: list[dict[str, Any]]
def to_dict(self) -> dict[str, Any]:
return {
"converged": self.converged,
"stop_index": self.stop_index,
"stop_time_s": self.stop_time_s,
"fraction": self.fraction,
"family_similarity": self.family_similarity,
"checks": self.checks,
}
def find_convergence_prefix(
requests: list[TraceRequest],
window: WindowRecord,
*,
gpu_count: int,
length_mode: str = "total",
tau: float = 0.9,
tau_c: float = 0.92,
stable_checks: int = 3,
max_checks: int = 20,
min_fraction: float = 0.1,
) -> ConvergencePoint:
"""Earliest arrival-ordered prefix whose offered L-C-A converges to the full set.
The L-C-A vector is a deterministic function of the trace metadata, so the
convergence of prefix-vs-full is itself deterministic (the paper's Fig. 9
curve). Stop-A replays only up to this prefix. A prefix counts as converged
when the L and A family similarities reach ``tau`` and the (slowest) C family
similarity reaches the stricter ``tau_c`` for ``stable_checks`` consecutive
checkpoints. If that never happens within the window the point reports the
full set (converged=False), which keeps the C-gate honest: an unconverged C
means the probe must replay the whole window rather than stop early.
"""
total = len(requests)
if total == 0:
return ConvergencePoint(
converged=False,
stop_index=0,
stop_time_s=0.0,
fraction=1.0,
family_similarity={"L": 1.0, "C": 1.0, "A": 1.0},
checks=[],
)
# Compare each arrival-ordered prefix to the whole set, both measured over
# their own elapsed span so the A (rate) dimension is comparable rather than
# diluted by the fixed window length.
target = _prefix_profile(
requests, total, window, gpu_count=gpu_count, length_mode=length_mode
)
indices = _checkpoint_indices(
total, max_checks=max_checks, min_fraction=min_fraction
)
checks: list[dict[str, Any]] = []
consecutive = 0
converged_index: int | None = None
converged_sims: dict[str, float] | None = None
for index in indices:
prefix = _prefix_profile(
requests, index, window, gpu_count=gpu_count, length_mode=length_mode
)
sims = _family_similarity(target.vector, prefix.vector)
checks.append(
{
"index": index,
"fraction": float(index / total),
"time_s": float(requests[index - 1].arrival_s),
"family_similarity": sims,
}
)
passed = sims["L"] >= tau and sims["A"] >= tau and sims["C"] >= tau_c
consecutive = consecutive + 1 if passed else 0
if consecutive >= stable_checks and converged_index is None:
converged_index = index
converged_sims = sims
break
if converged_index is None:
last_sims = checks[-1]["family_similarity"] if checks else {"L": 1.0, "C": 1.0, "A": 1.0}
return ConvergencePoint(
converged=False,
stop_index=total,
stop_time_s=float(requests[-1].arrival_s),
fraction=1.0,
family_similarity=last_sims,
checks=checks,
)
return ConvergencePoint(
converged=True,
stop_index=converged_index,
stop_time_s=float(requests[converged_index - 1].arrival_s),
fraction=float(converged_index / total),
family_similarity=converged_sims or {},
checks=checks,
)
def _prefix_profile(
requests: list[TraceRequest],
index: int,
window: WindowRecord,
*,
gpu_count: int,
length_mode: str,
) -> WorkloadProfile:
prefix = requests[:index]
end = float(prefix[-1].arrival_s) if prefix else float(window.window_start)
prefix_window = WindowRecord(
window_id=window.window_id,
trace_path=window.trace_path,
trace_type=window.trace_type,
window_start=window.window_start,
window_end=end,
source_payload=window.source_payload,
)
return build_workload_profile(
prefix, prefix_window, gpu_count=gpu_count, length_mode=length_mode
)
def _checkpoint_indices(total: int, *, max_checks: int, min_fraction: float) -> list[int]:
start = max(1, int(math.ceil(min_fraction * total)))
if total <= max_checks:
candidates = range(start, total + 1)
else:
step = max(1, total // max_checks)
candidates = list(range(start, total + 1, step))
if candidates and candidates[-1] != total:
candidates.append(total)
seen: list[int] = []
for value in candidates:
clamped = min(total, max(1, int(value)))
if not seen or seen[-1] != clamped:
seen.append(clamped)
return seen
def dumps_profile(profile: WorkloadProfile) -> str:
return json.dumps(profile.to_dict(), ensure_ascii=False, indent=2) + "\n"

View File

@@ -321,6 +321,59 @@ class InputLengthFilterSpec:
return spec
@dataclass(frozen=True)
class AdaptiveStopSpec:
"""Stop-A: truncate per-probe replay once the offered L-C-A converges.
Disabled by default; the thresholds are calibrated per workload (Phase 3)
before being switched on, so existing studies are unaffected.
"""
enabled: bool = False
tau: float = 0.9
tau_c: float = 0.92
stable_checks: int = 3
max_checks: int = 20
min_fraction: float = 0.1
@classmethod
def from_dict(cls, data: Any) -> "AdaptiveStopSpec":
if data is None:
return cls()
m = _require_mapping(data, context="trace.adaptive_stop")
enabled = (
_require_bool(m.get("enabled"), context="trace.adaptive_stop.enabled")
if m.get("enabled") is not None
else False
)
tau = _require_float(m.get("tau", 0.9), context="trace.adaptive_stop.tau")
tau_c = _require_float(m.get("tau_c", 0.92), context="trace.adaptive_stop.tau_c")
stable_checks = _require_int(
m.get("stable_checks", 3), context="trace.adaptive_stop.stable_checks"
)
max_checks = _require_int(
m.get("max_checks", 20), context="trace.adaptive_stop.max_checks"
)
min_fraction = _require_float(
m.get("min_fraction", 0.1), context="trace.adaptive_stop.min_fraction"
)
for name, value in (("tau", tau), ("tau_c", tau_c), ("min_fraction", min_fraction)):
if not 0.0 < value <= 1.0:
raise SpecError(f"trace.adaptive_stop.{name} must be in (0, 1].")
if stable_checks <= 0 or max_checks <= 0:
raise SpecError(
"trace.adaptive_stop.stable_checks and max_checks must be > 0."
)
return cls(
enabled=enabled,
tau=tau,
tau_c=tau_c,
stable_checks=stable_checks,
max_checks=max_checks,
min_fraction=min_fraction,
)
@dataclass(frozen=True)
class TraceSpec:
windows_path: str
@@ -338,6 +391,7 @@ class TraceSpec:
early_stop_max_lag_s: float | None = None
early_stop_max_elapsed_s: float | None = None
restart_engine_after_early_stop: bool = False
adaptive_stop: AdaptiveStopSpec = AdaptiveStopSpec()
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
@@ -429,6 +483,7 @@ class TraceSpec:
if data.get("restart_engine_after_early_stop") is not None
else request_mode == "decode_only"
),
adaptive_stop=AdaptiveStopSpec.from_dict(data.get("adaptive_stop")),
)

View File

@@ -16,6 +16,7 @@ from typing import Any, Callable
from .engine import build_launch_recipe
from .http_client import HttpClientError, stream_chat_completion, wait_for_server
from .lca import find_convergence_prefix, resolve_length_mode
from .search import ThresholdProbe, binary_search_max_feasible
from .slo import RequestOutcome, evaluate_request, summarize_evaluations
from .spec import ConfigPatch, SamplingSearchSpec, TrialSpec, load_study_spec, to_jsonable
@@ -209,6 +210,45 @@ def _probe_outcome_details(
}
def _adaptive_replay_set(
selected: list[TraceRequest],
*,
study: Any,
window: Any,
) -> tuple[list[TraceRequest], dict[str, Any] | None]:
"""Stop-A: truncate the replay to the offered-L-C-A convergence prefix.
Returns the (possibly shortened) request list to replay and a certificate of
the convergence decision. When Stop-A is disabled, or C never converges, the
full selected set is replayed (the C-gate: no early stop on a cold cache).
"""
spec = study.trace.adaptive_stop
if not getattr(spec, "enabled", False) or not selected:
return selected, None
point = find_convergence_prefix(
selected,
window,
gpu_count=study.hardware.gpu_count,
length_mode=resolve_length_mode(request_mode=study.trace.request_mode),
tau=spec.tau,
tau_c=spec.tau_c,
stable_checks=spec.stable_checks,
max_checks=spec.max_checks,
min_fraction=spec.min_fraction,
)
replay = selected[: point.stop_index] if point.stop_index > 0 else selected
certificate = {
"enabled": True,
"converged": point.converged,
"stop_index": point.stop_index,
"total_selected": len(selected),
"fraction": point.fraction,
"stop_time_s": point.stop_time_s,
"family_similarity": point.family_similarity,
}
return replay, certificate
def _best_feasible_probe_record(probe_history: list[dict[str, Any]]) -> dict[str, Any] | None:
feasible = [
item
@@ -519,9 +559,12 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]:
nonlocal process
selected = select_requests_for_threshold(requests, threshold=threshold)
replay_set, adaptive_stop_certificate = _adaptive_replay_set(
selected, study=study, window=window
)
restart_after_early_stop = study.trace.restart_engine_after_early_stop
outcomes, early_stopped, early_stop_reason = _replay_requests(
selected,
replay_set,
base_url=recipe.base_url,
timeout_s=recipe.request_timeout_s,
max_concurrency=study.trace.max_concurrency,
@@ -534,12 +577,13 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
evaluations, summary = summarize_evaluations(outcomes, study.slo)
probe_details = _probe_outcome_details(
threshold=threshold,
selected=selected,
selected=replay_set,
outcomes=outcomes,
evaluations=evaluations,
early_stopped=early_stopped,
early_stop_reason=early_stop_reason,
)
probe_details["adaptive_stop"] = adaptive_stop_certificate
with probe_details_path.open("a", encoding="utf-8") as details_handle:
details_handle.write(
json.dumps(probe_details, ensure_ascii=False) + "\n"
@@ -580,12 +624,14 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
probe_record = {
"threshold": threshold,
"request_count": payload.request_count,
"replayed_request_count": len(replay_set),
"pass_rate": payload.pass_rate,
"request_rate": payload.request_rate,
"feasible": payload.feasible,
"early_stopped": payload.early_stopped,
"early_stop_reason": payload.early_stop_reason,
"latency_summary": payload.latency_summary,
"adaptive_stop": adaptive_stop_certificate,
}
probe_history.append(probe_record)
StudyStore.write_json(Path(trial.probe_log_path), probe_history)

View File

@@ -30,6 +30,7 @@ from aituner.harness import (
from aituner.lca import (
build_study_workload_profile,
build_workload_profile,
find_convergence_prefix,
profile_similarity,
resolve_length_mode,
similarity_report,
@@ -38,6 +39,7 @@ from aituner.llm import _extract_response_text, build_prompt, parse_proposal_tex
from aituner.search import ThresholdProbe, binary_search_max_feasible
from aituner.slo import RequestOutcome, evaluate_request, summarize_evaluations
from aituner.spec import (
AdaptiveStopSpec,
ConfigPatch,
LLMEndpointSpec,
Proposal,
@@ -49,6 +51,7 @@ from aituner.spec import (
from aituner.store import StudyStore
from aituner.trace import load_trace_requests, summarize_window
from aituner.worker import (
_adaptive_replay_set,
_best_feasible_probe_record,
_latency_summary,
_run_one_request,
@@ -327,6 +330,134 @@ class CoreFlowTests(unittest.TestCase):
)["workload_lca_profile"]
self.assertNotIn("vector", legacy)
def _steady_requests(self, count: int, *, input_tokens: int = 100) -> list:
return [
TraceRequest(
row_id=f"r{i}",
arrival_s=float(i),
sampling_u=1.0,
body={},
prompt_tokens_hint=input_tokens,
completion_tokens_hint=16,
metadata={"hash_ids": None},
)
for i in range(count)
]
def _conv_window(self) -> WindowRecord:
return WindowRecord(
window_id="conv",
trace_path=Path("trace.jsonl"),
trace_type="chat",
window_start=0.0,
window_end=0.0,
source_payload={"block_size": 64},
)
def test_convergence_prefix_stops_early_on_stationary_trace(self) -> None:
requests = self._steady_requests(60)
point = find_convergence_prefix(
requests,
self._conv_window(),
gpu_count=1,
length_mode="total",
tau=0.9,
tau_c=0.9,
stable_checks=3,
max_checks=20,
min_fraction=0.1,
)
self.assertTrue(point.converged)
# A stationary workload should be trustworthy well before the full window.
self.assertLess(point.stop_index, len(requests))
self.assertLess(point.fraction, 1.0)
self.assertTrue(point.checks)
def test_convergence_prefix_waits_when_cache_warms_late(self) -> None:
window = self._conv_window()
# First half: no prefix reuse. Second half: every request reuses block 1,
# so the C dimension only stabilizes once the reuse regime is exercised.
requests = []
for i in range(30):
requests.append(
TraceRequest(
row_id=f"cold{i}",
arrival_s=float(i),
sampling_u=1.0,
body={},
prompt_tokens_hint=640,
completion_tokens_hint=16,
metadata={"hash_ids": [10_000 + i]},
)
)
for i in range(30):
requests.append(
TraceRequest(
row_id=f"warm{i}",
arrival_s=float(30 + i),
sampling_u=1.0,
body={},
prompt_tokens_hint=640,
completion_tokens_hint=16,
metadata={"hash_ids": [1, 2, 3, 4, 5]},
)
)
point = find_convergence_prefix(
requests,
window,
gpu_count=1,
length_mode="total",
tau=0.9,
tau_c=0.95,
stable_checks=2,
max_checks=20,
min_fraction=0.1,
)
# The C family similarity must be low while only the cold half is seen.
early = [c for c in point.checks if c["fraction"] <= 0.4]
self.assertTrue(early)
self.assertTrue(any(c["family_similarity"]["C"] < 0.9 for c in early))
def test_adaptive_replay_set_truncates_only_when_enabled(self) -> None:
from types import SimpleNamespace
requests = self._steady_requests(60)
window = self._conv_window()
enabled_study = SimpleNamespace(
trace=SimpleNamespace(
adaptive_stop=AdaptiveStopSpec(
enabled=True,
tau=0.9,
tau_c=0.9,
stable_checks=3,
max_checks=20,
min_fraction=0.1,
),
request_mode="chat",
),
hardware=SimpleNamespace(gpu_count=1),
)
replay, certificate = _adaptive_replay_set(
requests, study=enabled_study, window=window
)
self.assertIsNotNone(certificate)
self.assertTrue(certificate["enabled"])
self.assertEqual(len(replay), certificate["stop_index"])
self.assertLessEqual(len(replay), len(requests))
disabled_study = SimpleNamespace(
trace=SimpleNamespace(
adaptive_stop=AdaptiveStopSpec(enabled=False),
request_mode="chat",
),
hardware=SimpleNamespace(gpu_count=1),
)
passthrough, no_cert = _adaptive_replay_set(
requests, study=disabled_study, window=window
)
self.assertIsNone(no_cert)
self.assertEqual(len(passthrough), len(requests))
def test_lca_similarity_matrix_separates_different_profiles(self) -> None:
window = WindowRecord(
window_id="base",