Add Stop-A: offered-L-C-A convergence early-stop for replay
Phase 2 of the two-stop work. The L-C-A vector is a deterministic function of the trace's offered metadata, so the convergence of prefix-vs-full L-C-A (the paper's Fig. 9 curve) can be computed up front rather than monitored live, with identical result and no per-request overhead. - lca.find_convergence_prefix: earliest arrival-ordered prefix whose L and A family similarities reach tau and the slow C family reaches the stricter tau_c for stable_checks consecutive checkpoints. Self-similarity uses the raw log-feature vector (same window -> identical per-dim spread; RobustScaler is reserved for the cross-window Stop-C). If C never converges it reports the full set, which is the C-gate: no early stop on a cold/under-warmed cache. The checkpoint sims double as Phase 3 calibration data. - spec.AdaptiveStopSpec (trace.adaptive_stop), disabled by default until the thresholds are calibrated, so existing studies are unaffected. - worker._adaptive_replay_set truncates each probe's replay to the convergence prefix and records a certificate (converged, fraction, family similarity) into probe history and probe_details. Offered request_rate at the threshold is unchanged; only wall-clock replay shrinks. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -259,6 +259,151 @@ def similarity_report(profiles: Sequence[WorkloadProfile]) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConvergencePoint:
|
||||
converged: bool
|
||||
stop_index: int
|
||||
stop_time_s: float
|
||||
fraction: float
|
||||
family_similarity: dict[str, float]
|
||||
checks: list[dict[str, Any]]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"converged": self.converged,
|
||||
"stop_index": self.stop_index,
|
||||
"stop_time_s": self.stop_time_s,
|
||||
"fraction": self.fraction,
|
||||
"family_similarity": self.family_similarity,
|
||||
"checks": self.checks,
|
||||
}
|
||||
|
||||
|
||||
def find_convergence_prefix(
|
||||
requests: list[TraceRequest],
|
||||
window: WindowRecord,
|
||||
*,
|
||||
gpu_count: int,
|
||||
length_mode: str = "total",
|
||||
tau: float = 0.9,
|
||||
tau_c: float = 0.92,
|
||||
stable_checks: int = 3,
|
||||
max_checks: int = 20,
|
||||
min_fraction: float = 0.1,
|
||||
) -> ConvergencePoint:
|
||||
"""Earliest arrival-ordered prefix whose offered L-C-A converges to the full set.
|
||||
|
||||
The L-C-A vector is a deterministic function of the trace metadata, so the
|
||||
convergence of prefix-vs-full is itself deterministic (the paper's Fig. 9
|
||||
curve). Stop-A replays only up to this prefix. A prefix counts as converged
|
||||
when the L and A family similarities reach ``tau`` and the (slowest) C family
|
||||
similarity reaches the stricter ``tau_c`` for ``stable_checks`` consecutive
|
||||
checkpoints. If that never happens within the window the point reports the
|
||||
full set (converged=False), which keeps the C-gate honest: an unconverged C
|
||||
means the probe must replay the whole window rather than stop early.
|
||||
"""
|
||||
total = len(requests)
|
||||
if total == 0:
|
||||
return ConvergencePoint(
|
||||
converged=False,
|
||||
stop_index=0,
|
||||
stop_time_s=0.0,
|
||||
fraction=1.0,
|
||||
family_similarity={"L": 1.0, "C": 1.0, "A": 1.0},
|
||||
checks=[],
|
||||
)
|
||||
# Compare each arrival-ordered prefix to the whole set, both measured over
|
||||
# their own elapsed span so the A (rate) dimension is comparable rather than
|
||||
# diluted by the fixed window length.
|
||||
target = _prefix_profile(
|
||||
requests, total, window, gpu_count=gpu_count, length_mode=length_mode
|
||||
)
|
||||
indices = _checkpoint_indices(
|
||||
total, max_checks=max_checks, min_fraction=min_fraction
|
||||
)
|
||||
checks: list[dict[str, Any]] = []
|
||||
consecutive = 0
|
||||
converged_index: int | None = None
|
||||
converged_sims: dict[str, float] | None = None
|
||||
for index in indices:
|
||||
prefix = _prefix_profile(
|
||||
requests, index, window, gpu_count=gpu_count, length_mode=length_mode
|
||||
)
|
||||
sims = _family_similarity(target.vector, prefix.vector)
|
||||
checks.append(
|
||||
{
|
||||
"index": index,
|
||||
"fraction": float(index / total),
|
||||
"time_s": float(requests[index - 1].arrival_s),
|
||||
"family_similarity": sims,
|
||||
}
|
||||
)
|
||||
passed = sims["L"] >= tau and sims["A"] >= tau and sims["C"] >= tau_c
|
||||
consecutive = consecutive + 1 if passed else 0
|
||||
if consecutive >= stable_checks and converged_index is None:
|
||||
converged_index = index
|
||||
converged_sims = sims
|
||||
break
|
||||
if converged_index is None:
|
||||
last_sims = checks[-1]["family_similarity"] if checks else {"L": 1.0, "C": 1.0, "A": 1.0}
|
||||
return ConvergencePoint(
|
||||
converged=False,
|
||||
stop_index=total,
|
||||
stop_time_s=float(requests[-1].arrival_s),
|
||||
fraction=1.0,
|
||||
family_similarity=last_sims,
|
||||
checks=checks,
|
||||
)
|
||||
return ConvergencePoint(
|
||||
converged=True,
|
||||
stop_index=converged_index,
|
||||
stop_time_s=float(requests[converged_index - 1].arrival_s),
|
||||
fraction=float(converged_index / total),
|
||||
family_similarity=converged_sims or {},
|
||||
checks=checks,
|
||||
)
|
||||
|
||||
|
||||
def _prefix_profile(
|
||||
requests: list[TraceRequest],
|
||||
index: int,
|
||||
window: WindowRecord,
|
||||
*,
|
||||
gpu_count: int,
|
||||
length_mode: str,
|
||||
) -> WorkloadProfile:
|
||||
prefix = requests[:index]
|
||||
end = float(prefix[-1].arrival_s) if prefix else float(window.window_start)
|
||||
prefix_window = WindowRecord(
|
||||
window_id=window.window_id,
|
||||
trace_path=window.trace_path,
|
||||
trace_type=window.trace_type,
|
||||
window_start=window.window_start,
|
||||
window_end=end,
|
||||
source_payload=window.source_payload,
|
||||
)
|
||||
return build_workload_profile(
|
||||
prefix, prefix_window, gpu_count=gpu_count, length_mode=length_mode
|
||||
)
|
||||
|
||||
|
||||
def _checkpoint_indices(total: int, *, max_checks: int, min_fraction: float) -> list[int]:
|
||||
start = max(1, int(math.ceil(min_fraction * total)))
|
||||
if total <= max_checks:
|
||||
candidates = range(start, total + 1)
|
||||
else:
|
||||
step = max(1, total // max_checks)
|
||||
candidates = list(range(start, total + 1, step))
|
||||
if candidates and candidates[-1] != total:
|
||||
candidates.append(total)
|
||||
seen: list[int] = []
|
||||
for value in candidates:
|
||||
clamped = min(total, max(1, int(value)))
|
||||
if not seen or seen[-1] != clamped:
|
||||
seen.append(clamped)
|
||||
return seen
|
||||
|
||||
|
||||
def dumps_profile(profile: WorkloadProfile) -> str:
|
||||
return json.dumps(profile.to_dict(), ensure_ascii=False, indent=2) + "\n"
|
||||
|
||||
|
||||
@@ -321,6 +321,59 @@ class InputLengthFilterSpec:
|
||||
return spec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AdaptiveStopSpec:
|
||||
"""Stop-A: truncate per-probe replay once the offered L-C-A converges.
|
||||
|
||||
Disabled by default; the thresholds are calibrated per workload (Phase 3)
|
||||
before being switched on, so existing studies are unaffected.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
tau: float = 0.9
|
||||
tau_c: float = 0.92
|
||||
stable_checks: int = 3
|
||||
max_checks: int = 20
|
||||
min_fraction: float = 0.1
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Any) -> "AdaptiveStopSpec":
|
||||
if data is None:
|
||||
return cls()
|
||||
m = _require_mapping(data, context="trace.adaptive_stop")
|
||||
enabled = (
|
||||
_require_bool(m.get("enabled"), context="trace.adaptive_stop.enabled")
|
||||
if m.get("enabled") is not None
|
||||
else False
|
||||
)
|
||||
tau = _require_float(m.get("tau", 0.9), context="trace.adaptive_stop.tau")
|
||||
tau_c = _require_float(m.get("tau_c", 0.92), context="trace.adaptive_stop.tau_c")
|
||||
stable_checks = _require_int(
|
||||
m.get("stable_checks", 3), context="trace.adaptive_stop.stable_checks"
|
||||
)
|
||||
max_checks = _require_int(
|
||||
m.get("max_checks", 20), context="trace.adaptive_stop.max_checks"
|
||||
)
|
||||
min_fraction = _require_float(
|
||||
m.get("min_fraction", 0.1), context="trace.adaptive_stop.min_fraction"
|
||||
)
|
||||
for name, value in (("tau", tau), ("tau_c", tau_c), ("min_fraction", min_fraction)):
|
||||
if not 0.0 < value <= 1.0:
|
||||
raise SpecError(f"trace.adaptive_stop.{name} must be in (0, 1].")
|
||||
if stable_checks <= 0 or max_checks <= 0:
|
||||
raise SpecError(
|
||||
"trace.adaptive_stop.stable_checks and max_checks must be > 0."
|
||||
)
|
||||
return cls(
|
||||
enabled=enabled,
|
||||
tau=tau,
|
||||
tau_c=tau_c,
|
||||
stable_checks=stable_checks,
|
||||
max_checks=max_checks,
|
||||
min_fraction=min_fraction,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TraceSpec:
|
||||
windows_path: str
|
||||
@@ -338,6 +391,7 @@ class TraceSpec:
|
||||
early_stop_max_lag_s: float | None = None
|
||||
early_stop_max_elapsed_s: float | None = None
|
||||
restart_engine_after_early_stop: bool = False
|
||||
adaptive_stop: AdaptiveStopSpec = AdaptiveStopSpec()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
|
||||
@@ -429,6 +483,7 @@ class TraceSpec:
|
||||
if data.get("restart_engine_after_early_stop") is not None
|
||||
else request_mode == "decode_only"
|
||||
),
|
||||
adaptive_stop=AdaptiveStopSpec.from_dict(data.get("adaptive_stop")),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from typing import Any, Callable
|
||||
|
||||
from .engine import build_launch_recipe
|
||||
from .http_client import HttpClientError, stream_chat_completion, wait_for_server
|
||||
from .lca import find_convergence_prefix, resolve_length_mode
|
||||
from .search import ThresholdProbe, binary_search_max_feasible
|
||||
from .slo import RequestOutcome, evaluate_request, summarize_evaluations
|
||||
from .spec import ConfigPatch, SamplingSearchSpec, TrialSpec, load_study_spec, to_jsonable
|
||||
@@ -209,6 +210,45 @@ def _probe_outcome_details(
|
||||
}
|
||||
|
||||
|
||||
def _adaptive_replay_set(
|
||||
selected: list[TraceRequest],
|
||||
*,
|
||||
study: Any,
|
||||
window: Any,
|
||||
) -> tuple[list[TraceRequest], dict[str, Any] | None]:
|
||||
"""Stop-A: truncate the replay to the offered-L-C-A convergence prefix.
|
||||
|
||||
Returns the (possibly shortened) request list to replay and a certificate of
|
||||
the convergence decision. When Stop-A is disabled, or C never converges, the
|
||||
full selected set is replayed (the C-gate: no early stop on a cold cache).
|
||||
"""
|
||||
spec = study.trace.adaptive_stop
|
||||
if not getattr(spec, "enabled", False) or not selected:
|
||||
return selected, None
|
||||
point = find_convergence_prefix(
|
||||
selected,
|
||||
window,
|
||||
gpu_count=study.hardware.gpu_count,
|
||||
length_mode=resolve_length_mode(request_mode=study.trace.request_mode),
|
||||
tau=spec.tau,
|
||||
tau_c=spec.tau_c,
|
||||
stable_checks=spec.stable_checks,
|
||||
max_checks=spec.max_checks,
|
||||
min_fraction=spec.min_fraction,
|
||||
)
|
||||
replay = selected[: point.stop_index] if point.stop_index > 0 else selected
|
||||
certificate = {
|
||||
"enabled": True,
|
||||
"converged": point.converged,
|
||||
"stop_index": point.stop_index,
|
||||
"total_selected": len(selected),
|
||||
"fraction": point.fraction,
|
||||
"stop_time_s": point.stop_time_s,
|
||||
"family_similarity": point.family_similarity,
|
||||
}
|
||||
return replay, certificate
|
||||
|
||||
|
||||
def _best_feasible_probe_record(probe_history: list[dict[str, Any]]) -> dict[str, Any] | None:
|
||||
feasible = [
|
||||
item
|
||||
@@ -519,9 +559,12 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
|
||||
def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]:
|
||||
nonlocal process
|
||||
selected = select_requests_for_threshold(requests, threshold=threshold)
|
||||
replay_set, adaptive_stop_certificate = _adaptive_replay_set(
|
||||
selected, study=study, window=window
|
||||
)
|
||||
restart_after_early_stop = study.trace.restart_engine_after_early_stop
|
||||
outcomes, early_stopped, early_stop_reason = _replay_requests(
|
||||
selected,
|
||||
replay_set,
|
||||
base_url=recipe.base_url,
|
||||
timeout_s=recipe.request_timeout_s,
|
||||
max_concurrency=study.trace.max_concurrency,
|
||||
@@ -534,12 +577,13 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
|
||||
evaluations, summary = summarize_evaluations(outcomes, study.slo)
|
||||
probe_details = _probe_outcome_details(
|
||||
threshold=threshold,
|
||||
selected=selected,
|
||||
selected=replay_set,
|
||||
outcomes=outcomes,
|
||||
evaluations=evaluations,
|
||||
early_stopped=early_stopped,
|
||||
early_stop_reason=early_stop_reason,
|
||||
)
|
||||
probe_details["adaptive_stop"] = adaptive_stop_certificate
|
||||
with probe_details_path.open("a", encoding="utf-8") as details_handle:
|
||||
details_handle.write(
|
||||
json.dumps(probe_details, ensure_ascii=False) + "\n"
|
||||
@@ -580,12 +624,14 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
|
||||
probe_record = {
|
||||
"threshold": threshold,
|
||||
"request_count": payload.request_count,
|
||||
"replayed_request_count": len(replay_set),
|
||||
"pass_rate": payload.pass_rate,
|
||||
"request_rate": payload.request_rate,
|
||||
"feasible": payload.feasible,
|
||||
"early_stopped": payload.early_stopped,
|
||||
"early_stop_reason": payload.early_stop_reason,
|
||||
"latency_summary": payload.latency_summary,
|
||||
"adaptive_stop": adaptive_stop_certificate,
|
||||
}
|
||||
probe_history.append(probe_record)
|
||||
StudyStore.write_json(Path(trial.probe_log_path), probe_history)
|
||||
|
||||
@@ -30,6 +30,7 @@ from aituner.harness import (
|
||||
from aituner.lca import (
|
||||
build_study_workload_profile,
|
||||
build_workload_profile,
|
||||
find_convergence_prefix,
|
||||
profile_similarity,
|
||||
resolve_length_mode,
|
||||
similarity_report,
|
||||
@@ -38,6 +39,7 @@ from aituner.llm import _extract_response_text, build_prompt, parse_proposal_tex
|
||||
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
||||
from aituner.slo import RequestOutcome, evaluate_request, summarize_evaluations
|
||||
from aituner.spec import (
|
||||
AdaptiveStopSpec,
|
||||
ConfigPatch,
|
||||
LLMEndpointSpec,
|
||||
Proposal,
|
||||
@@ -49,6 +51,7 @@ from aituner.spec import (
|
||||
from aituner.store import StudyStore
|
||||
from aituner.trace import load_trace_requests, summarize_window
|
||||
from aituner.worker import (
|
||||
_adaptive_replay_set,
|
||||
_best_feasible_probe_record,
|
||||
_latency_summary,
|
||||
_run_one_request,
|
||||
@@ -327,6 +330,134 @@ class CoreFlowTests(unittest.TestCase):
|
||||
)["workload_lca_profile"]
|
||||
self.assertNotIn("vector", legacy)
|
||||
|
||||
def _steady_requests(self, count: int, *, input_tokens: int = 100) -> list:
|
||||
return [
|
||||
TraceRequest(
|
||||
row_id=f"r{i}",
|
||||
arrival_s=float(i),
|
||||
sampling_u=1.0,
|
||||
body={},
|
||||
prompt_tokens_hint=input_tokens,
|
||||
completion_tokens_hint=16,
|
||||
metadata={"hash_ids": None},
|
||||
)
|
||||
for i in range(count)
|
||||
]
|
||||
|
||||
def _conv_window(self) -> WindowRecord:
|
||||
return WindowRecord(
|
||||
window_id="conv",
|
||||
trace_path=Path("trace.jsonl"),
|
||||
trace_type="chat",
|
||||
window_start=0.0,
|
||||
window_end=0.0,
|
||||
source_payload={"block_size": 64},
|
||||
)
|
||||
|
||||
def test_convergence_prefix_stops_early_on_stationary_trace(self) -> None:
|
||||
requests = self._steady_requests(60)
|
||||
point = find_convergence_prefix(
|
||||
requests,
|
||||
self._conv_window(),
|
||||
gpu_count=1,
|
||||
length_mode="total",
|
||||
tau=0.9,
|
||||
tau_c=0.9,
|
||||
stable_checks=3,
|
||||
max_checks=20,
|
||||
min_fraction=0.1,
|
||||
)
|
||||
self.assertTrue(point.converged)
|
||||
# A stationary workload should be trustworthy well before the full window.
|
||||
self.assertLess(point.stop_index, len(requests))
|
||||
self.assertLess(point.fraction, 1.0)
|
||||
self.assertTrue(point.checks)
|
||||
|
||||
def test_convergence_prefix_waits_when_cache_warms_late(self) -> None:
|
||||
window = self._conv_window()
|
||||
# First half: no prefix reuse. Second half: every request reuses block 1,
|
||||
# so the C dimension only stabilizes once the reuse regime is exercised.
|
||||
requests = []
|
||||
for i in range(30):
|
||||
requests.append(
|
||||
TraceRequest(
|
||||
row_id=f"cold{i}",
|
||||
arrival_s=float(i),
|
||||
sampling_u=1.0,
|
||||
body={},
|
||||
prompt_tokens_hint=640,
|
||||
completion_tokens_hint=16,
|
||||
metadata={"hash_ids": [10_000 + i]},
|
||||
)
|
||||
)
|
||||
for i in range(30):
|
||||
requests.append(
|
||||
TraceRequest(
|
||||
row_id=f"warm{i}",
|
||||
arrival_s=float(30 + i),
|
||||
sampling_u=1.0,
|
||||
body={},
|
||||
prompt_tokens_hint=640,
|
||||
completion_tokens_hint=16,
|
||||
metadata={"hash_ids": [1, 2, 3, 4, 5]},
|
||||
)
|
||||
)
|
||||
point = find_convergence_prefix(
|
||||
requests,
|
||||
window,
|
||||
gpu_count=1,
|
||||
length_mode="total",
|
||||
tau=0.9,
|
||||
tau_c=0.95,
|
||||
stable_checks=2,
|
||||
max_checks=20,
|
||||
min_fraction=0.1,
|
||||
)
|
||||
# The C family similarity must be low while only the cold half is seen.
|
||||
early = [c for c in point.checks if c["fraction"] <= 0.4]
|
||||
self.assertTrue(early)
|
||||
self.assertTrue(any(c["family_similarity"]["C"] < 0.9 for c in early))
|
||||
|
||||
def test_adaptive_replay_set_truncates_only_when_enabled(self) -> None:
|
||||
from types import SimpleNamespace
|
||||
|
||||
requests = self._steady_requests(60)
|
||||
window = self._conv_window()
|
||||
enabled_study = SimpleNamespace(
|
||||
trace=SimpleNamespace(
|
||||
adaptive_stop=AdaptiveStopSpec(
|
||||
enabled=True,
|
||||
tau=0.9,
|
||||
tau_c=0.9,
|
||||
stable_checks=3,
|
||||
max_checks=20,
|
||||
min_fraction=0.1,
|
||||
),
|
||||
request_mode="chat",
|
||||
),
|
||||
hardware=SimpleNamespace(gpu_count=1),
|
||||
)
|
||||
replay, certificate = _adaptive_replay_set(
|
||||
requests, study=enabled_study, window=window
|
||||
)
|
||||
self.assertIsNotNone(certificate)
|
||||
self.assertTrue(certificate["enabled"])
|
||||
self.assertEqual(len(replay), certificate["stop_index"])
|
||||
self.assertLessEqual(len(replay), len(requests))
|
||||
|
||||
disabled_study = SimpleNamespace(
|
||||
trace=SimpleNamespace(
|
||||
adaptive_stop=AdaptiveStopSpec(enabled=False),
|
||||
request_mode="chat",
|
||||
),
|
||||
hardware=SimpleNamespace(gpu_count=1),
|
||||
)
|
||||
passthrough, no_cert = _adaptive_replay_set(
|
||||
requests, study=disabled_study, window=window
|
||||
)
|
||||
self.assertIsNone(no_cert)
|
||||
self.assertEqual(len(passthrough), len(requests))
|
||||
|
||||
def test_lca_similarity_matrix_separates_different_profiles(self) -> None:
|
||||
window = WindowRecord(
|
||||
window_id="base",
|
||||
|
||||
Reference in New Issue
Block a user