Files
agentic-pd-hybrid/src/agentic_pd_hybrid/replay.py
Claude Code Agent e729d62ddf fix(d2p): structural log + relax entrance condition for sync
E4 forensic (docs/E4_RESULTS_ZH.md): 272 admission rejections triggered
the fallback seeded_router path, but zero /_snapshot/* HTTP calls hit
the workers. Two root causes:

1. _attempt_d_to_p_sync gated on agentic-side `decode_session.opened`.
   By the time fallback runs, agentic has already flipped that flag
   to False in response to admission rejection. But D-side
   SessionAwareCache may still hold the session (release_session is
   not called automatically on admission rejection). Removing the
   gate; let D respond authoritatively with "session-not-resident"
   if it has actually evicted.

2. _attempt_d_to_p_sync logged decisions via logger.info, but
   agentic has no root logger handler so those events silently sank.
   Switching every branch (entry skip, prepare fail/not-ok, dump
   fail/not-ok, finalize fail/not-ok, ok) to write a structural-log
   line at outputs/<run>/structural/d-to-p-sync.jsonl. Each line
   carries stage, reason, durations, bytes pushed.

The result doc is updated to reflect the honest E4-1 outcome and
the P1 fix list.
2026-05-13 09:34:09 +08:00

3219 lines
114 KiB
Python

from __future__ import annotations
import asyncio
import json
import time
from collections import Counter
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import Any, Literal
import httpx
from agentic_pd_hybrid.metrics import (
RequestMetrics,
write_metrics_jsonl,
write_summary_json,
)
from agentic_pd_hybrid.policies import RoutingState, create_policy
from agentic_pd_hybrid.topology import SingleNodeTopology
from agentic_pd_hybrid.trace import (
TraceRequest,
build_synthetic_append_chunk,
build_synthetic_prompt,
load_trace,
)
HeaderMode = Literal["none", "routing-key", "target-worker", "auto"]
KvCacheAdmissionMode = Literal["router", "worker"]
KvCachePrefillBackupPolicy = Literal["release-after-transfer", "capacity-backup"]
_ADMISSION_PROBE_TIMEOUT_S = 2.0
# --- Structural event logging (admission probes, backpressure pauses, ---
# --- session-D bindings). Module-level state keeps call-site diff small. ---
_STRUCTURAL_LOG_DIR: Path | None = None
_STRUCTURAL_LOG_LOCK = asyncio.Lock()
_STRUCTURAL_LOG_FILES: dict[str, Any] = {}
_STRUCTURAL_RUN_START_S: float = 0.0
def _structural_init(log_dir: Path | None) -> None:
global _STRUCTURAL_LOG_DIR, _STRUCTURAL_RUN_START_S
_STRUCTURAL_LOG_DIR = log_dir
_STRUCTURAL_RUN_START_S = time.perf_counter()
if log_dir is not None:
log_dir.mkdir(parents=True, exist_ok=True)
def _structural_close() -> None:
for handle in _STRUCTURAL_LOG_FILES.values():
try:
handle.close()
except Exception:
pass
_STRUCTURAL_LOG_FILES.clear()
async def _structural_emit(filename: str, event: dict[str, Any]) -> None:
if _STRUCTURAL_LOG_DIR is None:
return
event = {"t": round(time.perf_counter() - _STRUCTURAL_RUN_START_S, 4), **event}
async with _STRUCTURAL_LOG_LOCK:
handle = _STRUCTURAL_LOG_FILES.get(filename)
if handle is None:
handle = (_STRUCTURAL_LOG_DIR / filename).open("a", encoding="utf-8")
_STRUCTURAL_LOG_FILES[filename] = handle
handle.write(json.dumps(event, sort_keys=True) + "\n")
handle.flush()
@dataclass(frozen=True)
class ReplayConfig:
trace_path: Path
output_path: Path
policy_name: str
mechanism_name: str
topology: SingleNodeTopology
router_url: str | None = None
model_name: str | None = None
pace: bool = True
time_scale: float = 1.0
request_limit: int | None = None
concurrency_limit: int = 32
header_mode: HeaderMode = "auto"
timeout_s: float = 600.0
stream: bool = True
stream_idle_timeout_s: float | None = 900.0
kvcache_direct_max_uncached_tokens: int = 2048
kvcache_admission_mode: KvCacheAdmissionMode = "router"
kvcache_seed_max_resident_tokens: int | None = None
kvcache_seed_max_output_tokens: int | None = None
kvcache_seed_min_turn_id: int = 1
kvcache_seed_only_multiturn_sessions: bool = False
kvcache_seed_allowed_session_ids: frozenset[str] | None = None
kvcache_prefill_backup_policy: KvCachePrefillBackupPolicy = (
"release-after-transfer"
)
kvcache_seed_max_inflight_decode: int | None = 3
kvcache_seed_max_decode_transfer_queue_reqs: int | None = None
kvcache_direct_max_decode_transfer_queue_reqs: int | None = None
kvcache_prefill_priority_eviction: bool = False
kvcache_prefill_direct_priority: int = -100
kvcache_prefill_normal_priority: int = 100
pool_poll_interval_s: float = 0.0
pool_poll_include_sessions: bool = True
enable_backpressure: bool = False
backpressure_max_pause_s: float = 2.0
# Session migration via per-(sess, D) admission reject memory.
# When a session has been admission-rejected this many times on a given D,
# KvAwarePolicy skips that D for the session (forcing migration). Default 3.
# Set 0 to disable. See REFACTOR_PLAN_V1 §6.2.
kvcache_migration_reject_threshold: int = 3
# Load-floor bonus magnitude for KvAwarePolicy: graduated boost added to
# under-loaded D workers to break overlap-pinning imbalance on workloads
# with shared cross-session prefix. 0 disables. See
# docs/E1_E2_FIX_DESIGN_ZH.md §Q2.
kvcache_load_floor_bonus: int = 0
# D→P snapshot push: when True and reseed fires, agentic will RDMA-dump
# the session's KV from the D-side worker that last held it onto the P
# worker and insert into P's radix tree, so the subsequent P prefill
# hits cache. See docs/D_TO_P_SYNC_DESIGN_ZH.md.
enable_d_to_p_sync: bool = False
structural_log_dir: Path | None = None
@dataclass
class DirectSessionState:
session_id: str
server_url: str
opened: bool = False
last_trace_request: TraceRequest | None = None
resident_tokens: int = 0
last_access_s: float = 0.0
active_requests: int = 0
prefill_server_url: str | None = None
prefill_opened: bool = False
prefill_resident_tokens: int = 0
prefill_last_access_s: float = 0.0
prefill_low_priority: bool = False
@dataclass
class DecodeResidencyState:
capacity_tokens: dict[str, int] = field(default_factory=dict)
headroom_tokens: dict[str, int] = field(default_factory=dict)
reserved_decode_tokens: dict[str, int] = field(default_factory=dict)
resident_tokens_by_server: dict[str, int] = field(default_factory=dict)
reserved_tokens_by_server: dict[str, int] = field(default_factory=dict)
prefill_capacity_tokens: dict[str, int] = field(default_factory=dict)
prefill_headroom_tokens: dict[str, int] = field(default_factory=dict)
prefill_resident_tokens_by_server: dict[str, int] = field(default_factory=dict)
prefill_reserved_tokens_by_server: dict[str, int] = field(default_factory=dict)
decode_evictions_prefill_backed: int = 0
decode_evictions_without_prefill_backup: int = 0
# Backpressure: per-D timestamp until which new requests should pause.
pause_until_s: dict[str, float] = field(default_factory=dict)
@dataclass(frozen=True)
class DecodeLoadSnapshot:
timestamp_s: float
num_running_reqs: int
num_waiting_reqs: int
num_used_tokens: int
max_total_num_tokens: int
token_usage: float
decode_prealloc_queue_reqs: int
decode_transfer_queue_reqs: int
decode_retracted_queue_reqs: int
@dataclass(frozen=True)
class ExecutionResult:
execution_mode: str
actual_kv_transfer_blocks: int
effective_input_length: int | None
cached_tokens: int
session_reused: bool
session_reset: bool
latency_s: float | None
ttft_s: float | None
tpot_s: float | None
prefill_request_priority: int | None = None
decode_request_priority: int | None = None
error: str | None = None
actual_output_tokens: int | None = None
requested_output_tokens: int | None = None
finish_reason: str | None = None
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
structural_dir = config.structural_log_dir
if structural_dir is None and config.output_path is not None:
structural_dir = config.output_path.parent / "structural"
_structural_init(structural_dir)
requests = load_trace(config.trace_path, request_limit=config.request_limit)
if config.kvcache_seed_only_multiturn_sessions:
session_turns = Counter(request.session_id for request in requests)
config = replace(
config,
kvcache_seed_allowed_session_ids=frozenset(
session_id
for session_id, turn_count in session_turns.items()
if turn_count > 1
),
)
policy = create_policy(
config.policy_name,
migration_reject_threshold=config.kvcache_migration_reject_threshold,
load_floor_bonus=config.kvcache_load_floor_bonus,
)
state = RoutingState.create(config.topology)
state_lock = asyncio.Lock()
semaphore = asyncio.Semaphore(config.concurrency_limit)
start_time = time.perf_counter()
first_timestamp = requests[0].timestamp_s if requests else 0.0
session_tail_tasks: dict[str, asyncio.Task[RequestMetrics]] = {}
direct_sessions: dict[str, DirectSessionState] = {}
direct_session_lock = asyncio.Lock()
async with httpx.AsyncClient(timeout=config.timeout_s, trust_env=False) as client:
decode_residency = await _discover_decode_residency(
client=client,
config=config,
)
poll_task: asyncio.Task[None] | None = None
if config.pool_poll_interval_s > 0:
poll_workers: list[tuple[str, str, str]] = []
for worker in config.topology.decode_workers:
poll_workers.append((worker.worker_id, "decode", worker.url))
for worker in config.topology.prefill_workers:
poll_workers.append((worker.worker_id, "prefill", worker.url))
if poll_workers:
poll_output = config.output_path.parent / "d-pool-timeseries.jsonl"
poll_task = asyncio.create_task(
_poll_pool_timeseries(
client=client,
workers=poll_workers,
interval_s=config.pool_poll_interval_s,
output_path=poll_output,
start_time=start_time,
include_sessions=config.pool_poll_include_sessions,
)
)
tasks = []
for request in requests:
if config.pace:
target_offset = (request.timestamp_s - first_timestamp) / config.time_scale
sleep_s = target_offset - (time.perf_counter() - start_time)
if sleep_s > 0:
await asyncio.sleep(sleep_s)
tasks.append(
asyncio.create_task(
_run_request(
request=request,
config=config,
client=client,
policy=policy,
state=state,
state_lock=state_lock,
semaphore=semaphore,
direct_sessions=direct_sessions,
direct_session_lock=direct_session_lock,
decode_residency=decode_residency,
depends_on=session_tail_tasks.get(request.session_id),
)
)
)
session_tail_tasks[request.session_id] = tasks[-1]
results = await asyncio.gather(*tasks)
if poll_task is not None:
poll_task.cancel()
try:
await poll_task
except asyncio.CancelledError:
pass
for session in direct_sessions.values():
if session.opened:
try:
await _close_streaming_session(
client=client,
server_url=session.server_url,
session_id=session.session_id,
allow_missing=True,
)
except Exception:
pass
if session.prefill_opened and session.prefill_server_url is not None:
try:
await _close_streaming_session(
client=client,
server_url=session.prefill_server_url,
session_id=session.session_id,
allow_missing=True,
)
except Exception:
pass
write_metrics_jsonl(config.output_path, results)
write_summary_json(
config.output_path.with_suffix(config.output_path.suffix + ".summary.json"),
results,
trace_path=config.trace_path,
router_url=config.router_url,
)
_structural_close()
return results
async def _run_request(
*,
request: TraceRequest,
config: ReplayConfig,
client: httpx.AsyncClient,
policy,
state: RoutingState,
state_lock: asyncio.Lock,
semaphore: asyncio.Semaphore,
direct_sessions: dict[str, DirectSessionState],
direct_session_lock: asyncio.Lock,
decode_residency: DecodeResidencyState,
depends_on: asyncio.Task[RequestMetrics] | None,
) -> RequestMetrics:
if depends_on is not None:
await depends_on
async with semaphore:
async with state_lock:
decision = policy.select(request, topology=config.topology, state=state)
await _structural_emit(
"session-d-binding.jsonl",
{
"session_id": request.session_id,
"request_id": request.request_id,
"turn_id": request.turn_id,
"decode_worker_index": decision.decode_worker_index,
"decode_worker_id": decision.decode_worker_id,
"prefill_worker_id": decision.prefill_worker_id,
"observed_overlap_blocks": decision.observed_overlap_blocks,
"kv_transfer_blocks": decision.kv_transfer_blocks,
"inflight_decode_load": decision.inflight_decode_load,
},
)
try:
execution = await _execute_request(
client=client,
request=request,
config=config,
decision=decision,
direct_sessions=direct_sessions,
direct_session_lock=direct_session_lock,
decode_residency=decode_residency,
)
except Exception as exc: # pragma: no cover - defensive logging path
execution = ExecutionResult(
execution_mode=config.mechanism_name,
actual_kv_transfer_blocks=0,
effective_input_length=None,
cached_tokens=0,
session_reused=False,
session_reset=False,
latency_s=None,
ttft_s=None,
tpot_s=None,
error=f"{type(exc).__name__}: {exc}",
)
async with state_lock:
state.finish(request, decision)
# Migration feedback: if this request was forced into a fallback path
# because the chosen D rejected admission, record the (session, D)
# rejection so KvAwarePolicy can migrate this session next turn.
if _is_admission_rejection_mode(execution.execution_mode):
state.record_admission_reject(
request.session_id,
decision.decode_worker_id,
)
# Reset-on-success: a successful direct-to-D path proves D-X can
# currently serve this session — clear the cumulative reject counter
# so that brief past saturation doesn't permanently blacklist the D.
# (MIGRATION_V1_FINDINGS §4.1: blacklist-permanence bug fix.)
elif execution.execution_mode == "kvcache-direct-to-d-session":
state.session_d_rejects[
(request.session_id, decision.decode_worker_id)
] = 0
return RequestMetrics.from_decision(
request,
decision,
mechanism_name=config.mechanism_name,
execution_mode=execution.execution_mode,
actual_kv_transfer_blocks=execution.actual_kv_transfer_blocks,
effective_input_length=execution.effective_input_length,
cached_tokens=execution.cached_tokens,
session_reused=execution.session_reused,
session_reset=execution.session_reset,
latency_s=execution.latency_s,
ttft_s=execution.ttft_s,
tpot_s=execution.tpot_s,
prefill_request_priority=execution.prefill_request_priority,
decode_request_priority=execution.decode_request_priority,
error=execution.error,
actual_output_tokens=execution.actual_output_tokens,
requested_output_tokens=execution.requested_output_tokens,
finish_reason=execution.finish_reason,
)
async def _invoke_router(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
decode_worker_index: int,
session_id: str | None = None,
prefill_request_priority: int | None = None,
decode_request_priority: int | None = None,
decode_residency: "DecodeResidencyState | None" = None,
) -> GenerateResult:
if decode_residency is not None and config.enable_backpressure:
decode_url = config.topology.decode_workers[decode_worker_index].url
await _wait_for_decode_pause(
config=config,
residency=decode_residency,
server_url=decode_url,
request_id=request.request_id,
session_id=session_id,
)
headers = _build_headers(
request=request,
header_mode=config.header_mode,
decode_worker_index=decode_worker_index,
policy_name=config.policy_name,
)
assert config.router_url is not None
payload: dict[str, object] = {
"input_ids": _build_direct_full_input_ids(request),
"sampling_params": {
"temperature": 0,
"max_new_tokens": max(1, request.output_length),
"ignore_eos": True,
"no_stop_trim": True,
"skip_special_tokens": False,
},
"stream": config.stream,
}
if session_id is not None:
payload["session_params"] = {"id": session_id}
if prefill_request_priority is not None:
payload["smg_prefill_priority"] = prefill_request_priority
if decode_request_priority is not None:
payload["smg_decode_priority"] = decode_request_priority
return await _invoke_generate(
client=client,
base_url=config.router_url,
headers=headers,
payload=payload,
timeout_s=config.timeout_s,
stream_idle_timeout_s=config.stream_idle_timeout_s,
stream=config.stream,
)
def _build_payload(
*,
request: TraceRequest,
model_name: str,
prompt: str,
stream: bool,
session_params: dict[str, str] | None,
exact_output_length: bool,
) -> dict[str, object]:
payload: dict[str, object] = {
"model": model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max(1, request.output_length),
"temperature": 0,
"stream": stream,
}
if stream:
payload["stream_options"] = {"include_usage": True}
if exact_output_length:
payload.update(
{
"min_tokens": max(1, request.output_length),
"ignore_eos": True,
"no_stop_trim": True,
"skip_special_tokens": False,
}
)
if session_params is not None:
payload["session_params"] = session_params
return payload
async def _invoke_chat_completion(
*,
client: httpx.AsyncClient,
base_url: str,
headers: dict[str, str],
payload: dict[str, object],
timeout_s: float,
stream_idle_timeout_s: float | None,
stream: bool,
) -> tuple[float, float | None, float | None, int]:
start = time.perf_counter()
ttft_s: float | None = None
cached_tokens = 0
generated_tokens = int(payload.get("max_tokens", 1))
if stream:
async with client.stream(
"POST",
f"{base_url.rstrip('/')}/v1/chat/completions",
headers=headers,
json=payload,
timeout=timeout_s,
) as response:
response.raise_for_status()
async for line in _aiter_lines(
response,
idle_timeout_s=stream_idle_timeout_s,
):
if not line.startswith("data:"):
continue
data = line[5:].strip()
if data == "[DONE]":
break
parsed = json.loads(data)
cached_tokens = max(cached_tokens, _extract_openai_cached_tokens(parsed))
if _contains_token(parsed) and ttft_s is None:
ttft_s = time.perf_counter() - start
if _is_terminal_chunk(parsed):
break
else:
response = await client.post(
f"{base_url.rstrip('/')}/v1/chat/completions",
headers=headers,
json=payload,
timeout=timeout_s,
)
response.raise_for_status()
parsed = response.json()
cached_tokens = _extract_openai_cached_tokens(parsed)
latency_s = time.perf_counter() - start
if stream and ttft_s is None and generated_tokens > 0:
raise RuntimeError("generate stream ended before producing any token")
if ttft_s is None:
tpot_s = None
else:
tpot_s = max(0.0, latency_s - ttft_s) / max(1, generated_tokens)
return latency_s, ttft_s, tpot_s, cached_tokens
@dataclass(frozen=True)
class GenerateResult:
latency_s: float
ttft_s: float | None
tpot_s: float | None
cached_tokens: int
actual_output_tokens: int
requested_output_tokens: int
finish_reason: str | None
server_meta_info: dict | None
async def _invoke_generate(
*,
client: httpx.AsyncClient,
base_url: str,
headers: dict[str, str],
payload: dict[str, object],
timeout_s: float,
stream_idle_timeout_s: float | None,
stream: bool,
) -> GenerateResult:
start = time.perf_counter()
ttft_s: float | None = None
cached_tokens = 0
sampling_params = payload.get("sampling_params", {})
requested_output_tokens = int(sampling_params.get("max_new_tokens", 1))
actual_token_count = 0
finish_reason: str | None = None
last_meta_info: dict | None = None
if stream:
async with client.stream(
"POST",
f"{base_url.rstrip('/')}/generate",
headers=headers,
json=payload,
timeout=timeout_s,
) as response:
response.raise_for_status()
async for line in _aiter_lines(
response,
idle_timeout_s=stream_idle_timeout_s,
):
if not line.startswith("data:"):
continue
data = line[5:].strip()
if data == "[DONE]":
break
parsed = json.loads(data)
error = parsed.get("error")
if isinstance(error, dict):
raise ValueError(error.get("message", json.dumps(error)))
cached_tokens = max(cached_tokens, _extract_generate_cached_tokens(parsed))
if _contains_generate_token(parsed):
actual_token_count += 1
if ttft_s is None:
ttft_s = time.perf_counter() - start
meta_info = parsed.get("meta_info")
if isinstance(meta_info, dict):
last_meta_info = meta_info
completion_tokens = int(meta_info.get("completion_tokens", 0))
if completion_tokens > actual_token_count:
actual_token_count = completion_tokens
fr = meta_info.get("finish_reason")
if fr is not None:
finish_reason = str(fr)
if _is_generate_terminal_chunk(parsed):
break
else:
response = await client.post(
f"{base_url.rstrip('/')}/generate",
headers=headers,
json=payload,
timeout=timeout_s,
)
response.raise_for_status()
parsed = response.json()
error = parsed.get("error")
if isinstance(error, dict):
raise ValueError(error.get("message", json.dumps(error)))
cached_tokens = _extract_generate_cached_tokens(parsed)
meta_info = parsed.get("meta_info")
if isinstance(meta_info, dict):
last_meta_info = meta_info
actual_token_count = int(meta_info.get("completion_tokens", 0))
finish_reason = meta_info.get("finish_reason")
latency_s = time.perf_counter() - start
if stream and ttft_s is None and requested_output_tokens > 0:
raise RuntimeError("generate stream ended before producing any token")
# Use actual token count for TPOT (not requested count)
effective_tokens = max(1, actual_token_count) if actual_token_count > 0 else max(1, requested_output_tokens)
if ttft_s is None:
tpot_s = None
else:
tpot_s = max(0.0, latency_s - ttft_s) / effective_tokens
return GenerateResult(
latency_s=latency_s,
ttft_s=ttft_s,
tpot_s=tpot_s,
cached_tokens=cached_tokens,
actual_output_tokens=actual_token_count,
requested_output_tokens=requested_output_tokens,
finish_reason=finish_reason,
server_meta_info=last_meta_info,
)
async def _open_streaming_session(
*,
client: httpx.AsyncClient,
server_url: str,
session_id: str,
request: TraceRequest,
) -> None:
capacity = max(
4096,
request.input_length * 16,
(request.input_length + request.output_length) * 16,
)
response = await client.post(
f"{server_url.rstrip('/')}/open_session",
json={
"capacity_of_str_len": capacity,
"session_id": session_id,
"streaming": True,
},
timeout=_ADMISSION_PROBE_TIMEOUT_S,
)
response.raise_for_status()
opened_session_id = response.json()
if opened_session_id != session_id:
raise ValueError(
f"Unexpected session id from {server_url}: {opened_session_id!r} != {session_id!r}"
)
async def _close_streaming_session(
*,
client: httpx.AsyncClient,
server_url: str,
session_id: str,
allow_missing: bool = False,
) -> None:
response = await client.post(
f"{server_url.rstrip('/')}/close_session",
json={"session_id": session_id},
timeout=_ADMISSION_PROBE_TIMEOUT_S,
)
if response.is_success:
return
if allow_missing:
response_text = response.text.lower()
if response.status_code == 404 or "does not exist" in response_text:
return
response.raise_for_status()
def _extract_internal_state(payload: dict[str, Any]) -> dict[str, Any]:
internal_states = payload.get("internal_states")
if isinstance(internal_states, list) and internal_states:
internal_state = internal_states[0]
if isinstance(internal_state, dict):
return internal_state
return payload
def _extract_server_int(
payload: dict[str, Any],
key: str,
) -> int:
internal_state = _extract_internal_state(payload)
value = payload.get(key, internal_state.get(key, 0))
return int(value or 0)
def _extract_session_cache(
payload: dict[str, Any],
) -> dict[str, Any]:
internal_state = _extract_internal_state(payload)
session_cache = internal_state.get("session_cache")
if isinstance(session_cache, dict):
return session_cache
return {}
def _find_session_cache_status(
session_cache: dict[str, Any],
session_id: str,
) -> dict[str, Any] | None:
sessions = session_cache.get("sessions")
if not isinstance(sessions, list):
return None
for session in sessions:
if isinstance(session, dict) and session.get("session_id") == session_id:
return session
return None
async def _fetch_decode_server_state(
*,
client: httpx.AsyncClient,
server_url: str,
) -> tuple[dict[str, Any], int, int]:
try:
response = await client.get(
f"{server_url.rstrip('/')}/server_info",
timeout=_ADMISSION_PROBE_TIMEOUT_S,
)
response.raise_for_status()
payload = response.json()
except Exception:
return {}, 0, 0
return (
_extract_session_cache(payload),
_extract_server_int(payload, "max_total_num_tokens"),
_extract_server_int(payload, "num_reserved_decode_tokens"),
)
async def _query_pool_snapshot(
*,
client: httpx.AsyncClient,
server_url: str,
include_sessions: bool,
) -> dict[str, Any]:
try:
response = await client.get(
f"{server_url.rstrip('/')}/server_info",
timeout=_ADMISSION_PROBE_TIMEOUT_S,
)
response.raise_for_status()
payload = response.json()
except Exception as exc:
return {"error": type(exc).__name__}
internal = _extract_internal_state(payload)
session_cache = _extract_session_cache(payload)
sessions: list[dict[str, Any]] = []
if include_sessions and isinstance(session_cache.get("sessions"), list):
for entry in session_cache["sessions"]:
if not isinstance(entry, dict):
continue
sessions.append(
{
"session_id": entry.get("session_id"),
"resident": bool(entry.get("resident")),
"resident_tokens": int(entry.get("resident_tokens") or 0),
"idle_evictable": bool(entry.get("idle_evictable")),
"timed_out": bool(entry.get("timed_out")),
}
)
memory_usage = internal.get("memory_usage") if isinstance(internal, dict) else None
if not isinstance(memory_usage, dict):
memory_usage = {}
# P1 instrument: pool_breakdown decomposes "other" into named buckets
pool_breakdown = internal.get("pool_breakdown") if isinstance(internal, dict) else None
if not isinstance(pool_breakdown, dict):
pool_breakdown = {}
return {
"session_cache_enabled": bool(session_cache.get("enabled")),
"session_count": int(session_cache.get("session_count") or 0),
"resident_session_count": int(session_cache.get("resident_session_count") or 0),
"held_tokens": int(session_cache.get("held_tokens") or 0),
"available_tokens": int(session_cache.get("available_tokens") or 0),
"capacity_tokens": int(session_cache.get("capacity_tokens") or 0),
"idle_evictable_session_count": int(
session_cache.get("idle_evictable_session_count") or 0
),
"idle_evictable_tokens": int(session_cache.get("idle_evictable_tokens") or 0),
"kvcache_mem_gb": float(memory_usage.get("kvcache") or 0.0),
"token_capacity": int(memory_usage.get("token_capacity") or 0),
"max_total_num_tokens": int(internal.get("max_total_num_tokens") or 0)
if isinstance(internal, dict)
else 0,
"last_gen_throughput": float(internal.get("last_gen_throughput") or 0.0)
if isinstance(internal, dict)
else 0.0,
"radix_evictable_tokens": int(pool_breakdown.get("radix_evictable_tokens") or 0),
"radix_protected_tokens": int(pool_breakdown.get("radix_protected_tokens") or 0),
"slot_private_held_tokens": int(pool_breakdown.get("slot_private_held_tokens") or 0),
"session_slot_count": int(pool_breakdown.get("session_slot_count") or 0),
"running_batch_reqs": int(pool_breakdown.get("running_batch_reqs") or 0),
"running_batch_kv_tokens": int(pool_breakdown.get("running_batch_kv_tokens") or 0),
"transfer_queue_reqs": int(pool_breakdown.get("transfer_queue_reqs") or 0),
"transfer_queue_tokens": int(pool_breakdown.get("transfer_queue_tokens") or 0),
"prealloc_queue_reqs": int(pool_breakdown.get("prealloc_queue_reqs") or 0),
"prealloc_queue_tokens": int(pool_breakdown.get("prealloc_queue_tokens") or 0),
"retracted_queue_reqs": int(pool_breakdown.get("retracted_queue_reqs") or 0),
"retracted_queue_tokens": int(pool_breakdown.get("retracted_queue_tokens") or 0),
"sessions": sessions,
}
async def _poll_pool_timeseries(
*,
client: httpx.AsyncClient,
workers: list[tuple[str, str, str]],
interval_s: float,
output_path: Path,
start_time: float,
include_sessions: bool,
) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as handle:
try:
while True:
tick_started = time.perf_counter()
ts = time.time()
wall_s = tick_started - start_time
snapshots = await asyncio.gather(
*(
_query_pool_snapshot(
client=client,
server_url=url,
include_sessions=include_sessions,
)
for _, _, url in workers
),
return_exceptions=True,
)
for (worker_id, role, url), snap in zip(workers, snapshots):
if isinstance(snap, BaseException):
row: dict[str, Any] = {
"ts": ts,
"wall_s": wall_s,
"worker_id": worker_id,
"worker_role": role,
"worker_url": url,
"error": type(snap).__name__,
}
else:
row = {
"ts": ts,
"wall_s": wall_s,
"worker_id": worker_id,
"worker_role": role,
"worker_url": url,
**snap,
}
handle.write(json.dumps(row, sort_keys=True) + "\n")
handle.flush()
elapsed = time.perf_counter() - tick_started
sleep_s = interval_s - elapsed
if sleep_s > 0:
await asyncio.sleep(sleep_s)
except asyncio.CancelledError:
return
async def _query_decode_direct_admission(
*,
client: httpx.AsyncClient,
server_url: str,
session_id: str,
uncached_input_tokens: int,
output_tokens: int,
mode: str = "direct_append",
config: "ReplayConfig | None" = None,
residency: "DecodeResidencyState | None" = None,
request_id: str | None = None,
turn_id: int | None = None,
) -> dict[str, Any]:
started = time.perf_counter()
try:
response = await client.post(
f"{server_url.rstrip('/')}/session_cache/admit_direct_append",
json={
"session_id": session_id,
"uncached_input_tokens": max(0, uncached_input_tokens),
"output_tokens": max(0, output_tokens),
"mode": mode,
},
timeout=_ADMISSION_PROBE_TIMEOUT_S,
)
response.raise_for_status()
payload = response.json()
if not isinstance(payload, dict):
payload = None
except Exception as exc:
payload = None
_last_exc_msg = type(exc).__name__
else:
_last_exc_msg = None
if payload is None:
payload = {
"can_admit": False,
"resident": False,
"reason": "admission-query-failed",
"required_tokens": 0,
"available_tokens_before": 0,
"available_tokens_after": 0,
"evicted_session_count": 0,
"freed_tokens": 0,
}
rtt_s = time.perf_counter() - started
pause_ms = int(payload.get("recommended_pause_ms", 0) or 0)
# Update per-D pause window when backpressure is enabled.
if (
config is not None
and residency is not None
and config.enable_backpressure
and pause_ms > 0
):
max_pause_s = max(0.0, config.backpressure_max_pause_s)
applied_pause_s = min(pause_ms / 1000.0, max_pause_s)
new_until = time.perf_counter() + applied_pause_s
prev = residency.pause_until_s.get(server_url, 0.0)
if new_until > prev:
residency.pause_until_s[server_url] = new_until
# Always emit admission event for analysis (even if backpressure disabled).
await _structural_emit(
"admission-events.jsonl",
{
"server_url": server_url,
"session_id": session_id,
"request_id": request_id,
"turn_id": turn_id,
"mode": mode,
"rtt_s": round(rtt_s, 4),
"can_admit": bool(payload.get("can_admit")),
"resident": bool(payload.get("resident")),
"reason": payload.get("reason"),
"queue_depth": int(payload.get("decode_transfer_queue_reqs", 0) or 0),
"retracted_depth": int(payload.get("decode_retracted_queue_reqs", 0) or 0),
"available_tokens_after": int(payload.get("available_tokens_after", 0) or 0),
"token_usage": float(payload.get("token_usage", 0.0) or 0.0),
"evicted_session_count": int(payload.get("evicted_session_count", 0) or 0),
"recommended_pause_ms": pause_ms,
"uncached_input_tokens": int(uncached_input_tokens),
"output_tokens": int(output_tokens),
},
)
return payload
async def _wait_for_decode_pause(
*,
config: "ReplayConfig",
residency: "DecodeResidencyState",
server_url: str,
request_id: str | None = None,
session_id: str | None = None,
) -> None:
if not config.enable_backpressure:
return
until = residency.pause_until_s.get(server_url, 0.0)
if until <= 0:
return
now = time.perf_counter()
if now >= until:
return
sleep_s = min(until - now, config.backpressure_max_pause_s)
await _structural_emit(
"backpressure-events.jsonl",
{
"server_url": server_url,
"session_id": session_id,
"request_id": request_id,
"sleep_s": round(sleep_s, 4),
"until_offset_s": round(until - _STRUCTURAL_RUN_START_S, 4),
},
)
await asyncio.sleep(sleep_s)
async def _discover_decode_residency(
*,
client: httpx.AsyncClient,
config: ReplayConfig,
) -> DecodeResidencyState:
residency = DecodeResidencyState()
if config.mechanism_name != "kvcache-centric":
return residency
for worker in config.topology.decode_workers:
_session_cache, max_total_num_tokens, reserved_decode_tokens = (
await _fetch_decode_server_state(
client=client,
server_url=worker.url,
)
)
if max_total_num_tokens <= 0:
continue
safety_headroom = max(
reserved_decode_tokens * 4,
max_total_num_tokens // 20,
8192,
)
residency.capacity_tokens[worker.url] = max_total_num_tokens
residency.headroom_tokens[worker.url] = min(
max_total_num_tokens,
safety_headroom,
)
residency.reserved_decode_tokens[worker.url] = reserved_decode_tokens
for worker in config.topology.prefill_workers:
_session_cache, max_total_num_tokens, _reserved_decode_tokens = (
await _fetch_decode_server_state(
client=client,
server_url=worker.url,
)
)
if max_total_num_tokens <= 0:
continue
safety_headroom = max(
max_total_num_tokens // 10,
16384,
)
residency.prefill_capacity_tokens[worker.url] = max_total_num_tokens
residency.prefill_headroom_tokens[worker.url] = min(
max_total_num_tokens,
safety_headroom,
)
return residency
def _estimate_session_resident_tokens(request: TraceRequest) -> int:
return request.input_length + request.output_length
def _seed_filter_reason(
*,
request: TraceRequest,
config: ReplayConfig,
inflight_decode_load: int | None = None,
) -> str | None:
if request.turn_id < config.kvcache_seed_min_turn_id:
return "seed-filter-early-turn"
if (
config.kvcache_seed_max_inflight_decode is not None
and inflight_decode_load is not None
and inflight_decode_load > config.kvcache_seed_max_inflight_decode
):
return "seed-filter-inflight-decode-load"
if (
config.kvcache_seed_allowed_session_ids is not None
and request.session_id not in config.kvcache_seed_allowed_session_ids
):
return "seed-filter-single-turn-session"
resident_tokens = _estimate_session_resident_tokens(request)
if (
config.kvcache_seed_max_resident_tokens is not None
and resident_tokens > config.kvcache_seed_max_resident_tokens
):
return "seed-filter-resident-tokens"
if (
config.kvcache_seed_max_output_tokens is not None
and request.output_length > config.kvcache_seed_max_output_tokens
):
return "seed-filter-output-tokens"
return None
def _prefill_priority_for_router_request(
*,
config: ReplayConfig,
direct_to_d_predicted: bool,
) -> int | None:
if not config.kvcache_prefill_priority_eviction:
return None
if direct_to_d_predicted:
return config.kvcache_prefill_direct_priority
return config.kvcache_prefill_normal_priority
def _inspect_direct_request(
*,
request: TraceRequest,
session: DirectSessionState,
) -> tuple[int, bool, bool]:
previous = session.last_trace_request
if previous is None:
return request.input_length, False, False
append_length = request.input_length - (
previous.input_length + previous.output_length
)
if append_length <= 0:
return request.input_length, False, True
return append_length, True, False
def _add_reserved_tokens(
residency: DecodeResidencyState,
server_url: str,
delta_tokens: int,
) -> None:
if delta_tokens <= 0:
return
residency.reserved_tokens_by_server[server_url] = (
residency.reserved_tokens_by_server.get(server_url, 0) + delta_tokens
)
def _release_reserved_tokens(
residency: DecodeResidencyState,
server_url: str,
delta_tokens: int,
) -> None:
if delta_tokens <= 0:
return
remaining = residency.reserved_tokens_by_server.get(server_url, 0) - delta_tokens
if remaining > 0:
residency.reserved_tokens_by_server[server_url] = remaining
else:
residency.reserved_tokens_by_server.pop(server_url, None)
def _add_prefill_reserved_tokens(
residency: DecodeResidencyState,
server_url: str,
delta_tokens: int,
) -> None:
if delta_tokens <= 0:
return
residency.prefill_reserved_tokens_by_server[server_url] = (
residency.prefill_reserved_tokens_by_server.get(server_url, 0) + delta_tokens
)
def _release_prefill_reserved_tokens(
residency: DecodeResidencyState,
server_url: str,
delta_tokens: int,
) -> None:
if delta_tokens <= 0:
return
remaining = (
residency.prefill_reserved_tokens_by_server.get(server_url, 0) - delta_tokens
)
if remaining > 0:
residency.prefill_reserved_tokens_by_server[server_url] = remaining
else:
residency.prefill_reserved_tokens_by_server.pop(server_url, None)
def _usable_capacity_tokens(
residency: DecodeResidencyState,
server_url: str,
) -> int:
return max(
0,
residency.capacity_tokens.get(server_url, 0)
- residency.headroom_tokens.get(server_url, 0),
)
def _usable_prefill_backup_capacity_tokens(
residency: DecodeResidencyState,
server_url: str,
) -> int:
return max(
0,
residency.prefill_capacity_tokens.get(server_url, 0)
- residency.prefill_headroom_tokens.get(server_url, 0),
)
def _eviction_suffix(evicted_sessions: int, prefill_backed_evictions: int) -> str:
if evicted_sessions <= 0:
return ""
if prefill_backed_evictions >= evicted_sessions:
return "-after-prefill-backed-eviction"
if prefill_backed_evictions > 0:
return "-after-mixed-eviction"
return "-after-eviction"
def _decode_session_soft_cap(
*,
residency: DecodeResidencyState,
server_url: str,
request: TraceRequest,
) -> int:
target_tokens = max(1, _estimate_session_resident_tokens(request))
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
if usable_capacity_tokens <= 0:
usable_capacity_tokens = max(
0,
residency.capacity_tokens.get(server_url, 0)
- residency.headroom_tokens.get(server_url, 0),
)
if usable_capacity_tokens <= 0:
return 16
return max(1, min(16, usable_capacity_tokens // target_tokens))
def _should_admit_new_decode_session(
*,
residency: DecodeResidencyState,
server_url: str,
request: TraceRequest,
session: DirectSessionState,
direct_sessions: dict[str, DirectSessionState],
treat_as_fresh_session: bool,
admission_mode: KvCacheAdmissionMode = "router",
) -> bool:
if (
not treat_as_fresh_session
and session.opened
and session.server_url == server_url
):
return True
if admission_mode == "worker":
# Defer the capacity decision to D's admit_direct_append (mode=seed),
# which checks real KV pool availability and runs LRU eviction. The
# local soft cap is router-mode only.
return True
open_sessions = sum(
1
for candidate in direct_sessions.values()
if candidate.opened and candidate.server_url == server_url
)
return open_sessions < _decode_session_soft_cap(
residency=residency,
server_url=server_url,
request=request,
)
async def _fetch_decode_load_snapshot(
*,
client: httpx.AsyncClient,
server_url: str,
) -> DecodeLoadSnapshot | None:
try:
response = await client.get(
f"{server_url.rstrip('/')}/v1/loads",
params={"include": "core,disagg"},
timeout=_ADMISSION_PROBE_TIMEOUT_S,
)
response.raise_for_status()
payload = response.json()
except Exception:
return None
loads = payload.get("loads")
if not isinstance(loads, list) or not loads:
return None
load = loads[0]
disagg = load.get("disaggregation") or {}
return DecodeLoadSnapshot(
timestamp_s=time.perf_counter(),
num_running_reqs=int(load.get("num_running_reqs", 0) or 0),
num_waiting_reqs=int(load.get("num_waiting_reqs", 0) or 0),
num_used_tokens=int(load.get("num_used_tokens", 0) or 0),
max_total_num_tokens=int(load.get("max_total_num_tokens", 0) or 0),
token_usage=float(load.get("token_usage", 0.0) or 0.0),
decode_prealloc_queue_reqs=int(
disagg.get("decode_prealloc_queue_reqs", 0) or 0
),
decode_transfer_queue_reqs=int(
disagg.get("decode_transfer_queue_reqs", 0) or 0
),
decode_retracted_queue_reqs=int(
disagg.get("decode_retracted_queue_reqs", 0) or 0
),
)
def _decode_load_backpressure_reason(
snapshot: DecodeLoadSnapshot | None,
*,
config: ReplayConfig,
routing_mode: Literal["direct", "seed"],
) -> str | None:
if snapshot is None:
return None
if routing_mode == "direct":
if (
config.kvcache_direct_max_decode_transfer_queue_reqs is not None
and snapshot.decode_transfer_queue_reqs
> config.kvcache_direct_max_decode_transfer_queue_reqs
):
return "d-transfer-backpressure"
if snapshot.decode_retracted_queue_reqs > 0 and snapshot.token_usage >= 0.99:
return "d-retracted"
if snapshot.token_usage >= 0.992:
return "d-token-usage-critical"
else:
if (
config.kvcache_seed_max_decode_transfer_queue_reqs is not None
and snapshot.decode_transfer_queue_reqs
> config.kvcache_seed_max_decode_transfer_queue_reqs
):
return "d-transfer-backpressure"
if snapshot.decode_retracted_queue_reqs > 0:
return "d-retracted"
if snapshot.token_usage >= 0.985:
return "d-token-usage-critical"
if routing_mode == "seed" and snapshot.token_usage >= 0.94 and (
snapshot.decode_prealloc_queue_reqs > 0
or snapshot.decode_transfer_queue_reqs > 0
):
return "d-prealloc-backpressure"
return None
def _is_decode_backpressure_reason(reason: str | None) -> bool:
return reason in {
"d-retracted",
"d-token-usage-critical",
"d-prealloc-backpressure",
"d-transfer-backpressure",
}
def _is_stale_decode_session_error(exc: Exception) -> bool:
return (
isinstance(exc, httpx.HTTPStatusError)
and exc.response.status_code == 400
)
# execution_mode substrings that signal D-side admission rejected this request.
# Used by _run_request to update state.session_d_rejects so KvAwarePolicy can
# migrate persistently-starved sessions to a different D next turn.
_ADMISSION_REJECTION_SUBSTRINGS = (
"session-cap",
"no-d-capacity",
"d-backpressure",
)
def _is_admission_rejection_mode(execution_mode: str) -> bool:
return any(token in execution_mode for token in _ADMISSION_REJECTION_SUBSTRINGS)
def _fallthrough_reason(
*,
request: TraceRequest,
config: ReplayConfig,
decision,
direct_append_length: int | None,
direct_session_reused: bool,
direct_session_reset: bool,
) -> str:
"""Classify why a turn-2+ KVC request fell through to the seed/large-append branch.
Returns a short label suffix used in execution_mode strings to replace the
misleading 'large-append' label (TEAM_REPORT §2.7). In particular,
'session-not-resident' is the §1 starvation signature — direct_session_reused
is False because the session was never opened on the policy-chosen D.
"""
if not direct_session_reused:
return "session-not-resident"
if direct_session_reset:
return "session-was-evicted"
if direct_append_length is None:
return "no-direct-info"
if direct_append_length > config.kvcache_direct_max_uncached_tokens:
return "real-large-append"
if not _should_bypass_prefill(request=request, config=config, decision=decision):
return "policy-no-bypass"
return "other-large-append"
def _dynamic_decode_headroom_tokens(
*,
residency: DecodeResidencyState,
server_url: str,
snapshot: DecodeLoadSnapshot | None,
routing_mode: Literal["direct", "seed"],
) -> int:
if snapshot is None:
return residency.headroom_tokens.get(server_url, 0)
base_reserved = max(512, residency.reserved_decode_tokens.get(server_url, 0))
if routing_mode == "direct":
direct_queue_pressure = max(
1,
snapshot.decode_prealloc_queue_reqs
+ snapshot.decode_transfer_queue_reqs
+ snapshot.decode_retracted_queue_reqs,
)
capacity_divisor = 24
minimum_headroom = 4096
return max(
base_reserved * direct_queue_pressure,
snapshot.max_total_num_tokens // capacity_divisor,
minimum_headroom,
)
disagg_queued = (
snapshot.decode_prealloc_queue_reqs
+ snapshot.decode_transfer_queue_reqs
+ snapshot.decode_retracted_queue_reqs
)
active_decode_pressure = max(1, snapshot.num_running_reqs + disagg_queued)
capacity_divisor = 15
minimum_headroom = 12288
return max(
base_reserved * active_decode_pressure,
snapshot.max_total_num_tokens // capacity_divisor,
minimum_headroom,
)
def _commit_session_residency(
*,
residency: DecodeResidencyState,
session: DirectSessionState,
request: TraceRequest,
reserved_tokens: int,
) -> None:
_release_reserved_tokens(residency, session.server_url, reserved_tokens)
previous_tokens = session.resident_tokens if session.opened else 0
new_tokens = _estimate_session_resident_tokens(request)
delta_tokens = new_tokens - previous_tokens
if delta_tokens != 0:
residency.resident_tokens_by_server[session.server_url] = (
residency.resident_tokens_by_server.get(session.server_url, 0) + delta_tokens
)
session.opened = True
session.resident_tokens = new_tokens
session.last_trace_request = request
session.last_access_s = time.perf_counter()
if session.prefill_opened:
session.prefill_low_priority = True
def _commit_prefill_backup_residency(
*,
residency: DecodeResidencyState,
session: DirectSessionState,
request: TraceRequest,
prefill_url: str,
reserved_tokens: int,
) -> None:
_release_prefill_reserved_tokens(residency, prefill_url, reserved_tokens)
previous_tokens = session.prefill_resident_tokens if session.prefill_opened else 0
new_tokens = _estimate_session_resident_tokens(request)
delta_tokens = new_tokens - previous_tokens
if delta_tokens != 0:
residency.prefill_resident_tokens_by_server[prefill_url] = (
residency.prefill_resident_tokens_by_server.get(prefill_url, 0)
+ delta_tokens
)
session.prefill_server_url = prefill_url
session.prefill_opened = True
session.prefill_resident_tokens = new_tokens
session.prefill_last_access_s = time.perf_counter()
session.prefill_low_priority = session.opened
async def _close_prefill_session(
*,
client: httpx.AsyncClient,
session: DirectSessionState,
residency: DecodeResidencyState,
) -> None:
if not session.prefill_opened or session.prefill_server_url is None:
session.prefill_opened = False
session.prefill_resident_tokens = 0
session.prefill_low_priority = False
return
prefill_url = session.prefill_server_url
await _close_streaming_session(
client=client,
server_url=prefill_url,
session_id=session.session_id,
allow_missing=True,
)
remaining = (
residency.prefill_resident_tokens_by_server.get(prefill_url, 0)
- session.prefill_resident_tokens
)
if remaining > 0:
residency.prefill_resident_tokens_by_server[prefill_url] = remaining
else:
residency.prefill_resident_tokens_by_server.pop(prefill_url, None)
session.prefill_opened = False
session.prefill_resident_tokens = 0
session.prefill_low_priority = False
async def _close_decode_session(
*,
client: httpx.AsyncClient,
session: DirectSessionState,
residency: DecodeResidencyState,
evicting_for_capacity: bool = False,
) -> None:
if not session.opened:
session.resident_tokens = 0
return
await _close_streaming_session(
client=client,
server_url=session.server_url,
session_id=session.session_id,
allow_missing=True,
)
remaining = (
residency.resident_tokens_by_server.get(session.server_url, 0)
- session.resident_tokens
)
if remaining > 0:
residency.resident_tokens_by_server[session.server_url] = remaining
else:
residency.resident_tokens_by_server.pop(session.server_url, None)
session.opened = False
session.resident_tokens = 0
if session.prefill_opened:
residency.decode_evictions_prefill_backed += int(evicting_for_capacity)
session.prefill_low_priority = False
elif evicting_for_capacity:
residency.decode_evictions_without_prefill_backup += 1
async def _reserve_prefill_backup_capacity(
*,
client: httpx.AsyncClient,
request: TraceRequest,
prefill_url: str,
session: DirectSessionState,
direct_sessions: dict[str, DirectSessionState],
residency: DecodeResidencyState,
) -> tuple[bool, int, int]:
session_cache, max_total_num_tokens, _reserved_decode_tokens = (
await _fetch_decode_server_state(
client=client,
server_url=prefill_url,
)
)
if max_total_num_tokens > 0:
residency.prefill_capacity_tokens[prefill_url] = max_total_num_tokens
capacity_tokens = residency.prefill_capacity_tokens.get(prefill_url, 0)
headroom_tokens = residency.prefill_headroom_tokens.get(prefill_url, 0)
if capacity_tokens <= 0:
return True, 0, 0
low_occupancy_headroom_tokens = max(
headroom_tokens,
capacity_tokens // 2,
)
target_session_status = _find_session_cache_status(
session_cache,
session.session_id,
)
if (
isinstance(target_session_status, dict)
and bool(target_session_status.get("resident"))
):
current_tokens = int(target_session_status.get("resident_tokens", 0) or 0)
else:
current_tokens = (
session.prefill_resident_tokens
if session.prefill_opened and session.prefill_server_url == prefill_url
else 0
)
target_tokens = _estimate_session_resident_tokens(request)
required_extra_tokens = max(0, target_tokens - current_tokens)
evicted_sessions = 0
max_backup_sessions = max(1, capacity_tokens // max(1, target_tokens * 2))
max_backup_sessions = min(max_backup_sessions, 4)
available_tokens = int(session_cache.get("available_tokens", 0) or 0)
if available_tokens <= 0:
held_tokens = int(session_cache.get("held_tokens", 0) or 0)
available_tokens = max(0, capacity_tokens - held_tokens)
available_tokens -= residency.prefill_reserved_tokens_by_server.get(prefill_url, 0)
def has_enough_prefill_headroom() -> bool:
return available_tokens - required_extra_tokens >= low_occupancy_headroom_tokens
def prefill_backup_count() -> int:
return sum(
1
for candidate in direct_sessions.values()
if candidate.prefill_opened and candidate.prefill_server_url == prefill_url
)
while (
required_extra_tokens > 0
and (
not has_enough_prefill_headroom()
or (
not session.prefill_opened
and prefill_backup_count() >= max_backup_sessions
)
)
):
candidates = sorted(
(
candidate
for candidate in direct_sessions.values()
if candidate.prefill_opened
and candidate.prefill_server_url == prefill_url
and candidate.session_id != session.session_id
and candidate.active_requests <= 0
),
key=lambda candidate: (
0 if candidate.prefill_low_priority else 1,
candidate.prefill_last_access_s,
),
)
if not candidates:
break
freed_tokens = candidates[0].prefill_resident_tokens
await _close_prefill_session(
client=client,
session=candidates[0],
residency=residency,
)
available_tokens += freed_tokens
evicted_sessions += 1
if not has_enough_prefill_headroom():
return False, 0, evicted_sessions
_add_prefill_reserved_tokens(residency, prefill_url, required_extra_tokens)
return True, required_extra_tokens, evicted_sessions
async def _reserve_decode_session_capacity(
*,
client: httpx.AsyncClient,
config: ReplayConfig,
request: TraceRequest,
server_url: str,
session: DirectSessionState,
direct_sessions: dict[str, DirectSessionState],
residency: DecodeResidencyState,
treat_as_fresh_session: bool,
routing_mode: Literal["direct", "seed"],
admission_mode: KvCacheAdmissionMode,
) -> tuple[bool, int, int, int, str | None]:
if admission_mode == "router":
return await _reserve_decode_session_capacity_from_router_state(
client=client,
config=config,
request=request,
server_url=server_url,
session=session,
direct_sessions=direct_sessions,
residency=residency,
treat_as_fresh_session=treat_as_fresh_session,
routing_mode=routing_mode,
)
if treat_as_fresh_session and session.opened:
await _close_decode_session(
client=client,
session=session,
residency=residency,
)
current_tokens = 0 if treat_as_fresh_session else session.resident_tokens
target_tokens = _estimate_session_resident_tokens(request)
required_extra_tokens = max(0, target_tokens - current_tokens)
prefill_backed_evictions = 0
if routing_mode == "direct" and not treat_as_fresh_session:
if not session.opened:
return False, 0, 0, 0, "d-session-not-resident"
admission = await _query_decode_direct_admission(
client=client,
server_url=server_url,
session_id=session.session_id,
uncached_input_tokens=max(0, request.input_length - current_tokens),
output_tokens=request.output_length,
mode="direct_append",
config=config,
residency=residency,
request_id=request.request_id,
turn_id=request.turn_id,
)
if not bool(admission.get("resident")):
return False, 0, 0, 0, str(admission.get("reason") or "d-session-not-resident")
if not bool(admission.get("can_admit")):
return (
False,
0,
int(admission.get("evicted_session_count", 0) or 0),
0,
str(admission.get("reason") or "d-no-space"),
)
reserved_tokens = int(
admission.get("required_tokens", required_extra_tokens)
or required_extra_tokens
)
_add_reserved_tokens(residency, server_url, reserved_tokens)
return (
True,
reserved_tokens,
int(admission.get("evicted_session_count", 0) or 0),
0,
None,
)
# Seed / reseed path: ask D itself via the seed-mode admission endpoint
# instead of estimating capacity from a stale router-state snapshot. D
# will run LRU eviction internally to make room. Falls through to the
# legacy router-state logic below if the endpoint is unavailable.
seed_admission = await _query_decode_direct_admission(
client=client,
server_url=server_url,
session_id=session.session_id,
uncached_input_tokens=max(0, request.input_length - current_tokens),
output_tokens=request.output_length,
mode="seed",
config=config,
residency=residency,
request_id=request.request_id,
turn_id=request.turn_id,
)
seed_reason = seed_admission.get("reason")
if seed_reason != "admission-query-failed":
if not bool(seed_admission.get("can_admit")):
return (
False,
0,
int(seed_admission.get("evicted_session_count", 0) or 0),
0,
str(seed_reason or "d-no-space"),
)
reserved_tokens = int(
seed_admission.get("required_tokens", required_extra_tokens)
or required_extra_tokens
)
_add_reserved_tokens(residency, server_url, reserved_tokens)
return (
True,
reserved_tokens,
int(seed_admission.get("evicted_session_count", 0) or 0),
0,
None,
)
session_cache, max_total_num_tokens, reserved_decode_tokens = (
await _fetch_decode_server_state(
client=client,
server_url=server_url,
)
)
if max_total_num_tokens > 0:
residency.capacity_tokens[server_url] = max_total_num_tokens
if reserved_decode_tokens > 0:
residency.reserved_decode_tokens[server_url] = reserved_decode_tokens
target_session_status = _find_session_cache_status(
session_cache,
session.session_id,
)
if routing_mode == "direct" and not (
isinstance(target_session_status, dict)
and bool(target_session_status.get("resident"))
):
return False, 0, 0, 0, "d-session-not-resident"
load_snapshot = await _fetch_decode_load_snapshot(
client=client,
server_url=server_url,
)
if load_snapshot is not None and load_snapshot.max_total_num_tokens > 0:
residency.capacity_tokens[server_url] = load_snapshot.max_total_num_tokens
backpressure_reason = _decode_load_backpressure_reason(
load_snapshot,
config=config,
routing_mode=routing_mode,
)
if backpressure_reason is not None:
return False, 0, 0, 0, backpressure_reason
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
evicted_sessions = 0
while (
required_extra_tokens > 0
and residency.resident_tokens_by_server.get(server_url, 0)
+ residency.reserved_tokens_by_server.get(server_url, 0)
+ required_extra_tokens
> usable_capacity_tokens + int(session_cache.get("idle_evictable_tokens", 0) or 0)
):
candidates = sorted(
(
candidate
for candidate in direct_sessions.values()
if candidate.opened
and candidate.server_url == server_url
and candidate.session_id != session.session_id
and candidate.active_requests <= 0
),
key=lambda candidate: candidate.last_access_s,
)
if not candidates:
break
await _close_decode_session(
client=client,
session=candidates[0],
residency=residency,
evicting_for_capacity=True,
)
prefill_backed_evictions += int(candidates[0].prefill_opened)
evicted_sessions += 1
if evicted_sessions > 0:
load_snapshot = await _fetch_decode_load_snapshot(
client=client,
server_url=server_url,
)
session_cache, max_total_num_tokens, reserved_decode_tokens = (
await _fetch_decode_server_state(
client=client,
server_url=server_url,
)
)
if max_total_num_tokens > 0:
residency.capacity_tokens[server_url] = max_total_num_tokens
if reserved_decode_tokens > 0:
residency.reserved_decode_tokens[server_url] = reserved_decode_tokens
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
if load_snapshot is not None:
dynamic_headroom = _dynamic_decode_headroom_tokens(
residency=residency,
server_url=server_url,
snapshot=load_snapshot,
routing_mode=routing_mode,
)
residency.headroom_tokens[server_url] = min(
residency.capacity_tokens.get(server_url, dynamic_headroom),
dynamic_headroom,
)
usable_capacity_tokens = max(
0,
residency.capacity_tokens.get(server_url, 0) - dynamic_headroom,
)
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
if load_snapshot is not None:
dynamic_headroom = _dynamic_decode_headroom_tokens(
residency=residency,
server_url=server_url,
snapshot=load_snapshot,
routing_mode=routing_mode,
)
residency.headroom_tokens[server_url] = min(
residency.capacity_tokens.get(server_url, dynamic_headroom),
dynamic_headroom,
)
usable_capacity_tokens = max(
0,
residency.capacity_tokens.get(server_url, 0) - dynamic_headroom,
)
effective_used_tokens = (
load_snapshot.num_used_tokens
if load_snapshot is not None
else residency.resident_tokens_by_server.get(server_url, 0)
) + residency.reserved_tokens_by_server.get(server_url, 0)
idle_evictable_tokens = int(session_cache.get("idle_evictable_tokens", 0) or 0)
if (
routing_mode == "direct"
and isinstance(target_session_status, dict)
and bool(target_session_status.get("idle_evictable"))
):
idle_evictable_tokens = max(
0,
idle_evictable_tokens
- int(target_session_status.get("resident_tokens", 0) or 0),
)
if effective_used_tokens + required_extra_tokens > (
usable_capacity_tokens + idle_evictable_tokens
):
return False, 0, evicted_sessions, prefill_backed_evictions, "d-no-space"
_add_reserved_tokens(residency, server_url, required_extra_tokens)
return True, required_extra_tokens, evicted_sessions, prefill_backed_evictions, None
async def _reserve_decode_session_capacity_from_router_state(
*,
client: httpx.AsyncClient,
config: ReplayConfig,
request: TraceRequest,
server_url: str,
session: DirectSessionState,
direct_sessions: dict[str, DirectSessionState],
residency: DecodeResidencyState,
treat_as_fresh_session: bool,
routing_mode: Literal["direct", "seed"],
) -> tuple[bool, int, int, int, str | None]:
if treat_as_fresh_session and session.opened:
await _close_decode_session(
client=client,
session=session,
residency=residency,
)
if routing_mode == "direct" and not session.opened:
return False, 0, 0, 0, "d-session-not-resident"
current_tokens = 0 if treat_as_fresh_session else session.resident_tokens
target_tokens = _estimate_session_resident_tokens(request)
required_extra_tokens = max(0, target_tokens - current_tokens)
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
# If discovery failed, do not force every request down the P/D fallback path.
# The router can still preserve correctness; this only disables proactive
# capacity admission until the worker reports capacity again in a later run.
if usable_capacity_tokens <= 0:
_add_reserved_tokens(residency, server_url, required_extra_tokens)
return True, required_extra_tokens, 0, 0, None
evicted_sessions = 0
prefill_backed_evictions = 0
while (
required_extra_tokens > 0
and residency.resident_tokens_by_server.get(server_url, 0)
+ residency.reserved_tokens_by_server.get(server_url, 0)
+ required_extra_tokens
> usable_capacity_tokens
):
candidates = sorted(
(
candidate
for candidate in direct_sessions.values()
if candidate.opened
and candidate.server_url == server_url
and candidate.session_id != session.session_id
and candidate.active_requests <= 0
),
key=lambda candidate: candidate.last_access_s,
)
if not candidates:
break
await _close_decode_session(
client=client,
session=candidates[0],
residency=residency,
evicting_for_capacity=True,
)
prefill_backed_evictions += int(candidates[0].prefill_opened)
evicted_sessions += 1
if (
residency.resident_tokens_by_server.get(server_url, 0)
+ residency.reserved_tokens_by_server.get(server_url, 0)
+ required_extra_tokens
> usable_capacity_tokens
):
return False, 0, evicted_sessions, prefill_backed_evictions, "d-no-space"
_add_reserved_tokens(residency, server_url, required_extra_tokens)
return True, required_extra_tokens, evicted_sessions, prefill_backed_evictions, None
def _build_direct_prompt(
*,
request: TraceRequest,
session: DirectSessionState,
) -> tuple[str, int, bool, bool]:
append_length, session_reused, session_reset = _inspect_direct_request(
request=request,
session=session,
)
if session_reset:
return build_synthetic_prompt(request), request.input_length, False, True
if not session_reused:
return build_synthetic_prompt(request), request.input_length, False, False
return (
build_synthetic_append_chunk(request, append_length),
append_length,
session_reused,
session_reset,
)
def _build_direct_full_input_ids(
request: TraceRequest,
*,
block_token_budget: int = 24,
) -> list[int]:
input_ids: list[int] = []
for hash_id in request.hash_ids:
base = 1000 + (hash_id % 3000)
for offset in range(block_token_budget):
input_ids.append(1000 + ((base + offset) % 3000))
while len(input_ids) < request.input_length:
input_ids.append(5000 + (len(input_ids) % 2000))
return input_ids[: request.input_length]
def _build_direct_append_input_ids(
request: TraceRequest,
append_length: int,
) -> list[int]:
if append_length <= 0:
return []
base = 9000 + ((request.chat_id + request.turn_id) % 2000)
return [
9000 + ((base + offset) % 2000)
for offset in range(append_length)
]
async def _invoke_plain_router(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
decision,
execution_mode: str,
decode_residency: "DecodeResidencyState | None" = None,
) -> ExecutionResult:
prefill_priority = _prefill_priority_for_router_request(
config=config,
direct_to_d_predicted=False,
)
gen = await _invoke_router(
client=client,
request=request,
config=config,
decode_worker_index=decision.decode_worker_index,
prefill_request_priority=prefill_priority,
decode_residency=decode_residency,
)
return ExecutionResult(
execution_mode=execution_mode,
actual_kv_transfer_blocks=decision.kv_transfer_blocks,
effective_input_length=request.input_length,
cached_tokens=gen.cached_tokens,
prefill_request_priority=prefill_priority,
session_reused=False,
session_reset=False,
latency_s=gen.latency_s,
ttft_s=gen.ttft_s,
tpot_s=gen.tpot_s,
actual_output_tokens=gen.actual_output_tokens,
requested_output_tokens=gen.requested_output_tokens,
finish_reason=gen.finish_reason,
)
async def _attempt_d_to_p_sync(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
prefill_url: str,
decode_session: DirectSessionState,
) -> dict | None:
"""Try to RDMA-dump session KV from the D that last held it to ``prefill_url``.
Returns a dict with status info on success/skip, or ``None`` on a
non-recoverable error. The caller falls back to normal re-prefill on
any failure. Each path emits a structural-log line so we can forensic
why sync skipped vs succeeded vs failed.
"""
if not config.enable_d_to_p_sync:
return None
source_d_url = decode_session.server_url
sid = request.session_id
rid = request.request_id
if not source_d_url:
await _structural_emit(
"d-to-p-sync.jsonl",
{"event": "skipped", "stage": "entry", "sid": sid, "rid": rid,
"reason": "no-source-d"},
)
return {"status": "skipped-no-source-d"}
# NB: do NOT gate on decode_session.opened. By the time we reach the
# fallback seeded_router, agentic has already flipped that flag to False
# in response to admission rejection. But the D-side scheduler's
# SessionAwareCache may STILL hold the session resident (release_session
# is only called explicitly, not from admission events). Let D be the
# source of truth via its own snapshot_dump response.
target_tokens = max(0, int(_estimate_session_resident_tokens(request)))
if target_tokens <= 0:
await _structural_emit(
"d-to-p-sync.jsonl",
{"event": "skipped", "stage": "entry", "sid": sid, "rid": rid,
"reason": "zero-target-tokens"},
)
return {"status": "skipped-zero-tokens"}
t_prep0 = time.perf_counter()
try:
prep_resp = await client.post(
f"{prefill_url}/_snapshot/prepare_receive",
json={
"session_id": request.session_id,
"num_tokens": target_tokens,
},
timeout=30.0,
)
prep_resp.raise_for_status()
prep = prep_resp.json()
except Exception as exc:
await _structural_emit(
"d-to-p-sync.jsonl",
{"event": "failed", "stage": "prepare", "sid": sid, "rid": rid,
"error": repr(exc)[:200]},
)
return {"status": "prepare-failed", "error": repr(exc)}
t_prep1 = time.perf_counter()
if not prep.get("ok"):
await _structural_emit(
"d-to-p-sync.jsonl",
{"event": "skipped", "stage": "prepare", "sid": sid, "rid": rid,
"reason": prep.get("reason"),
"prepare_dur_ms": round((t_prep1 - t_prep0) * 1000, 2)},
)
return {"status": "prepare-not-ok", "reason": prep.get("reason")}
t_dump0 = time.perf_counter()
try:
dump_resp = await client.post(
f"{source_d_url}/_snapshot/dump",
json={
"session_id": request.session_id,
"target_snapshot_session_id": prep["snapshot_session_id"],
"target_k_base_ptrs": prep["k_base_ptrs"],
"target_v_base_ptrs": prep["v_base_ptrs"],
"target_slot_indices": prep["slot_indices"],
"target_stride_k_bytes": prep["stride_k_bytes"],
"target_stride_v_bytes": prep["stride_v_bytes"],
},
timeout=60.0,
)
dump_resp.raise_for_status()
dump = dump_resp.json()
except Exception as exc:
await _structural_emit(
"d-to-p-sync.jsonl",
{"event": "failed", "stage": "dump", "sid": sid, "rid": rid,
"error": repr(exc)[:200]},
)
return {"status": "dump-failed", "error": repr(exc)}
t_dump1 = time.perf_counter()
if not dump.get("ok"):
await _structural_emit(
"d-to-p-sync.jsonl",
{"event": "skipped", "stage": "dump", "sid": sid, "rid": rid,
"reason": dump.get("reason"),
"dump_dur_ms": round((t_dump1 - t_dump0) * 1000, 2),
"kv_committed_len": int(dump.get("kv_committed_len", 0))},
)
return {"status": "dump-not-ok", "reason": dump.get("reason"),
"bytes_pushed": dump.get("bytes_pushed", 0)}
# We need token_ids for radix insert. The caller has request.input_token_ids
# for the first N — use that as best-available approximation.
tokens = list(getattr(request, "input_token_ids", []) or [])
if not tokens:
# No token_ids available — can't insert into radix. P will fall back
# to normal prefill but will have wasted slots. Discard.
try:
await client.post(
f"{prefill_url}/_snapshot/finalize_ingest",
json={
"session_id": request.session_id,
"token_ids": [],
"slot_indices": prep["slot_indices"],
},
timeout=15.0,
)
except Exception:
pass
await _structural_emit(
"d-to-p-sync.jsonl",
{"event": "skipped", "stage": "post-dump", "sid": sid, "rid": rid,
"reason": "no-input-token-ids",
"bytes_pushed": int(dump.get("bytes_pushed", 0))},
)
return {"status": "no-tokens-discard", "bytes_pushed": dump.get("bytes_pushed", 0)}
n = min(len(tokens), len(prep["slot_indices"]))
t_fin0 = time.perf_counter()
try:
fin_resp = await client.post(
f"{prefill_url}/_snapshot/finalize_ingest",
json={
"session_id": request.session_id,
"token_ids": tokens[:n],
"slot_indices": prep["slot_indices"][:n],
},
timeout=30.0,
)
fin_resp.raise_for_status()
fin = fin_resp.json()
except Exception as exc:
await _structural_emit(
"d-to-p-sync.jsonl",
{"event": "failed", "stage": "finalize", "sid": sid, "rid": rid,
"error": repr(exc)[:200],
"bytes_pushed": int(dump.get("bytes_pushed", 0))},
)
return {"status": "finalize-failed", "error": repr(exc),
"bytes_pushed": dump.get("bytes_pushed", 0)}
t_fin1 = time.perf_counter()
if not fin.get("ok"):
await _structural_emit(
"d-to-p-sync.jsonl",
{"event": "skipped", "stage": "finalize", "sid": sid, "rid": rid,
"reason": fin.get("reason"),
"bytes_pushed": int(dump.get("bytes_pushed", 0))},
)
return {"status": "finalize-not-ok", "reason": fin.get("reason"),
"bytes_pushed": dump.get("bytes_pushed", 0)}
await _structural_emit(
"d-to-p-sync.jsonl",
{"event": "ok", "sid": sid, "rid": rid,
"bytes_pushed": int(dump.get("bytes_pushed", 0)),
"kv_committed_len": int(dump.get("kv_committed_len", 0)),
"inserted_prefix_len": int(fin.get("inserted_prefix_len", 0)),
"prepare_dur_ms": round((t_prep1 - t_prep0) * 1000, 2),
"dump_dur_ms": round((t_dump1 - t_dump0) * 1000, 2),
"finalize_dur_ms": round((t_fin1 - t_fin0) * 1000, 2),
"snapshot_session_id": prep.get("snapshot_session_id")},
)
return {
"status": "ok",
"bytes_pushed": int(dump.get("bytes_pushed", 0)),
"inserted_prefix_len": int(fin.get("inserted_prefix_len", 0)),
"snapshot_session_id": prep.get("snapshot_session_id"),
}
async def _invoke_kvcache_seeded_router(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
decision,
prefill_url: str,
decode_session: DirectSessionState,
direct_sessions: dict[str, DirectSessionState],
direct_session_lock: asyncio.Lock,
decode_residency: DecodeResidencyState,
reserved_tokens: int,
execution_mode: str,
) -> ExecutionResult:
keep_prefill_backup = False
prefill_reserved_tokens = 0
async with direct_session_lock:
if config.kvcache_prefill_backup_policy == "capacity-backup":
keep_prefill_backup, prefill_reserved_tokens, _prefill_evicted = (
await _reserve_prefill_backup_capacity(
client=client,
request=request,
prefill_url=prefill_url,
session=decode_session,
direct_sessions=direct_sessions,
residency=decode_residency,
)
)
if (
decode_session.prefill_opened
and decode_session.prefill_server_url != prefill_url
):
await _close_prefill_session(
client=client,
session=decode_session,
residency=decode_residency,
)
prefill_session_newly_opened = False
async with direct_session_lock:
if not decode_session.prefill_opened:
await _open_streaming_session(
client=client,
server_url=prefill_url,
session_id=request.session_id,
request=request,
)
decode_session.prefill_opened = True
decode_session.prefill_server_url = prefill_url
prefill_session_newly_opened = True
# D→P snapshot push (Phase 3) — best-effort; on any failure we silently
# fall back to the existing re-prefill path. The result is logged for
# post-hoc analysis but does not affect correctness.
if config.enable_d_to_p_sync:
sync_result = await _attempt_d_to_p_sync(
client=client,
request=request,
config=config,
prefill_url=prefill_url,
decode_session=decode_session,
)
if sync_result is not None and sync_result.get("status") != "ok":
logger.info(
"d_to_p_sync sid=%s rid=%s skipped: %s",
request.session_id, request.request_id, sync_result,
)
elif sync_result and sync_result.get("status") == "ok":
logger.info(
"d_to_p_sync sid=%s rid=%s pushed=%d ingested_prefix=%d",
request.session_id,
request.request_id,
sync_result.get("bytes_pushed", 0),
sync_result.get("inserted_prefix_len", 0),
)
decode_session_newly_opened = False
try:
prefill_priority = _prefill_priority_for_router_request(
config=config,
direct_to_d_predicted=True,
)
async with direct_session_lock:
if not decode_session.opened:
await _open_streaming_session(
client=client,
server_url=decode_session.server_url,
session_id=request.session_id,
request=request,
)
decode_session.opened = True
decode_session_newly_opened = True
decode_session.active_requests += 1
gen = await _invoke_router(
client=client,
request=request,
config=config,
decode_worker_index=decision.decode_worker_index,
session_id=request.session_id,
prefill_request_priority=prefill_priority,
decode_residency=decode_residency,
)
except Exception:
async with direct_session_lock:
decode_session.active_requests = max(0, decode_session.active_requests - 1)
_release_reserved_tokens(
decode_residency,
decode_session.server_url,
reserved_tokens,
)
_release_prefill_reserved_tokens(
decode_residency,
prefill_url,
prefill_reserved_tokens,
)
if decode_session_newly_opened:
await _close_decode_session(
client=client,
session=decode_session,
residency=decode_residency,
)
if prefill_session_newly_opened:
await _close_prefill_session(
client=client,
session=decode_session,
residency=decode_residency,
)
raise
async with direct_session_lock:
decode_session.active_requests = max(0, decode_session.active_requests - 1)
if keep_prefill_backup:
_commit_prefill_backup_residency(
residency=decode_residency,
session=decode_session,
request=request,
prefill_url=prefill_url,
reserved_tokens=prefill_reserved_tokens,
)
else:
_release_prefill_reserved_tokens(
decode_residency,
prefill_url,
prefill_reserved_tokens,
)
await _close_prefill_session(
client=client,
session=decode_session,
residency=decode_residency,
)
_commit_session_residency(
residency=decode_residency,
session=decode_session,
request=request,
reserved_tokens=reserved_tokens,
)
return ExecutionResult(
execution_mode=execution_mode,
actual_kv_transfer_blocks=decision.kv_transfer_blocks,
effective_input_length=request.input_length,
cached_tokens=gen.cached_tokens,
prefill_request_priority=prefill_priority,
session_reused=False,
session_reset=False,
latency_s=gen.latency_s,
ttft_s=gen.ttft_s,
tpot_s=gen.tpot_s,
actual_output_tokens=gen.actual_output_tokens,
requested_output_tokens=gen.requested_output_tokens,
finish_reason=gen.finish_reason,
)
async def _execute_request(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
decision,
direct_sessions: dict[str, DirectSessionState],
direct_session_lock: asyncio.Lock,
decode_residency: DecodeResidencyState,
) -> ExecutionResult:
if config.mechanism_name == "pd-disaggregation":
if not config.router_url:
raise ValueError("router_url is required for pd-disaggregation replay")
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode="pd-disaggregation-router",
decode_residency=decode_residency,
)
if config.mechanism_name == "pd-colo":
if not config.router_url:
raise ValueError("router_url is required for pd-colo replay")
result = await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode="dp-colo-router",
decode_residency=decode_residency,
)
return replace(result, actual_kv_transfer_blocks=0)
if config.mechanism_name == "kvcache-centric":
if not config.router_url:
raise ValueError("router_url is required for kvcache-centric replay")
if not config.topology.decode_workers:
raise ValueError("kvcache-centric mechanism requires at least one decode worker")
prefill_url = _worker_url_by_id(
config.topology.prefill_workers,
decision.prefill_worker_id,
)
decode_url = config.topology.decode_workers[decision.decode_worker_index].url
async with direct_session_lock:
decode_session = direct_sessions.get(request.session_id)
if decode_session is None:
decode_session = DirectSessionState(
session_id=request.session_id,
server_url=decode_url,
)
direct_sessions[request.session_id] = decode_session
elif decode_session.server_url != decode_url and decode_session.opened:
await _close_decode_session(
client=client,
session=decode_session,
residency=decode_residency,
)
decode_session.server_url = decode_url
else:
decode_session.server_url = decode_url
direct_append_length: int | None = None
direct_session_reused = False
direct_session_reset = False
if request.turn_id > 1:
async with direct_session_lock:
(
direct_append_length,
direct_session_reused,
direct_session_reset,
) = _inspect_direct_request(
request=request,
session=decode_session,
)
if request.turn_id == 1:
seed_filter_reason = _seed_filter_reason(
request=request,
config=config,
inflight_decode_load=decision.inflight_decode_load,
)
if seed_filter_reason is not None:
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode=f"pd-router-turn1-{seed_filter_reason}",
decode_residency=decode_residency,
)
async with direct_session_lock:
admit_new_decode_session = _should_admit_new_decode_session(
residency=decode_residency,
server_url=decode_url,
request=request,
session=decode_session,
direct_sessions=direct_sessions,
treat_as_fresh_session=True,
admission_mode=config.kvcache_admission_mode,
)
if not admit_new_decode_session:
can_seed = False
reserved_tokens = 0
seed_reason = "d-session-cap"
else:
can_seed, reserved_tokens, _evicted, _p_backed, seed_reason = (
await _reserve_decode_session_capacity(
client=client,
config=config,
request=request,
server_url=decode_url,
session=decode_session,
direct_sessions=direct_sessions,
residency=decode_residency,
treat_as_fresh_session=True,
routing_mode="seed",
admission_mode=config.kvcache_admission_mode,
)
)
if seed_reason == "d-session-cap":
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode="pd-router-turn1-session-cap",
decode_residency=decode_residency,
)
if can_seed:
return await _invoke_kvcache_seeded_router(
client=client,
request=request,
config=config,
decision=decision,
prefill_url=prefill_url,
decode_session=decode_session,
direct_sessions=direct_sessions,
direct_session_lock=direct_session_lock,
decode_residency=decode_residency,
reserved_tokens=reserved_tokens,
execution_mode="pd-router-turn1-seed",
)
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode=(
"pd-router-turn1-d-backpressure"
if seed_reason is not None and seed_reason != "d-no-space"
else "pd-router-turn1-no-d-capacity"
),
decode_residency=decode_residency,
)
if (
_should_bypass_prefill(
request=request,
config=config,
decision=decision,
)
and direct_append_length is not None
and direct_session_reused
and not direct_session_reset
and direct_append_length <= config.kvcache_direct_max_uncached_tokens
):
async with direct_session_lock:
can_direct = (
decode_session.opened
and decode_session.server_url == decode_url
and direct_session_reused
and not direct_session_reset
)
direct_reserved_tokens = 0
direct_reason: str | None = None
if can_direct:
async with direct_session_lock:
(
can_direct,
direct_reserved_tokens,
_evicted,
_p_backed,
direct_reason,
) = (
await _reserve_decode_session_capacity(
client=client,
config=config,
request=request,
server_url=decode_url,
session=decode_session,
direct_sessions=direct_sessions,
residency=decode_residency,
treat_as_fresh_session=False,
routing_mode="direct",
admission_mode=config.kvcache_admission_mode,
)
)
if can_direct:
try:
return await _invoke_decode_session_direct(
client=client,
request=request,
config=config,
decision=decision,
direct_sessions=direct_sessions,
direct_session_lock=direct_session_lock,
decode_residency=decode_residency,
reserved_tokens=direct_reserved_tokens,
)
except Exception as exc:
if not _is_stale_decode_session_error(exc):
raise
async with direct_session_lock:
await _close_decode_session(
client=client,
session=decode_session,
residency=decode_residency,
)
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode="pd-router-fallback-stale-d-session",
decode_residency=decode_residency,
)
if _is_decode_backpressure_reason(direct_reason):
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode="pd-router-fallback-d-backpressure",
decode_residency=decode_residency,
)
seed_filter_reason = _seed_filter_reason(
request=request,
config=config,
inflight_decode_load=decision.inflight_decode_load,
)
if seed_filter_reason is not None:
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode=f"pd-router-fallback-{seed_filter_reason}",
decode_residency=decode_residency,
)
async with direct_session_lock:
admit_new_decode_session = _should_admit_new_decode_session(
residency=decode_residency,
server_url=decode_url,
request=request,
session=decode_session,
direct_sessions=direct_sessions,
treat_as_fresh_session=True,
admission_mode=config.kvcache_admission_mode,
)
if not admit_new_decode_session:
can_seed = False
reserved_tokens = 0
evicted_sessions = 0
prefill_backed_evictions = 0
seed_reason = "d-session-cap"
else:
(
can_seed,
reserved_tokens,
evicted_sessions,
prefill_backed_evictions,
seed_reason,
) = (
await _reserve_decode_session_capacity(
client=client,
config=config,
request=request,
server_url=decode_url,
session=decode_session,
direct_sessions=direct_sessions,
residency=decode_residency,
treat_as_fresh_session=True,
routing_mode="seed",
admission_mode=config.kvcache_admission_mode,
)
)
if seed_reason == "d-session-cap":
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode="pd-router-fallback-session-cap",
decode_residency=decode_residency,
)
if can_seed:
return await _invoke_kvcache_seeded_router(
client=client,
request=request,
config=config,
decision=decision,
prefill_url=prefill_url,
decode_session=decode_session,
direct_sessions=direct_sessions,
direct_session_lock=direct_session_lock,
decode_residency=decode_residency,
reserved_tokens=reserved_tokens,
execution_mode=(
"pd-router-d-session-reseed"
+ _eviction_suffix(
evicted_sessions,
prefill_backed_evictions,
)
),
)
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode=(
"pd-router-fallback-d-backpressure"
if _is_decode_backpressure_reason(seed_reason)
else "pd-router-fallback-no-d-capacity"
),
decode_residency=decode_residency,
)
# TEAM_REPORT §2.7: 'large-append' is misleading — most fallthroughs are
# actually 'session-not-resident-on-pinned-D' (§1 starvation). Classify
# the real reason and embed it in the execution_mode label.
fallthrough = _fallthrough_reason(
request=request,
config=config,
decision=decision,
direct_append_length=direct_append_length,
direct_session_reused=direct_session_reused,
direct_session_reset=direct_session_reset,
)
seed_filter_reason = _seed_filter_reason(
request=request,
config=config,
inflight_decode_load=decision.inflight_decode_load,
)
if seed_filter_reason is not None:
return await _invoke_plain_router(
request=request,
client=client,
config=config,
decision=decision,
execution_mode=f"pd-router-fallback-{fallthrough}-{seed_filter_reason}",
decode_residency=decode_residency,
)
async with direct_session_lock:
admit_new_decode_session = _should_admit_new_decode_session(
residency=decode_residency,
server_url=decode_url,
request=request,
session=decode_session,
direct_sessions=direct_sessions,
treat_as_fresh_session=True,
)
if not admit_new_decode_session:
can_seed = False
reserved_tokens = 0
evicted_sessions = 0
prefill_backed_evictions = 0
seed_reason = "d-session-cap"
else:
(
can_seed,
reserved_tokens,
evicted_sessions,
prefill_backed_evictions,
seed_reason,
) = (
await _reserve_decode_session_capacity(
client=client,
config=config,
request=request,
server_url=decode_url,
session=decode_session,
direct_sessions=direct_sessions,
residency=decode_residency,
treat_as_fresh_session=True,
routing_mode="seed",
admission_mode=config.kvcache_admission_mode,
)
)
if seed_reason == "d-session-cap":
return await _invoke_plain_router(
request=request,
client=client,
config=config,
decision=decision,
execution_mode=f"pd-router-fallback-{fallthrough}-session-cap",
decode_residency=decode_residency,
)
if can_seed:
return await _invoke_kvcache_seeded_router(
client=client,
request=request,
config=config,
decision=decision,
prefill_url=prefill_url,
decode_session=decode_session,
direct_sessions=direct_sessions,
direct_session_lock=direct_session_lock,
decode_residency=decode_residency,
reserved_tokens=reserved_tokens,
execution_mode=(
f"pd-router-{fallthrough}-reseed"
+ _eviction_suffix(
evicted_sessions,
prefill_backed_evictions,
)
),
)
# Preserve seed_reason in the label so migration feedback fires for
# 'd-no-space' / 'd-*-backpressure' (matched via _is_admission_rejection_mode).
if _is_decode_backpressure_reason(seed_reason):
mode_label = f"pd-router-fallback-{fallthrough}-d-backpressure"
elif seed_reason == "d-no-space":
mode_label = f"pd-router-fallback-{fallthrough}-no-d-capacity"
else:
mode_label = f"pd-router-fallback-{fallthrough}"
return await _invoke_plain_router(
request=request,
client=client,
config=config,
decision=decision,
execution_mode=mode_label,
decode_residency=decode_residency,
)
raise ValueError(f"Unsupported mechanism: {config.mechanism_name}")
async def _invoke_direct(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
decision,
direct_sessions: dict[str, DirectSessionState],
direct_session_lock: asyncio.Lock,
) -> ExecutionResult:
direct_workers = config.topology.direct_workers
if not direct_workers:
raise ValueError("pd-colo mechanism requires at least one direct worker")
server_url = direct_workers[decision.decode_worker_index].url
session = direct_sessions.get(request.session_id)
if session is None or session.server_url != server_url:
session = DirectSessionState(session_id=request.session_id, server_url=server_url)
direct_sessions[request.session_id] = session
return await _invoke_session_direct(
client=client,
request=request,
config=config,
session=session,
execution_mode="pd-colo-direct-session",
direct_session_lock=direct_session_lock,
)
async def _invoke_session_direct(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
session: DirectSessionState,
execution_mode: str,
decode_residency: DecodeResidencyState | None = None,
reserved_tokens: int = 0,
direct_session_lock: asyncio.Lock | None = None,
) -> ExecutionResult:
if decode_residency is not None and config.enable_backpressure:
await _wait_for_decode_pause(
config=config,
residency=decode_residency,
server_url=session.server_url,
request_id=request.request_id,
session_id=session.session_id,
)
_prompt, effective_input_length, session_reused, session_reset = _build_direct_prompt(
request=request,
session=session,
)
if session_reused:
input_ids = _build_direct_append_input_ids(request, effective_input_length)
else:
input_ids = _build_direct_full_input_ids(request)
if session.opened and (session_reset or not session_reused):
if decode_residency is not None:
await _close_decode_session(
client=client,
session=session,
residency=decode_residency,
)
else:
await _close_streaming_session(
client=client,
server_url=session.server_url,
session_id=session.session_id,
allow_missing=True,
)
session.opened = False
session.resident_tokens = 0
if not session.opened:
await _open_streaming_session(
client=client,
server_url=session.server_url,
session_id=session.session_id,
request=request,
)
session.opened = True
if direct_session_lock is not None:
async with direct_session_lock:
session.active_requests += 1
try:
gen = await _invoke_generate(
client=client,
base_url=session.server_url,
headers={"x-request-id": request.request_id},
payload={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0,
"max_new_tokens": max(1, request.output_length),
"min_new_tokens": max(1, request.output_length),
"ignore_eos": True,
"no_stop_trim": True,
"skip_special_tokens": False,
},
"session_params": {"id": session.session_id},
"stream": config.stream,
},
timeout_s=config.timeout_s,
stream_idle_timeout_s=config.stream_idle_timeout_s,
stream=config.stream,
)
finally:
if direct_session_lock is not None:
async with direct_session_lock:
session.active_requests = max(0, session.active_requests - 1)
if decode_residency is not None:
_commit_session_residency(
residency=decode_residency,
session=session,
request=request,
reserved_tokens=reserved_tokens,
)
else:
session.last_trace_request = request
session.last_access_s = time.perf_counter()
return ExecutionResult(
execution_mode=execution_mode,
actual_kv_transfer_blocks=0,
effective_input_length=len(input_ids),
cached_tokens=gen.cached_tokens,
session_reused=session_reused,
session_reset=session_reset,
latency_s=gen.latency_s,
ttft_s=gen.ttft_s,
tpot_s=gen.tpot_s,
actual_output_tokens=gen.actual_output_tokens,
requested_output_tokens=gen.requested_output_tokens,
finish_reason=gen.finish_reason,
)
async def _invoke_decode_session_direct(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
decision,
direct_sessions: dict[str, DirectSessionState],
direct_session_lock: asyncio.Lock,
decode_residency: DecodeResidencyState,
reserved_tokens: int,
) -> ExecutionResult:
decode_url = config.topology.decode_workers[decision.decode_worker_index].url
session = direct_sessions[request.session_id]
try:
return await _invoke_session_direct(
client=client,
request=request,
config=config,
session=session,
execution_mode="kvcache-direct-to-d-session",
decode_residency=decode_residency,
reserved_tokens=reserved_tokens,
direct_session_lock=direct_session_lock,
)
except Exception:
async with direct_session_lock:
_release_reserved_tokens(
decode_residency,
decode_url,
reserved_tokens,
)
raise
def _should_bypass_prefill(
*,
request: TraceRequest,
config: ReplayConfig,
decision,
block_token_budget: int = 24,
) -> bool:
if request.turn_id <= 1:
return False
if decision.observed_overlap_blocks <= 0:
return False
uncached_tokens = max(0, decision.kv_transfer_blocks * block_token_budget)
return uncached_tokens <= config.kvcache_direct_max_uncached_tokens
def _worker_url_by_id(workers, worker_id: str) -> str:
for worker in workers:
if worker.worker_id == worker_id:
return worker.url
raise KeyError(f"Unknown worker id: {worker_id}")
def _build_headers(
*,
request: TraceRequest,
header_mode: HeaderMode,
decode_worker_index: int,
policy_name: str,
) -> dict[str, str]:
if header_mode == "auto":
header_mode = "routing-key" if policy_name == "sticky" else "none"
headers = {
"x-request-id": request.request_id,
}
if header_mode == "routing-key":
headers["x-smg-routing-key"] = request.session_id
elif header_mode == "target-worker":
headers["x-smg-target-worker"] = str(decode_worker_index)
return headers
def _contains_token(payload: dict) -> bool:
choices = payload.get("choices")
if not isinstance(choices, list) or not choices:
return False
delta = choices[0].get("delta")
if not isinstance(delta, dict):
return False
reasoning_content = delta.get("reasoning_content")
if isinstance(reasoning_content, str) and reasoning_content:
return True
content = delta.get("content")
if isinstance(content, str) and content:
return True
if isinstance(content, list):
return any(
isinstance(item, dict) and item.get("text")
for item in content
)
tool_calls = delta.get("tool_calls")
if isinstance(tool_calls, list):
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
continue
if tool_call.get("id"):
return True
function = tool_call.get("function")
if not isinstance(function, dict):
continue
if function.get("name") or function.get("arguments"):
return True
return False
def _contains_generate_token(payload: dict) -> bool:
text = payload.get("text")
if isinstance(text, str) and text:
return True
meta_info = payload.get("meta_info")
if not isinstance(meta_info, dict):
return False
return int(meta_info.get("completion_tokens", 0)) > 0
def _is_generate_terminal_chunk(payload: dict) -> bool:
meta_info = payload.get("meta_info")
if not isinstance(meta_info, dict):
return False
return meta_info.get("finish_reason") is not None
def _extract_generate_cached_tokens(payload: dict) -> int:
meta_info = payload.get("meta_info")
if not isinstance(meta_info, dict):
return 0
return int(meta_info.get("cached_tokens", 0) or 0)
def _extract_openai_cached_tokens(payload: dict) -> int:
usage = payload.get("usage")
if not isinstance(usage, dict):
return 0
prompt_tokens_details = usage.get("prompt_tokens_details")
if not isinstance(prompt_tokens_details, dict):
return 0
return int(prompt_tokens_details.get("cached_tokens", 0) or 0)
async def _aiter_lines(
response: httpx.Response,
*,
idle_timeout_s: float | None,
):
line_iterator = response.aiter_lines()
while True:
try:
if idle_timeout_s is None or idle_timeout_s <= 0:
line = await anext(line_iterator)
else:
line = await asyncio.wait_for(
anext(line_iterator),
timeout=idle_timeout_s,
)
except StopAsyncIteration:
return
yield line
def _is_terminal_chunk(payload: dict) -> bool:
choices = payload.get("choices")
if not isinstance(choices, list) or not choices:
return False
finish_reason = choices[0].get("finish_reason")
return finish_reason is not None