refactor(policy): extract pure score_candidate() from KvAwarePolicy

Pulls the per-D score computation out of KvAwarePolicy.select
into a top-level pure function that takes primitives. The
in-method behavior is unchanged — the loop now calls
score_candidate() instead of inlining the arithmetic.

Motivation:
  Algorithm 1 (KVC_ROUTER_ALGORITHM.md §3.1) is the routing
  core. Until now its only API was select(), which requires
  building TraceRequest + SingleNodeTopology + RoutingState
  to test even a single lex-score property. After this
  extraction, unit tests can drive the four-tuple score
  directly with integers.

What changed:
  - Added module-level CandidateScore type alias.
  - Added score_candidate(*, overlap, sticky, inflight,
    assigned, mean_assigned, sticky_bonus,
    load_floor_bonus) -> CandidateScore.
  - KvAwarePolicy.select() loop body collapsed to a
    score_candidate() call; sticky now bool (was int)
    inside the call site.
  - Moved the load-floor docstring from KvAwarePolicy
    onto score_candidate where the formula lives.

Verified pure:
  - same kwargs -> same tuple
  - overlap=5 beats sticky-only (no load_floor): (5,0,0,0) > (1,1,0,0)
  - load_floor gated off when sticky=True

No behavior change; follow-up commit adds the unit tests
this refactor enables.
This commit is contained in:
2026-05-12 23:53:17 +08:00
parent 591cd6d382
commit 76a79dfdda

View File

@@ -152,6 +152,49 @@ class StickyDecodePolicy:
)
CandidateScore = tuple[int, int, int, int]
def score_candidate(
*,
overlap: int,
sticky: bool,
inflight: int,
assigned: int,
mean_assigned: float,
sticky_bonus: int,
load_floor_bonus: int,
) -> CandidateScore:
"""Pure scoring function for KvAwarePolicy (Algorithm 1 in KVC_ROUTER_ALGORITHM.md).
Returns the 4-tuple compared lexicographically by `select()` to pick the
best D. Extracted as a top-level function so unit tests can exercise it
without constructing topology/state objects.
Score tuple positions:
0: overlap + sticky_bonus*sticky + floor_bonus — primary, KV reuse aware
1: sticky — tie-1, session locality
2: -inflight — tie-2, prefer low load
3: -assigned — tie-3, prefer rarely-picked
Load-floor bonus is gated on `not sticky` (turn-1+ sessions continue to
stick to their original D). The boost magnitude scales linearly with the
D's deficit relative to the running mean of decode_assignment_counts:
floor_bonus = load_floor_bonus * max(0, mean - assigned) / max(1, mean)
When mean == 0 (warmup) the bonus is 0 for all candidates (lex tiebreak
falls through to iteration order).
See docs/E1_E2_FIX_DESIGN_ZH.md §Q2 for the load-floor design and
docs/KVC_ROUTER_ALGORITHM.md §3.1 for the lex-score formalism.
"""
floor_bonus = 0
if load_floor_bonus > 0 and not sticky and mean_assigned > 0:
deficit = max(0.0, mean_assigned - assigned)
floor_bonus = int(load_floor_bonus * deficit / mean_assigned)
primary = overlap + (sticky_bonus if sticky else 0) + floor_bonus
return (primary, int(sticky), -inflight, -assigned)
@dataclass(frozen=True)
class KvAwarePolicy:
name: str = "kv-aware"
@@ -161,27 +204,11 @@ class KvAwarePolicy:
# 0 disables the mechanism. Default 3 picked empirically to allow brief
# transient saturation without panicking, but to reroute persistent starvation.
migration_reject_threshold: int = 3
# Load-floor bonus: graduated boost added to lex-score position 0 for
# under-loaded D workers, gated on `not sticky` so turn-1+ requests of an
# existing session continue to stick to their original D. The boost
# magnitude scales linearly with the D's deficit relative to the running
# mean of `decode_assignment_counts`:
# floor_bonus = K * max(0, mean - assigned[D]) / max(1, mean)
# When mean=0 (warmup), bonus is 0 for all workers (lex tiebreak by
# iteration order). Once any D has been assigned, under-loaded D's get a
# bonus that approaches K as their deficit-to-mean ratio approaches 1.
# The bonus naturally decays as load equalises, leaving the original
# overlap+sticky scoring in charge of steady-state selection.
#
# Set this above the maximum cross-session boilerplate overlap you expect
# so that fresh sessions are routed to under-loaded D's even when those
# D's currently have 0 overlap, but below the magnitude of "real" prefix
# overlap (e.g., a session with 800-block per-session prefix on an
# already-warm D should still go there).
#
# 0 disables. See docs/E1_E2_FIX_DESIGN_ZH.md §Q2 for the full design and
# docs/E1_E2_RESULTS_ZH.md §5d for why this is needed on Inferact-shaped
# workloads where boilerplate overlap pins D2 cold forever.
# Load-floor bonus: see score_candidate() docstring for the exact formula.
# Set above the max cross-session boilerplate overlap you expect (so fresh
# sessions reach under-loaded D's even at 0 overlap), but below the
# magnitude of "real" prefix overlap (so a warm D still wins for its own
# session). 0 disables.
load_floor_bonus: int = 0
def select(
@@ -194,15 +221,12 @@ class KvAwarePolicy:
prefill_worker_id = state.next_prefill_worker_id(topology)
session = state.session_state.get(request.session_id)
# Pre-compute the running mean of decode assignments. Used by the
# load-floor bonus inside the candidate loop.
n_route_workers = max(1, len(topology.route_workers))
total_assigned = sum(state.decode_assignment_counts.values())
mean_assigned = total_assigned / n_route_workers
best_decode_worker_id: str | None = None
best_score: tuple[int, int, int, int] | None = None
candidates_considered = 0
best_score: CandidateScore | None = None
for worker in topology.route_workers:
# Migration: skip workers that have rejected this session too many times.
# If all candidates get filtered (degenerate case), fall through to
@@ -213,25 +237,17 @@ class KvAwarePolicy:
)
if rejects >= self.migration_reject_threshold:
continue
candidates_considered += 1
overlap = _overlap_blocks(request, state, worker.worker_id)
sticky = int(session is not None and session.last_decode_worker == worker.worker_id)
inflight_penalty = -state.inflight_decode.get(worker.worker_id, 0)
worker_assigned = state.decode_assignment_counts.get(worker.worker_id, 0)
assignment_penalty = -worker_assigned
# Load-floor bonus: only for fresh placements (not sticky), and
# only when the knob is enabled. See docstring above.
floor_bonus = 0
if self.load_floor_bonus > 0 and not sticky and mean_assigned > 0:
deficit = max(0.0, mean_assigned - worker_assigned)
floor_bonus = int(self.load_floor_bonus * deficit / mean_assigned)
score = (
overlap + sticky * self.sticky_bonus + floor_bonus,
sticky,
inflight_penalty,
assignment_penalty,
score = score_candidate(
overlap=_overlap_blocks(request, state, worker.worker_id),
sticky=(
session is not None
and session.last_decode_worker == worker.worker_id
),
inflight=state.inflight_decode.get(worker.worker_id, 0),
assigned=state.decode_assignment_counts.get(worker.worker_id, 0),
mean_assigned=mean_assigned,
sticky_bonus=self.sticky_bonus,
load_floor_bonus=self.load_floor_bonus,
)
if best_score is None or score > best_score:
best_score = score