"""Tests for A5 joined analysis: join + indices + labels.""" from __future__ import annotations from analysis.characterization.joined_analysis import ( build_joined_records, hotspot_index, interference_index, label_slow_requests, reuse_decomposition, window_summary, _normalize_worker, _percentile, _resolve_worker, _vllm_rid_matches, ) def _mk_metric(rid, **kw): base = { "request_id": rid, "session_id": "s1", "turn_id": 0, "trace_timestamp_s": 1.0, "input_length": 1000, "output_length": 50, "cached_tokens": 0, "actual_output_tokens": 50, "latency_s": 1.0, "ttft_s": 0.5, "tpot_s": 0.04, "t_dispatch_unix": 1000.0, "t_first_token_unix": 1000.5, "t_finish_unix": 1001.0, "endpoint_url": "http://h:8000", "trace_hash_ids": [], "error": None, } base.update(kw) return base def test_build_joined_records_merges_by_request_id(): metrics = [_mk_metric("r1"), _mk_metric("r2")] breakdown = [{"request_id": "r1", "policy": "lmetric", "chosen_idx": 3, "estimated_new_tokens": 500, "routed_to": "http://h:8000"}] worker_state = [{"request_id": "r2", "workers": [{"idx": 0, "url": "x"}]}] joined = build_joined_records(metrics, breakdown, worker_state) assert len(joined) == 2 j_by_id = {r["request_id"]: r for r in joined} assert j_by_id["r1"]["policy"] == "lmetric" assert j_by_id["r1"]["chosen_idx"] == 3 assert j_by_id["r1"]["estimated_new_tokens"] == 500 assert j_by_id["r2"]["worker_state_at_decision"][0]["url"] == "x" assert j_by_id["r2"].get("policy") is None # no breakdown for r2 def test_reuse_decomposition_classifies_intra_and_cross(): records = [ _mk_metric("r1", session_id="A", trace_hash_ids=[11], cached_tokens=0, t_dispatch_unix=1.0), _mk_metric("r2", session_id="A", trace_hash_ids=[11], cached_tokens=100, t_dispatch_unix=2.0), _mk_metric("r3", session_id="B", trace_hash_ids=[11], cached_tokens=100, t_dispatch_unix=3.0), ] out = reuse_decomposition(records) assert out["status"] == "supported" assert out["intra_session_tokens"] > 0 assert out["cross_session_tokens"] > 0 fr = out["fractions"] assert abs(sum(fr.values()) - 1.0) < 1e-9 def test_normalize_worker_maps_port_to_engine_id(): assert _normalize_worker("http://node:8000") == "engine_0" assert _normalize_worker("http://node:8005/foo") == "engine_5" assert _normalize_worker("engine_3") == "engine_3" assert _normalize_worker(None) is None def test_interference_index_marks_overlap_when_other_request_prefilling(): metrics = [ _mk_metric("decode_target", t_first_token_unix=10.0, t_finish_unix=11.0, tpot_s=0.10), _mk_metric("decode_clean", t_first_token_unix=20.0, t_finish_unix=21.0, tpot_s=0.04), ] breakdown = [ {"request_id": "decode_target", "routed_to": "http://h:8000"}, {"request_id": "decode_clean", "routed_to": "http://h:8001"}, ] joined = build_joined_records(metrics, breakdown, []) engine_state = { "engine_0": [ {"t_unix": 10.5, "prefill_tokens": 8000, "per_req": [{"rid": "cmpl-other-0-aaaa", "phase": "prefill"}]}, ], "engine_1": [ {"t_unix": 20.5, "prefill_tokens": 0, "per_req": [{"rid": "decode_clean", "phase": "decode"}]}, ], } out = interference_index(joined, engine_state) assert out["status"] == "supported" assert out["n_overlap_requests"] == 1 assert out["n_clean_requests"] == 1 assert out["interference_index"] is not None assert out["interference_index"] > 2.0 def test_resolve_worker_prefers_explicit_map(): assert _resolve_worker("http://h:9100", {"http://h:9100": "engine_0"}) == "engine_0" assert _resolve_worker("http://h:9100", None) == "engine_1100" def test_vllm_rid_matches_strips_cmpl_prefix(): assert _vllm_rid_matches("cmpl-1237198:1:1237198:0-0-b07fed77", "1237198:1:1237198:0") assert _vllm_rid_matches("chatcmpl-abc-0-xx", "abc") assert not _vllm_rid_matches("cmpl-other-0-xx", "1237198:1:1237198:0") assert not _vllm_rid_matches(None, "x") def test_hotspot_index_max_over_median_p90(): """One hot worker with TTFT 10x the others should drive a >1 index.""" rows = [] for i in range(3): for _ in range(10): rows.append({ "request_id": f"x{i}", "routed_to": f"http://h:800{i}", "endpoint_url": f"http://h:800{i}", "ttft_s": 0.5 if i < 2 else 5.0, "latency_s": 1.0, "error": None, }) out = hotspot_index(rows) assert out["status"] == "supported" assert out["hotspot_index_ttft_p90"] > 5.0 def test_hotspot_index_uses_true_median_for_even_n(): """8 workers, sorted TTFT p90 [1,2,3,4,5,6,7,80]. True median = (4+5)/2 = 4.5; hotspot = 80/4.5 ≈ 17.78. Previous buggy implementation used sorted[4] = 5, giving 80/5 = 16.0. """ rows = [] ttfts = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 80.0] for i, t in enumerate(ttfts): for _ in range(10): rows.append({ "request_id": f"x{i}", "routed_to": f"http://h:800{i}", "endpoint_url": f"http://h:800{i}", "ttft_s": t, "latency_s": 1.0, "error": None, }) out = hotspot_index(rows) assert out["status"] == "supported" idx = out["hotspot_index_ttft_p90"] assert abs(idx - 80.0 / 4.5) < 1e-6, f"expected ~17.78, got {idx}" def test_label_slow_requests_flags_overlap_and_hot_worker(): metrics = [ _mk_metric("slow_overlap", ttft_s=10.0, t_first_token_unix=10.0, t_finish_unix=11.0), _mk_metric("slow_no_overlap", ttft_s=10.0, t_first_token_unix=20.0, t_finish_unix=21.0), _mk_metric("fast", ttft_s=0.5, t_first_token_unix=15.0, t_finish_unix=16.0), ] bk = [ {"request_id": "slow_overlap", "routed_to": "http://h:8000"}, {"request_id": "slow_no_overlap", "routed_to": "http://h:8005"}, {"request_id": "fast", "routed_to": "http://h:8000"}, ] joined = build_joined_records(metrics, bk, []) engine_state = { "engine_0": [{"t_unix": 10.5, "prefill_tokens": 5000, "per_req": [{"rid": "cmpl-other-0-x", "phase": "prefill"}]}], } labels = label_slow_requests(joined, engine_state, slow_ttft_factor=2.0) by_id = {L["request_id"]: L["label"] for L in labels} assert by_id.get("slow_overlap") == "same_worker_prefill_overlap" assert "fast" not in by_id assert "slow_no_overlap" in by_id def test_window_summary_buckets_by_dispatch_unix(): run_meta = { "run_start_unix": 1000.0, "warmup_end_unix": 1010.0, "steady_end_unix": 1030.0, "drain_end_unix": 1040.0, } joined = [ _mk_metric("w", t_dispatch_unix=1005.0, ttft_s=0.5, latency_s=1.0, tpot_s=0.04), _mk_metric("s", t_dispatch_unix=1020.0, ttft_s=0.6, latency_s=1.5, tpot_s=0.05), _mk_metric("d", t_dispatch_unix=1035.0, ttft_s=0.7, latency_s=2.0, tpot_s=0.06), ] out = window_summary(joined, run_meta) assert out["windows"]["warmup"]["attempted"] == 1 assert out["windows"]["steady"]["attempted"] == 1 assert out["windows"]["drain"]["attempted"] == 1 assert out["windows"]["steady"]["ttft_p90_s"] is not None def test_percentile_helper_handles_singleton(): assert _percentile([5.0], 0.99) == 5.0 assert _percentile([], 0.50) is None