Files
agentic-pd-hybrid/src/agentic_pd_hybrid/policies.py
kzlin 2ec0debef4 feat(kvc): session migration with reset-on-success + direct-append threshold tuning
KVC v2 beats 4DP at ts=1 same-scale on 7/8 metrics:
  TTFT mean -24%, p50 -54%, p90 -64%; lat mean -0.8%, p50 -12.6%, p90 -0.7%.
  Direct-to-D rate jumped 42.8% -> 91.7%. REFACTOR_PLAN_V1 scenario C achieved.

Two-knob fix:
- reset-on-success blacklist decay: clear (sess, D) reject counter on
  successful direct-to-D path. Eliminates v1 thrashing where session 6880
  was stable on decode-1 for 70 turns then collapsed to 75 D-changes after
  cumulative transient pressure tripped the permanent blacklist.
- bump --kvcache-direct-max-uncached-tokens default 2048 -> 8192 via CLI flag.
  41% of v1 fallbacks were 'real-large-append' (>2048 token append); raising
  the threshold lets these go through the direct-to-D fast path.

Code:
- policies.py: RoutingState.session_d_rejects counter + KvAwarePolicy
  migration_reject_threshold; degenerate fallback picks least-rejected D.
- replay.py: record_admission_reject + reset-on-success in _run_request;
  _fallthrough_reason classifies turn-2+ fall-throughs as session-not-resident
  / real-large-append / etc, replacing misleading 'large-append' suffix
  (TEAM_REPORT §2.7).
- cli.py + benchmark.py: --kvcache-migration-reject-threshold flag wiring.

Docs:
- REFACTOR_PLAN_V1_ZH.md: forward-looking plan after ts=1 validation.
- MIGRATION_V1_FINDINGS_ZH.md: v1 thrashing root-cause analysis.
- V2_RESULTS_ZH.md: v2 results, scenario C achievement, attribution.
- TEAM_REPORT_AGENTIC_PD_HYBRID_ZH.md: comprehensive team report.

Scripts:
- sweep_ts1_kvc_n3_plus_dp.sh: ts=1 baseline (KVC 1P3D N=3 + 4DP CA).
- sweep_ts1_migration_v1.sh / v2.sh: validation runs.
- analyze_ts1_validation.py: 4-way comparison analyzer.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-09 01:18:13 +08:00

272 lines
9.4 KiB
Python

from __future__ import annotations
from collections import Counter
from dataclasses import dataclass, field
from typing import Protocol
from agentic_pd_hybrid.topology import SingleNodeTopology
from agentic_pd_hybrid.trace import TraceRequest
@dataclass
class SessionRouteState:
last_decode_worker: str | None = None
@dataclass
class RoutingDecision:
policy_name: str
prefill_worker_id: str
decode_worker_id: str
decode_worker_index: int
reuse_expected: bool
observed_overlap_blocks: int
kv_transfer_blocks: int
inflight_decode_load: int
session_id: str
request_id: str
turn_id: int
@property
def observed_reuse(self) -> bool:
return self.observed_overlap_blocks > 0
@property
def re_prefill_required(self) -> bool:
return self.turn_id > 1 and self.observed_overlap_blocks == 0
@dataclass
class RoutingState:
prefill_cursor: int = 0
decode_cursor: int = 0
session_state: dict[str, SessionRouteState] = field(default_factory=dict)
inflight_decode: Counter[str] = field(default_factory=Counter)
decode_assignment_counts: Counter[str] = field(default_factory=Counter)
decode_resident_blocks: dict[str, set[int]] = field(default_factory=dict)
# Migration support: per-(session_id, decode_worker_id) admission reject counter.
# KvAwarePolicy uses this to skip D's that have repeatedly rejected this session
# (avoids the structural starvation observed in TEAM_REPORT §2.1).
session_d_rejects: Counter[tuple[str, str]] = field(default_factory=Counter)
@classmethod
def create(cls, topology: SingleNodeTopology) -> "RoutingState":
return cls(
decode_resident_blocks={
worker.worker_id: set() for worker in topology.route_workers
}
)
def next_prefill_worker_id(self, topology: SingleNodeTopology) -> str:
if not topology.prefill_workers:
return "none"
worker = topology.prefill_workers[self.prefill_cursor % len(topology.prefill_workers)]
self.prefill_cursor += 1
return worker.worker_id
def next_decode_worker_id(self, topology: SingleNodeTopology) -> str:
route_workers = topology.route_workers
worker = route_workers[self.decode_cursor % len(route_workers)]
self.decode_cursor += 1
return worker.worker_id
def record_admission_reject(self, session_id: str, decode_worker_id: str) -> int:
"""Increment per-(session, D) rejection counter. Returns new count."""
key = (session_id, decode_worker_id)
self.session_d_rejects[key] += 1
return self.session_d_rejects[key]
def finish(self, request: TraceRequest, decision: RoutingDecision) -> None:
session = self.session_state.setdefault(request.session_id, SessionRouteState())
session.last_decode_worker = decision.decode_worker_id
self.decode_resident_blocks[decision.decode_worker_id].update(request.hash_ids)
self.inflight_decode[decision.decode_worker_id] -= 1
if self.inflight_decode[decision.decode_worker_id] <= 0:
del self.inflight_decode[decision.decode_worker_id]
class RoutingPolicy(Protocol):
name: str
def select(
self,
request: TraceRequest,
*,
topology: SingleNodeTopology,
state: RoutingState,
) -> RoutingDecision:
...
@dataclass(frozen=True)
class DefaultPolicy:
name: str = "default"
def select(
self,
request: TraceRequest,
*,
topology: SingleNodeTopology,
state: RoutingState,
) -> RoutingDecision:
prefill_worker_id = state.next_prefill_worker_id(topology)
decode_worker_id = state.next_decode_worker_id(topology)
return _build_decision(
policy_name=self.name,
request=request,
topology=topology,
state=state,
prefill_worker_id=prefill_worker_id,
decode_worker_id=decode_worker_id,
reuse_expected=False,
)
@dataclass(frozen=True)
class StickyDecodePolicy:
name: str = "sticky"
def select(
self,
request: TraceRequest,
*,
topology: SingleNodeTopology,
state: RoutingState,
) -> RoutingDecision:
session = state.session_state.get(request.session_id)
prefill_worker_id = state.next_prefill_worker_id(topology)
if request.turn_id > 1 and session and session.last_decode_worker is not None:
decode_worker_id = session.last_decode_worker
reuse_expected = True
else:
decode_worker_id = state.next_decode_worker_id(topology)
reuse_expected = False
return _build_decision(
policy_name=self.name,
request=request,
topology=topology,
state=state,
prefill_worker_id=prefill_worker_id,
decode_worker_id=decode_worker_id,
reuse_expected=reuse_expected,
)
@dataclass(frozen=True)
class KvAwarePolicy:
name: str = "kv-aware"
sticky_bonus: int = 1
# Session migration: when (session, D) has been rejected this many times,
# skip D entirely for this session (force migration to another D).
# 0 disables the mechanism. Default 3 picked empirically to allow brief
# transient saturation without panicking, but to reroute persistent starvation.
migration_reject_threshold: int = 3
def select(
self,
request: TraceRequest,
*,
topology: SingleNodeTopology,
state: RoutingState,
) -> RoutingDecision:
prefill_worker_id = state.next_prefill_worker_id(topology)
session = state.session_state.get(request.session_id)
best_decode_worker_id: str | None = None
best_score: tuple[int, int, int, int] | None = None
candidates_considered = 0
for worker in topology.route_workers:
# Migration: skip workers that have rejected this session too many times.
# If all candidates get filtered (degenerate case), fall through to
# un-filtered selection below.
if self.migration_reject_threshold > 0:
rejects = state.session_d_rejects.get(
(request.session_id, worker.worker_id), 0
)
if rejects >= self.migration_reject_threshold:
continue
candidates_considered += 1
overlap = _overlap_blocks(request, state, worker.worker_id)
sticky = int(session is not None and session.last_decode_worker == worker.worker_id)
inflight_penalty = -state.inflight_decode.get(worker.worker_id, 0)
assignment_penalty = -state.decode_assignment_counts.get(worker.worker_id, 0)
score = (
overlap + sticky * self.sticky_bonus,
sticky,
inflight_penalty,
assignment_penalty,
)
if best_score is None or score > best_score:
best_score = score
best_decode_worker_id = worker.worker_id
# Degenerate fallback: every D was filtered. Pick the least-rejected D.
if best_decode_worker_id is None:
best_decode_worker_id = min(
(w.worker_id for w in topology.route_workers),
key=lambda wid: state.session_d_rejects.get(
(request.session_id, wid), 0
),
)
best_score = (0, 0, 0, 0)
assert best_decode_worker_id is not None
reuse_expected = bool(best_score and best_score[0] > 0)
return _build_decision(
policy_name=self.name,
request=request,
topology=topology,
state=state,
prefill_worker_id=prefill_worker_id,
decode_worker_id=best_decode_worker_id,
reuse_expected=reuse_expected,
)
def create_policy(name: str, *, migration_reject_threshold: int = 3) -> RoutingPolicy:
normalized = name.strip().lower()
if normalized == "default":
return DefaultPolicy()
if normalized == "sticky":
return StickyDecodePolicy()
if normalized in {"kv-aware", "kv_aware", "kv"}:
return KvAwarePolicy(migration_reject_threshold=migration_reject_threshold)
raise ValueError(f"Unsupported policy: {name}")
def _build_decision(
*,
policy_name: str,
request: TraceRequest,
topology: SingleNodeTopology,
state: RoutingState,
prefill_worker_id: str,
decode_worker_id: str,
reuse_expected: bool,
) -> RoutingDecision:
overlap = _overlap_blocks(request, state, decode_worker_id)
state.inflight_decode[decode_worker_id] += 1
state.decode_assignment_counts[decode_worker_id] += 1
return RoutingDecision(
policy_name=policy_name,
prefill_worker_id=prefill_worker_id,
decode_worker_id=decode_worker_id,
decode_worker_index=topology.route_index(decode_worker_id),
reuse_expected=reuse_expected,
observed_overlap_blocks=overlap,
kv_transfer_blocks=max(0, len(request.hash_ids) - overlap),
inflight_decode_load=state.inflight_decode[decode_worker_id],
session_id=request.session_id,
request_id=request.request_id,
turn_id=request.turn_id,
)
def _overlap_blocks(
request: TraceRequest,
state: RoutingState,
decode_worker_id: str,
) -> int:
resident = state.decode_resident_blocks.get(decode_worker_id, set())
return sum(1 for block in request.hash_ids if block in resident)