KVC v2 beats 4DP at ts=1 same-scale on 7/8 metrics: TTFT mean -24%, p50 -54%, p90 -64%; lat mean -0.8%, p50 -12.6%, p90 -0.7%. Direct-to-D rate jumped 42.8% -> 91.7%. REFACTOR_PLAN_V1 scenario C achieved. Two-knob fix: - reset-on-success blacklist decay: clear (sess, D) reject counter on successful direct-to-D path. Eliminates v1 thrashing where session 6880 was stable on decode-1 for 70 turns then collapsed to 75 D-changes after cumulative transient pressure tripped the permanent blacklist. - bump --kvcache-direct-max-uncached-tokens default 2048 -> 8192 via CLI flag. 41% of v1 fallbacks were 'real-large-append' (>2048 token append); raising the threshold lets these go through the direct-to-D fast path. Code: - policies.py: RoutingState.session_d_rejects counter + KvAwarePolicy migration_reject_threshold; degenerate fallback picks least-rejected D. - replay.py: record_admission_reject + reset-on-success in _run_request; _fallthrough_reason classifies turn-2+ fall-throughs as session-not-resident / real-large-append / etc, replacing misleading 'large-append' suffix (TEAM_REPORT §2.7). - cli.py + benchmark.py: --kvcache-migration-reject-threshold flag wiring. Docs: - REFACTOR_PLAN_V1_ZH.md: forward-looking plan after ts=1 validation. - MIGRATION_V1_FINDINGS_ZH.md: v1 thrashing root-cause analysis. - V2_RESULTS_ZH.md: v2 results, scenario C achievement, attribution. - TEAM_REPORT_AGENTIC_PD_HYBRID_ZH.md: comprehensive team report. Scripts: - sweep_ts1_kvc_n3_plus_dp.sh: ts=1 baseline (KVC 1P3D N=3 + 4DP CA). - sweep_ts1_migration_v1.sh / v2.sh: validation runs. - analyze_ts1_validation.py: 4-way comparison analyzer. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
272 lines
9.4 KiB
Python
272 lines
9.4 KiB
Python
from __future__ import annotations
|
|
|
|
from collections import Counter
|
|
from dataclasses import dataclass, field
|
|
from typing import Protocol
|
|
|
|
from agentic_pd_hybrid.topology import SingleNodeTopology
|
|
from agentic_pd_hybrid.trace import TraceRequest
|
|
|
|
|
|
@dataclass
|
|
class SessionRouteState:
|
|
last_decode_worker: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class RoutingDecision:
|
|
policy_name: str
|
|
prefill_worker_id: str
|
|
decode_worker_id: str
|
|
decode_worker_index: int
|
|
reuse_expected: bool
|
|
observed_overlap_blocks: int
|
|
kv_transfer_blocks: int
|
|
inflight_decode_load: int
|
|
session_id: str
|
|
request_id: str
|
|
turn_id: int
|
|
|
|
@property
|
|
def observed_reuse(self) -> bool:
|
|
return self.observed_overlap_blocks > 0
|
|
|
|
@property
|
|
def re_prefill_required(self) -> bool:
|
|
return self.turn_id > 1 and self.observed_overlap_blocks == 0
|
|
|
|
|
|
@dataclass
|
|
class RoutingState:
|
|
prefill_cursor: int = 0
|
|
decode_cursor: int = 0
|
|
session_state: dict[str, SessionRouteState] = field(default_factory=dict)
|
|
inflight_decode: Counter[str] = field(default_factory=Counter)
|
|
decode_assignment_counts: Counter[str] = field(default_factory=Counter)
|
|
decode_resident_blocks: dict[str, set[int]] = field(default_factory=dict)
|
|
# Migration support: per-(session_id, decode_worker_id) admission reject counter.
|
|
# KvAwarePolicy uses this to skip D's that have repeatedly rejected this session
|
|
# (avoids the structural starvation observed in TEAM_REPORT §2.1).
|
|
session_d_rejects: Counter[tuple[str, str]] = field(default_factory=Counter)
|
|
|
|
@classmethod
|
|
def create(cls, topology: SingleNodeTopology) -> "RoutingState":
|
|
return cls(
|
|
decode_resident_blocks={
|
|
worker.worker_id: set() for worker in topology.route_workers
|
|
}
|
|
)
|
|
|
|
def next_prefill_worker_id(self, topology: SingleNodeTopology) -> str:
|
|
if not topology.prefill_workers:
|
|
return "none"
|
|
worker = topology.prefill_workers[self.prefill_cursor % len(topology.prefill_workers)]
|
|
self.prefill_cursor += 1
|
|
return worker.worker_id
|
|
|
|
def next_decode_worker_id(self, topology: SingleNodeTopology) -> str:
|
|
route_workers = topology.route_workers
|
|
worker = route_workers[self.decode_cursor % len(route_workers)]
|
|
self.decode_cursor += 1
|
|
return worker.worker_id
|
|
|
|
def record_admission_reject(self, session_id: str, decode_worker_id: str) -> int:
|
|
"""Increment per-(session, D) rejection counter. Returns new count."""
|
|
key = (session_id, decode_worker_id)
|
|
self.session_d_rejects[key] += 1
|
|
return self.session_d_rejects[key]
|
|
|
|
def finish(self, request: TraceRequest, decision: RoutingDecision) -> None:
|
|
session = self.session_state.setdefault(request.session_id, SessionRouteState())
|
|
session.last_decode_worker = decision.decode_worker_id
|
|
self.decode_resident_blocks[decision.decode_worker_id].update(request.hash_ids)
|
|
self.inflight_decode[decision.decode_worker_id] -= 1
|
|
if self.inflight_decode[decision.decode_worker_id] <= 0:
|
|
del self.inflight_decode[decision.decode_worker_id]
|
|
|
|
|
|
class RoutingPolicy(Protocol):
|
|
name: str
|
|
|
|
def select(
|
|
self,
|
|
request: TraceRequest,
|
|
*,
|
|
topology: SingleNodeTopology,
|
|
state: RoutingState,
|
|
) -> RoutingDecision:
|
|
...
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DefaultPolicy:
|
|
name: str = "default"
|
|
|
|
def select(
|
|
self,
|
|
request: TraceRequest,
|
|
*,
|
|
topology: SingleNodeTopology,
|
|
state: RoutingState,
|
|
) -> RoutingDecision:
|
|
prefill_worker_id = state.next_prefill_worker_id(topology)
|
|
decode_worker_id = state.next_decode_worker_id(topology)
|
|
return _build_decision(
|
|
policy_name=self.name,
|
|
request=request,
|
|
topology=topology,
|
|
state=state,
|
|
prefill_worker_id=prefill_worker_id,
|
|
decode_worker_id=decode_worker_id,
|
|
reuse_expected=False,
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class StickyDecodePolicy:
|
|
name: str = "sticky"
|
|
|
|
def select(
|
|
self,
|
|
request: TraceRequest,
|
|
*,
|
|
topology: SingleNodeTopology,
|
|
state: RoutingState,
|
|
) -> RoutingDecision:
|
|
session = state.session_state.get(request.session_id)
|
|
prefill_worker_id = state.next_prefill_worker_id(topology)
|
|
if request.turn_id > 1 and session and session.last_decode_worker is not None:
|
|
decode_worker_id = session.last_decode_worker
|
|
reuse_expected = True
|
|
else:
|
|
decode_worker_id = state.next_decode_worker_id(topology)
|
|
reuse_expected = False
|
|
return _build_decision(
|
|
policy_name=self.name,
|
|
request=request,
|
|
topology=topology,
|
|
state=state,
|
|
prefill_worker_id=prefill_worker_id,
|
|
decode_worker_id=decode_worker_id,
|
|
reuse_expected=reuse_expected,
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class KvAwarePolicy:
|
|
name: str = "kv-aware"
|
|
sticky_bonus: int = 1
|
|
# Session migration: when (session, D) has been rejected this many times,
|
|
# skip D entirely for this session (force migration to another D).
|
|
# 0 disables the mechanism. Default 3 picked empirically to allow brief
|
|
# transient saturation without panicking, but to reroute persistent starvation.
|
|
migration_reject_threshold: int = 3
|
|
|
|
def select(
|
|
self,
|
|
request: TraceRequest,
|
|
*,
|
|
topology: SingleNodeTopology,
|
|
state: RoutingState,
|
|
) -> RoutingDecision:
|
|
prefill_worker_id = state.next_prefill_worker_id(topology)
|
|
session = state.session_state.get(request.session_id)
|
|
|
|
best_decode_worker_id: str | None = None
|
|
best_score: tuple[int, int, int, int] | None = None
|
|
candidates_considered = 0
|
|
for worker in topology.route_workers:
|
|
# Migration: skip workers that have rejected this session too many times.
|
|
# If all candidates get filtered (degenerate case), fall through to
|
|
# un-filtered selection below.
|
|
if self.migration_reject_threshold > 0:
|
|
rejects = state.session_d_rejects.get(
|
|
(request.session_id, worker.worker_id), 0
|
|
)
|
|
if rejects >= self.migration_reject_threshold:
|
|
continue
|
|
candidates_considered += 1
|
|
overlap = _overlap_blocks(request, state, worker.worker_id)
|
|
sticky = int(session is not None and session.last_decode_worker == worker.worker_id)
|
|
inflight_penalty = -state.inflight_decode.get(worker.worker_id, 0)
|
|
assignment_penalty = -state.decode_assignment_counts.get(worker.worker_id, 0)
|
|
score = (
|
|
overlap + sticky * self.sticky_bonus,
|
|
sticky,
|
|
inflight_penalty,
|
|
assignment_penalty,
|
|
)
|
|
if best_score is None or score > best_score:
|
|
best_score = score
|
|
best_decode_worker_id = worker.worker_id
|
|
|
|
# Degenerate fallback: every D was filtered. Pick the least-rejected D.
|
|
if best_decode_worker_id is None:
|
|
best_decode_worker_id = min(
|
|
(w.worker_id for w in topology.route_workers),
|
|
key=lambda wid: state.session_d_rejects.get(
|
|
(request.session_id, wid), 0
|
|
),
|
|
)
|
|
best_score = (0, 0, 0, 0)
|
|
|
|
assert best_decode_worker_id is not None
|
|
reuse_expected = bool(best_score and best_score[0] > 0)
|
|
return _build_decision(
|
|
policy_name=self.name,
|
|
request=request,
|
|
topology=topology,
|
|
state=state,
|
|
prefill_worker_id=prefill_worker_id,
|
|
decode_worker_id=best_decode_worker_id,
|
|
reuse_expected=reuse_expected,
|
|
)
|
|
|
|
|
|
def create_policy(name: str, *, migration_reject_threshold: int = 3) -> RoutingPolicy:
|
|
normalized = name.strip().lower()
|
|
if normalized == "default":
|
|
return DefaultPolicy()
|
|
if normalized == "sticky":
|
|
return StickyDecodePolicy()
|
|
if normalized in {"kv-aware", "kv_aware", "kv"}:
|
|
return KvAwarePolicy(migration_reject_threshold=migration_reject_threshold)
|
|
raise ValueError(f"Unsupported policy: {name}")
|
|
|
|
|
|
def _build_decision(
|
|
*,
|
|
policy_name: str,
|
|
request: TraceRequest,
|
|
topology: SingleNodeTopology,
|
|
state: RoutingState,
|
|
prefill_worker_id: str,
|
|
decode_worker_id: str,
|
|
reuse_expected: bool,
|
|
) -> RoutingDecision:
|
|
overlap = _overlap_blocks(request, state, decode_worker_id)
|
|
state.inflight_decode[decode_worker_id] += 1
|
|
state.decode_assignment_counts[decode_worker_id] += 1
|
|
return RoutingDecision(
|
|
policy_name=policy_name,
|
|
prefill_worker_id=prefill_worker_id,
|
|
decode_worker_id=decode_worker_id,
|
|
decode_worker_index=topology.route_index(decode_worker_id),
|
|
reuse_expected=reuse_expected,
|
|
observed_overlap_blocks=overlap,
|
|
kv_transfer_blocks=max(0, len(request.hash_ids) - overlap),
|
|
inflight_decode_load=state.inflight_decode[decode_worker_id],
|
|
session_id=request.session_id,
|
|
request_id=request.request_id,
|
|
turn_id=request.turn_id,
|
|
)
|
|
|
|
|
|
def _overlap_blocks(
|
|
request: TraceRequest,
|
|
state: RoutingState,
|
|
decode_worker_id: str,
|
|
) -> int:
|
|
resident = state.decode_resident_blocks.get(decode_worker_id, set())
|
|
return sum(1 for block in request.hash_ids if block in resident)
|