Add linear_ms SLO rule (length-aware TTFT budget)
threshold_ms = intercept_ms + per_token_ms * input_tokens. Lets the TTFT target scale with prefill work, e.g. "4s + L_in/8k" => intercept_ms=4000, per_token_ms=0.125 (4s base, +1s per 8k input tokens). slo + spec + test. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -29,6 +29,9 @@ def _rule_threshold_ms(rule: ThresholdRule, prompt_tokens: int | None) -> float:
|
||||
if rule.kind == "fixed_ms":
|
||||
assert rule.threshold_ms is not None
|
||||
return rule.threshold_ms
|
||||
if rule.kind == "linear_ms":
|
||||
assert rule.intercept_ms is not None and rule.per_token_ms is not None
|
||||
return float(rule.intercept_ms) + float(rule.per_token_ms) * float(prompt_tokens or 0)
|
||||
if rule.kind != "step_ms":
|
||||
raise ValueError(f"Unsupported threshold rule: {rule.kind}")
|
||||
prompt = float(prompt_tokens or 0)
|
||||
|
||||
@@ -504,6 +504,8 @@ class ThresholdRule:
|
||||
kind: str
|
||||
threshold_ms: float | None = None
|
||||
buckets: list[dict[str, float]] = field(default_factory=list)
|
||||
intercept_ms: float | None = None
|
||||
per_token_ms: float | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any], *, context: str) -> "ThresholdRule":
|
||||
@@ -515,6 +517,18 @@ class ThresholdRule:
|
||||
data.get("threshold_ms"), context=f"{context}.threshold_ms"
|
||||
),
|
||||
)
|
||||
if kind == "linear_ms":
|
||||
# threshold = intercept_ms + per_token_ms * input_tokens
|
||||
# e.g. "4s + L_in/8k" -> intercept_ms=4000, per_token_ms=0.125
|
||||
intercept_ms = _require_float(
|
||||
data.get("intercept_ms"), context=f"{context}.intercept_ms"
|
||||
)
|
||||
per_token_ms = _require_float(
|
||||
data.get("per_token_ms"), context=f"{context}.per_token_ms"
|
||||
)
|
||||
if intercept_ms < 0 or per_token_ms < 0:
|
||||
raise SpecError(f"{context}.intercept_ms/per_token_ms must be >= 0.")
|
||||
return cls(kind=kind, intercept_ms=intercept_ms, per_token_ms=per_token_ms)
|
||||
if kind == "step_ms":
|
||||
raw = data.get("buckets")
|
||||
if not isinstance(raw, list) or not raw:
|
||||
|
||||
@@ -44,6 +44,7 @@ from aituner.spec import (
|
||||
ConfigPatch,
|
||||
LLMEndpointSpec,
|
||||
Proposal,
|
||||
SloSpec,
|
||||
SpecError,
|
||||
StudyState,
|
||||
TrialSummary,
|
||||
@@ -531,6 +532,34 @@ class CoreFlowTests(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_linear_ms_ttft_rule_scales_with_input_length(self) -> None:
|
||||
slo = SloSpec.from_dict(
|
||||
{
|
||||
"target_pass_rate": 0.95,
|
||||
"ttft_rule": {"kind": "linear_ms", "intercept_ms": 4000, "per_token_ms": 0.125},
|
||||
"tpot_rule": {"kind": "fixed_ms", "threshold_ms": 50},
|
||||
}
|
||||
)
|
||||
|
||||
def ev(prompt_tokens: int, ttft_ms: float):
|
||||
return evaluate_request(
|
||||
RequestOutcome(
|
||||
request_id="r",
|
||||
success=True,
|
||||
ttft_ms=ttft_ms,
|
||||
tpot_ms=10.0,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=8,
|
||||
),
|
||||
slo,
|
||||
)
|
||||
|
||||
# threshold = 4000 + 0.125*L_in : 8k->5000ms, 0->4000ms
|
||||
self.assertTrue(ev(8000, 4900).passed)
|
||||
self.assertFalse(ev(8000, 5100).passed)
|
||||
self.assertTrue(ev(0, 3900).passed)
|
||||
self.assertFalse(ev(0, 4100).passed)
|
||||
|
||||
def test_lca_similarity_matrix_separates_different_profiles(self) -> None:
|
||||
window = WindowRecord(
|
||||
window_id="base",
|
||||
|
||||
Reference in New Issue
Block a user