Add decode-only study mode support

This commit is contained in:
2026-04-09 11:23:17 +08:00
parent 96140b79bb
commit c158807fac
6 changed files with 282 additions and 1 deletions

View File

@@ -34,7 +34,10 @@ from aituner.trace import TraceRequest
def _write_study_assets(
tmp_path: Path, *, trace_overrides: dict[str, object] | None = None
tmp_path: Path,
*,
trace_overrides: dict[str, object] | None = None,
slo_overrides: dict[str, object] | None = None,
) -> Path:
trace_dir = tmp_path / "trace_windows" / "traces"
trace_dir.mkdir(parents=True)
@@ -148,6 +151,8 @@ def _write_study_assets(
"llm": {"system_prompt": "Tune it.", "max_history_trials": 8},
"capability_profile_path": str(capability_path)
}
if slo_overrides:
study_payload["slo"].update(slo_overrides)
study_path.write_text(json.dumps(study_payload), encoding="utf-8")
return study_path
@@ -222,6 +227,30 @@ class CoreFlowTests(unittest.TestCase):
with self.assertRaisesRegex(SpecError, "min_input_tokens must be <="):
load_study_spec(study_path)
def test_decode_only_mode_is_loaded_and_prompt_mentions_it(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(
tmp_path,
trace_overrides={"request_mode": "decode_only"},
slo_overrides={
"ttft_rule": None,
"tpot_rule": {"kind": "fixed_ms", "threshold_ms": 20},
},
)
study = load_study_spec(study_path)
self.assertEqual(study.trace.request_mode, "decode_only")
window, requests = load_trace_requests(study, study_spec_path=study_path)
prompt = build_prompt(
study=study,
window_summary=summarize_window(requests, window),
state=StudyState(study_id=study.study_id),
capability_profile=None,
)
self.assertIn('"request_mode": "decode_only"', prompt)
self.assertIn("There is no TTFT SLO for this study.", prompt)
self.assertIn("decode-only", prompt)
def test_bailian_endpoint_defaults(self) -> None:
endpoint = LLMEndpointSpec.from_dict({"provider": "bailian", "model": "qwen-plus"})
self.assertEqual(endpoint.provider, "bailian")
@@ -481,6 +510,40 @@ class CoreFlowTests(unittest.TestCase):
self.assertFalse(evaluations[1].passed)
self.assertEqual(summary["slo_pass_rate"], 0.5)
def test_slo_evaluation_supports_tpot_only_95_percent_target(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
study = load_study_spec(
_write_study_assets(
Path(tmp),
slo_overrides={
"ttft_rule": None,
"tpot_rule": {"kind": "fixed_ms", "threshold_ms": 20},
},
)
)
outcomes = [
RequestOutcome(
request_id="r1",
success=True,
ttft_ms=3000,
tpot_ms=10,
prompt_tokens=1000,
completion_tokens=16,
),
RequestOutcome(
request_id="r2",
success=True,
ttft_ms=9000,
tpot_ms=21,
prompt_tokens=5000,
completion_tokens=16,
),
]
evaluations, summary = summarize_evaluations(outcomes, study.slo)
self.assertEqual([item.passed for item in evaluations], [True, False])
self.assertEqual(summary["slo_pass_rate"], 0.5)
self.assertFalse(summary["feasible"])
def test_prepare_trace_windows_materializes_repo_local_assets(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
@@ -1241,6 +1304,7 @@ class CoreFlowTests(unittest.TestCase):
evaluations = [evaluate_request(item, study.slo) for item in outcomes]
summary = _latency_summary(outcomes=outcomes, evaluations=evaluations, study=study)
self.assertEqual(summary["observed_request_count"], 2)
self.assertEqual(summary["request_mode"], "chat")
self.assertEqual(summary["ttft_ms"]["mean"], 150.0)
self.assertEqual(summary["ttft_ms"]["p50"], 100.0)
self.assertEqual(summary["ttft_ms"]["p99"], 200.0)