Files
agentic-pd-hybrid/src/agentic_pd_hybrid/replay.py

2272 lines
76 KiB
Python

from __future__ import annotations
import asyncio
import json
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
import httpx
from agentic_pd_hybrid.metrics import (
RequestMetrics,
write_metrics_jsonl,
write_summary_json,
)
from agentic_pd_hybrid.policies import RoutingState, create_policy
from agentic_pd_hybrid.topology import SingleNodeTopology
from agentic_pd_hybrid.trace import (
TraceRequest,
build_synthetic_append_chunk,
build_synthetic_prompt,
load_trace,
)
HeaderMode = Literal["none", "routing-key", "target-worker", "auto"]
KvCacheAdmissionMode = Literal["router", "worker"]
@dataclass(frozen=True)
class ReplayConfig:
trace_path: Path
output_path: Path
policy_name: str
mechanism_name: str
topology: SingleNodeTopology
router_url: str | None = None
model_name: str | None = None
pace: bool = True
time_scale: float = 1.0
request_limit: int | None = None
concurrency_limit: int = 32
header_mode: HeaderMode = "auto"
timeout_s: float = 600.0
stream: bool = True
stream_idle_timeout_s: float | None = 900.0
kvcache_direct_max_uncached_tokens: int = 2048
kvcache_admission_mode: KvCacheAdmissionMode = "router"
@dataclass
class DirectSessionState:
session_id: str
server_url: str
opened: bool = False
last_trace_request: TraceRequest | None = None
resident_tokens: int = 0
last_access_s: float = 0.0
active_requests: int = 0
prefill_server_url: str | None = None
prefill_opened: bool = False
prefill_resident_tokens: int = 0
prefill_last_access_s: float = 0.0
prefill_low_priority: bool = False
@dataclass
class DecodeResidencyState:
capacity_tokens: dict[str, int] = field(default_factory=dict)
headroom_tokens: dict[str, int] = field(default_factory=dict)
reserved_decode_tokens: dict[str, int] = field(default_factory=dict)
resident_tokens_by_server: dict[str, int] = field(default_factory=dict)
reserved_tokens_by_server: dict[str, int] = field(default_factory=dict)
prefill_capacity_tokens: dict[str, int] = field(default_factory=dict)
prefill_headroom_tokens: dict[str, int] = field(default_factory=dict)
prefill_resident_tokens_by_server: dict[str, int] = field(default_factory=dict)
prefill_reserved_tokens_by_server: dict[str, int] = field(default_factory=dict)
decode_evictions_prefill_backed: int = 0
decode_evictions_without_prefill_backup: int = 0
@dataclass(frozen=True)
class DecodeLoadSnapshot:
timestamp_s: float
num_running_reqs: int
num_waiting_reqs: int
num_used_tokens: int
max_total_num_tokens: int
token_usage: float
decode_prealloc_queue_reqs: int
decode_transfer_queue_reqs: int
decode_retracted_queue_reqs: int
@dataclass(frozen=True)
class ExecutionResult:
execution_mode: str
actual_kv_transfer_blocks: int
effective_input_length: int | None
cached_tokens: int
session_reused: bool
session_reset: bool
latency_s: float | None
ttft_s: float | None
tpot_s: float | None
error: str | None = None
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
requests = load_trace(config.trace_path, request_limit=config.request_limit)
policy = create_policy(config.policy_name)
state = RoutingState.create(config.topology)
state_lock = asyncio.Lock()
semaphore = asyncio.Semaphore(config.concurrency_limit)
start_time = time.perf_counter()
first_timestamp = requests[0].timestamp_s if requests else 0.0
session_tail_tasks: dict[str, asyncio.Task[RequestMetrics]] = {}
direct_sessions: dict[str, DirectSessionState] = {}
direct_session_lock = asyncio.Lock()
async with httpx.AsyncClient(timeout=config.timeout_s, trust_env=False) as client:
decode_residency = await _discover_decode_residency(
client=client,
config=config,
)
tasks = []
for request in requests:
if config.pace:
target_offset = (request.timestamp_s - first_timestamp) / config.time_scale
sleep_s = target_offset - (time.perf_counter() - start_time)
if sleep_s > 0:
await asyncio.sleep(sleep_s)
tasks.append(
asyncio.create_task(
_run_request(
request=request,
config=config,
client=client,
policy=policy,
state=state,
state_lock=state_lock,
semaphore=semaphore,
direct_sessions=direct_sessions,
direct_session_lock=direct_session_lock,
decode_residency=decode_residency,
depends_on=session_tail_tasks.get(request.session_id),
)
)
)
session_tail_tasks[request.session_id] = tasks[-1]
results = await asyncio.gather(*tasks)
for session in direct_sessions.values():
if session.opened:
try:
await _close_streaming_session(
client=client,
server_url=session.server_url,
session_id=session.session_id,
allow_missing=True,
)
except Exception:
pass
if session.prefill_opened and session.prefill_server_url is not None:
try:
await _close_streaming_session(
client=client,
server_url=session.prefill_server_url,
session_id=session.session_id,
allow_missing=True,
)
except Exception:
pass
write_metrics_jsonl(config.output_path, results)
write_summary_json(
config.output_path.with_suffix(config.output_path.suffix + ".summary.json"),
results,
trace_path=config.trace_path,
router_url=config.router_url,
)
return results
async def _run_request(
*,
request: TraceRequest,
config: ReplayConfig,
client: httpx.AsyncClient,
policy,
state: RoutingState,
state_lock: asyncio.Lock,
semaphore: asyncio.Semaphore,
direct_sessions: dict[str, DirectSessionState],
direct_session_lock: asyncio.Lock,
decode_residency: DecodeResidencyState,
depends_on: asyncio.Task[RequestMetrics] | None,
) -> RequestMetrics:
if depends_on is not None:
await depends_on
async with semaphore:
async with state_lock:
decision = policy.select(request, topology=config.topology, state=state)
try:
execution = await _execute_request(
client=client,
request=request,
config=config,
decision=decision,
direct_sessions=direct_sessions,
direct_session_lock=direct_session_lock,
decode_residency=decode_residency,
)
except Exception as exc: # pragma: no cover - defensive logging path
execution = ExecutionResult(
execution_mode=config.mechanism_name,
actual_kv_transfer_blocks=0,
effective_input_length=None,
cached_tokens=0,
session_reused=False,
session_reset=False,
latency_s=None,
ttft_s=None,
tpot_s=None,
error=f"{type(exc).__name__}: {exc}",
)
async with state_lock:
state.finish(request, decision)
return RequestMetrics.from_decision(
request,
decision,
mechanism_name=config.mechanism_name,
execution_mode=execution.execution_mode,
actual_kv_transfer_blocks=execution.actual_kv_transfer_blocks,
effective_input_length=execution.effective_input_length,
cached_tokens=execution.cached_tokens,
session_reused=execution.session_reused,
session_reset=execution.session_reset,
latency_s=execution.latency_s,
ttft_s=execution.ttft_s,
tpot_s=execution.tpot_s,
error=execution.error,
)
async def _invoke_router(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
decode_worker_index: int,
session_id: str | None = None,
) -> tuple[float, float | None, float | None, int]:
headers = _build_headers(
request=request,
header_mode=config.header_mode,
decode_worker_index=decode_worker_index,
policy_name=config.policy_name,
)
assert config.router_url is not None
payload: dict[str, object] = {
"input_ids": _build_direct_full_input_ids(request),
"sampling_params": {
"temperature": 0,
"max_new_tokens": max(1, request.output_length),
"ignore_eos": True,
"no_stop_trim": True,
"skip_special_tokens": False,
},
"stream": config.stream,
}
if session_id is not None:
payload["session_params"] = {"id": session_id}
return await _invoke_generate(
client=client,
base_url=config.router_url,
headers=headers,
payload=payload,
timeout_s=config.timeout_s,
stream_idle_timeout_s=config.stream_idle_timeout_s,
stream=config.stream,
)
def _build_payload(
*,
request: TraceRequest,
model_name: str,
prompt: str,
stream: bool,
session_params: dict[str, str] | None,
exact_output_length: bool,
) -> dict[str, object]:
payload: dict[str, object] = {
"model": model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max(1, request.output_length),
"temperature": 0,
"stream": stream,
}
if stream:
payload["stream_options"] = {"include_usage": True}
if exact_output_length:
payload.update(
{
"min_tokens": max(1, request.output_length),
"ignore_eos": True,
"no_stop_trim": True,
"skip_special_tokens": False,
}
)
if session_params is not None:
payload["session_params"] = session_params
return payload
async def _invoke_chat_completion(
*,
client: httpx.AsyncClient,
base_url: str,
headers: dict[str, str],
payload: dict[str, object],
timeout_s: float,
stream_idle_timeout_s: float | None,
stream: bool,
) -> tuple[float, float | None, float | None, int]:
start = time.perf_counter()
ttft_s: float | None = None
cached_tokens = 0
generated_tokens = int(payload.get("max_tokens", 1))
if stream:
async with client.stream(
"POST",
f"{base_url.rstrip('/')}/v1/chat/completions",
headers=headers,
json=payload,
timeout=timeout_s,
) as response:
response.raise_for_status()
async for line in _aiter_lines(
response,
idle_timeout_s=stream_idle_timeout_s,
):
if not line.startswith("data:"):
continue
data = line[5:].strip()
if data == "[DONE]":
break
parsed = json.loads(data)
cached_tokens = max(cached_tokens, _extract_openai_cached_tokens(parsed))
if _contains_token(parsed) and ttft_s is None:
ttft_s = time.perf_counter() - start
if _is_terminal_chunk(parsed):
break
else:
response = await client.post(
f"{base_url.rstrip('/')}/v1/chat/completions",
headers=headers,
json=payload,
timeout=timeout_s,
)
response.raise_for_status()
parsed = response.json()
cached_tokens = _extract_openai_cached_tokens(parsed)
latency_s = time.perf_counter() - start
if stream and ttft_s is None and generated_tokens > 0:
raise RuntimeError("generate stream ended before producing any token")
if ttft_s is None:
tpot_s = None
else:
tpot_s = max(0.0, latency_s - ttft_s) / max(1, generated_tokens)
return latency_s, ttft_s, tpot_s, cached_tokens
async def _invoke_generate(
*,
client: httpx.AsyncClient,
base_url: str,
headers: dict[str, str],
payload: dict[str, object],
timeout_s: float,
stream_idle_timeout_s: float | None,
stream: bool,
) -> tuple[float, float | None, float | None, int]:
start = time.perf_counter()
ttft_s: float | None = None
cached_tokens = 0
sampling_params = payload.get("sampling_params", {})
generated_tokens = int(sampling_params.get("max_new_tokens", 1))
if stream:
async with client.stream(
"POST",
f"{base_url.rstrip('/')}/generate",
headers=headers,
json=payload,
timeout=timeout_s,
) as response:
response.raise_for_status()
async for line in _aiter_lines(
response,
idle_timeout_s=stream_idle_timeout_s,
):
if not line.startswith("data:"):
continue
data = line[5:].strip()
if data == "[DONE]":
break
parsed = json.loads(data)
error = parsed.get("error")
if isinstance(error, dict):
raise ValueError(error.get("message", json.dumps(error)))
cached_tokens = max(cached_tokens, _extract_generate_cached_tokens(parsed))
if _contains_generate_token(parsed) and ttft_s is None:
ttft_s = time.perf_counter() - start
if _is_generate_terminal_chunk(parsed):
break
else:
response = await client.post(
f"{base_url.rstrip('/')}/generate",
headers=headers,
json=payload,
timeout=timeout_s,
)
response.raise_for_status()
parsed = response.json()
error = parsed.get("error")
if isinstance(error, dict):
raise ValueError(error.get("message", json.dumps(error)))
cached_tokens = _extract_generate_cached_tokens(parsed)
latency_s = time.perf_counter() - start
if stream and ttft_s is None and generated_tokens > 0:
raise RuntimeError("generate stream ended before producing any token")
if ttft_s is None:
tpot_s = None
else:
tpot_s = max(0.0, latency_s - ttft_s) / max(1, generated_tokens)
return latency_s, ttft_s, tpot_s, cached_tokens
async def _open_streaming_session(
*,
client: httpx.AsyncClient,
server_url: str,
session_id: str,
request: TraceRequest,
) -> None:
capacity = max(
4096,
request.input_length * 16,
(request.input_length + request.output_length) * 16,
)
response = await client.post(
f"{server_url.rstrip('/')}/open_session",
json={
"capacity_of_str_len": capacity,
"session_id": session_id,
"streaming": True,
},
)
response.raise_for_status()
opened_session_id = response.json()
if opened_session_id != session_id:
raise ValueError(
f"Unexpected session id from {server_url}: {opened_session_id!r} != {session_id!r}"
)
async def _close_streaming_session(
*,
client: httpx.AsyncClient,
server_url: str,
session_id: str,
allow_missing: bool = False,
) -> None:
response = await client.post(
f"{server_url.rstrip('/')}/close_session",
json={"session_id": session_id},
)
if response.is_success:
return
if allow_missing:
response_text = response.text.lower()
if response.status_code == 404 or "does not exist" in response_text:
return
response.raise_for_status()
def _extract_internal_state(payload: dict[str, Any]) -> dict[str, Any]:
internal_states = payload.get("internal_states")
if isinstance(internal_states, list) and internal_states:
internal_state = internal_states[0]
if isinstance(internal_state, dict):
return internal_state
return payload
def _extract_server_int(
payload: dict[str, Any],
key: str,
) -> int:
internal_state = _extract_internal_state(payload)
value = payload.get(key, internal_state.get(key, 0))
return int(value or 0)
def _extract_session_cache(
payload: dict[str, Any],
) -> dict[str, Any]:
internal_state = _extract_internal_state(payload)
session_cache = internal_state.get("session_cache")
if isinstance(session_cache, dict):
return session_cache
return {}
def _find_session_cache_status(
session_cache: dict[str, Any],
session_id: str,
) -> dict[str, Any] | None:
sessions = session_cache.get("sessions")
if not isinstance(sessions, list):
return None
for session in sessions:
if isinstance(session, dict) and session.get("session_id") == session_id:
return session
return None
async def _fetch_decode_server_state(
*,
client: httpx.AsyncClient,
server_url: str,
) -> tuple[dict[str, Any], int, int]:
try:
response = await client.get(f"{server_url.rstrip('/')}/server_info")
response.raise_for_status()
payload = response.json()
except Exception:
return {}, 0, 0
return (
_extract_session_cache(payload),
_extract_server_int(payload, "max_total_num_tokens"),
_extract_server_int(payload, "num_reserved_decode_tokens"),
)
async def _query_decode_direct_admission(
*,
client: httpx.AsyncClient,
server_url: str,
session_id: str,
uncached_input_tokens: int,
output_tokens: int,
) -> dict[str, Any]:
try:
response = await client.post(
f"{server_url.rstrip('/')}/session_cache/admit_direct_append",
json={
"session_id": session_id,
"uncached_input_tokens": max(0, uncached_input_tokens),
"output_tokens": max(0, output_tokens),
},
)
response.raise_for_status()
payload = response.json()
if isinstance(payload, dict):
return payload
except Exception:
pass
return {
"can_admit": False,
"resident": False,
"reason": "admission-query-failed",
"required_tokens": 0,
"available_tokens_before": 0,
"available_tokens_after": 0,
"evicted_session_count": 0,
"freed_tokens": 0,
}
async def _discover_decode_residency(
*,
client: httpx.AsyncClient,
config: ReplayConfig,
) -> DecodeResidencyState:
residency = DecodeResidencyState()
if config.mechanism_name != "kvcache-centric":
return residency
for worker in config.topology.decode_workers:
_session_cache, max_total_num_tokens, reserved_decode_tokens = (
await _fetch_decode_server_state(
client=client,
server_url=worker.url,
)
)
if max_total_num_tokens <= 0:
continue
safety_headroom = max(
reserved_decode_tokens * 4,
max_total_num_tokens // 20,
8192,
)
residency.capacity_tokens[worker.url] = max_total_num_tokens
residency.headroom_tokens[worker.url] = min(
max_total_num_tokens,
safety_headroom,
)
residency.reserved_decode_tokens[worker.url] = reserved_decode_tokens
for worker in config.topology.prefill_workers:
_session_cache, max_total_num_tokens, _reserved_decode_tokens = (
await _fetch_decode_server_state(
client=client,
server_url=worker.url,
)
)
if max_total_num_tokens <= 0:
continue
safety_headroom = max(
max_total_num_tokens // 10,
16384,
)
residency.prefill_capacity_tokens[worker.url] = max_total_num_tokens
residency.prefill_headroom_tokens[worker.url] = min(
max_total_num_tokens,
safety_headroom,
)
return residency
def _estimate_session_resident_tokens(request: TraceRequest) -> int:
return request.input_length + request.output_length
def _inspect_direct_request(
*,
request: TraceRequest,
session: DirectSessionState,
) -> tuple[int, bool, bool]:
previous = session.last_trace_request
if previous is None:
return request.input_length, False, False
append_length = request.input_length - (
previous.input_length + previous.output_length
)
if append_length <= 0:
return request.input_length, False, True
return append_length, True, False
def _add_reserved_tokens(
residency: DecodeResidencyState,
server_url: str,
delta_tokens: int,
) -> None:
if delta_tokens <= 0:
return
residency.reserved_tokens_by_server[server_url] = (
residency.reserved_tokens_by_server.get(server_url, 0) + delta_tokens
)
def _release_reserved_tokens(
residency: DecodeResidencyState,
server_url: str,
delta_tokens: int,
) -> None:
if delta_tokens <= 0:
return
remaining = residency.reserved_tokens_by_server.get(server_url, 0) - delta_tokens
if remaining > 0:
residency.reserved_tokens_by_server[server_url] = remaining
else:
residency.reserved_tokens_by_server.pop(server_url, None)
def _add_prefill_reserved_tokens(
residency: DecodeResidencyState,
server_url: str,
delta_tokens: int,
) -> None:
if delta_tokens <= 0:
return
residency.prefill_reserved_tokens_by_server[server_url] = (
residency.prefill_reserved_tokens_by_server.get(server_url, 0) + delta_tokens
)
def _release_prefill_reserved_tokens(
residency: DecodeResidencyState,
server_url: str,
delta_tokens: int,
) -> None:
if delta_tokens <= 0:
return
remaining = (
residency.prefill_reserved_tokens_by_server.get(server_url, 0) - delta_tokens
)
if remaining > 0:
residency.prefill_reserved_tokens_by_server[server_url] = remaining
else:
residency.prefill_reserved_tokens_by_server.pop(server_url, None)
def _usable_capacity_tokens(
residency: DecodeResidencyState,
server_url: str,
) -> int:
return max(
0,
residency.capacity_tokens.get(server_url, 0)
- residency.headroom_tokens.get(server_url, 0),
)
def _usable_prefill_backup_capacity_tokens(
residency: DecodeResidencyState,
server_url: str,
) -> int:
return max(
0,
residency.prefill_capacity_tokens.get(server_url, 0)
- residency.prefill_headroom_tokens.get(server_url, 0),
)
def _eviction_suffix(evicted_sessions: int, prefill_backed_evictions: int) -> str:
if evicted_sessions <= 0:
return ""
if prefill_backed_evictions >= evicted_sessions:
return "-after-prefill-backed-eviction"
if prefill_backed_evictions > 0:
return "-after-mixed-eviction"
return "-after-eviction"
def _decode_session_soft_cap(
*,
residency: DecodeResidencyState,
server_url: str,
request: TraceRequest,
) -> int:
target_tokens = max(1, _estimate_session_resident_tokens(request))
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
if usable_capacity_tokens <= 0:
usable_capacity_tokens = max(
0,
residency.capacity_tokens.get(server_url, 0)
- residency.headroom_tokens.get(server_url, 0),
)
if usable_capacity_tokens <= 0:
return 4
return max(1, min(4, usable_capacity_tokens // target_tokens))
def _should_admit_new_decode_session(
*,
residency: DecodeResidencyState,
server_url: str,
request: TraceRequest,
session: DirectSessionState,
direct_sessions: dict[str, DirectSessionState],
treat_as_fresh_session: bool,
) -> bool:
if (
not treat_as_fresh_session
and session.opened
and session.server_url == server_url
):
return True
open_sessions = sum(
1
for candidate in direct_sessions.values()
if candidate.opened and candidate.server_url == server_url
)
return open_sessions < _decode_session_soft_cap(
residency=residency,
server_url=server_url,
request=request,
)
async def _fetch_decode_load_snapshot(
*,
client: httpx.AsyncClient,
server_url: str,
) -> DecodeLoadSnapshot | None:
try:
response = await client.get(
f"{server_url.rstrip('/')}/v1/loads",
params={"include": "core,disagg"},
)
response.raise_for_status()
payload = response.json()
except Exception:
return None
loads = payload.get("loads")
if not isinstance(loads, list) or not loads:
return None
load = loads[0]
disagg = load.get("disaggregation") or {}
return DecodeLoadSnapshot(
timestamp_s=time.perf_counter(),
num_running_reqs=int(load.get("num_running_reqs", 0) or 0),
num_waiting_reqs=int(load.get("num_waiting_reqs", 0) or 0),
num_used_tokens=int(load.get("num_used_tokens", 0) or 0),
max_total_num_tokens=int(load.get("max_total_num_tokens", 0) or 0),
token_usage=float(load.get("token_usage", 0.0) or 0.0),
decode_prealloc_queue_reqs=int(
disagg.get("decode_prealloc_queue_reqs", 0) or 0
),
decode_transfer_queue_reqs=int(
disagg.get("decode_transfer_queue_reqs", 0) or 0
),
decode_retracted_queue_reqs=int(
disagg.get("decode_retracted_queue_reqs", 0) or 0
),
)
def _decode_load_backpressure_reason(
snapshot: DecodeLoadSnapshot | None,
*,
routing_mode: Literal["direct", "seed"],
) -> str | None:
if snapshot is None:
return None
if routing_mode == "direct":
if snapshot.decode_retracted_queue_reqs > 0 and snapshot.token_usage >= 0.99:
return "d-retracted"
if snapshot.token_usage >= 0.992:
return "d-token-usage-critical"
else:
if snapshot.decode_retracted_queue_reqs > 0:
return "d-retracted"
if snapshot.token_usage >= 0.985:
return "d-token-usage-critical"
if routing_mode == "seed" and snapshot.token_usage >= 0.94 and (
snapshot.decode_prealloc_queue_reqs > 0
or snapshot.decode_transfer_queue_reqs > 0
):
return "d-prealloc-backpressure"
return None
def _is_decode_backpressure_reason(reason: str | None) -> bool:
return reason in {
"d-retracted",
"d-token-usage-critical",
"d-prealloc-backpressure",
}
def _dynamic_decode_headroom_tokens(
*,
residency: DecodeResidencyState,
server_url: str,
snapshot: DecodeLoadSnapshot | None,
routing_mode: Literal["direct", "seed"],
) -> int:
if snapshot is None:
return residency.headroom_tokens.get(server_url, 0)
base_reserved = max(512, residency.reserved_decode_tokens.get(server_url, 0))
if routing_mode == "direct":
direct_queue_pressure = max(
1,
snapshot.decode_prealloc_queue_reqs
+ snapshot.decode_transfer_queue_reqs
+ snapshot.decode_retracted_queue_reqs,
)
capacity_divisor = 24
minimum_headroom = 4096
return max(
base_reserved * direct_queue_pressure,
snapshot.max_total_num_tokens // capacity_divisor,
minimum_headroom,
)
disagg_queued = (
snapshot.decode_prealloc_queue_reqs
+ snapshot.decode_transfer_queue_reqs
+ snapshot.decode_retracted_queue_reqs
)
active_decode_pressure = max(1, snapshot.num_running_reqs + disagg_queued)
capacity_divisor = 15
minimum_headroom = 12288
return max(
base_reserved * active_decode_pressure,
snapshot.max_total_num_tokens // capacity_divisor,
minimum_headroom,
)
def _commit_session_residency(
*,
residency: DecodeResidencyState,
session: DirectSessionState,
request: TraceRequest,
reserved_tokens: int,
) -> None:
_release_reserved_tokens(residency, session.server_url, reserved_tokens)
previous_tokens = session.resident_tokens if session.opened else 0
new_tokens = _estimate_session_resident_tokens(request)
delta_tokens = new_tokens - previous_tokens
if delta_tokens != 0:
residency.resident_tokens_by_server[session.server_url] = (
residency.resident_tokens_by_server.get(session.server_url, 0) + delta_tokens
)
session.opened = True
session.resident_tokens = new_tokens
session.last_trace_request = request
session.last_access_s = time.perf_counter()
if session.prefill_opened:
session.prefill_low_priority = True
def _commit_prefill_backup_residency(
*,
residency: DecodeResidencyState,
session: DirectSessionState,
request: TraceRequest,
prefill_url: str,
reserved_tokens: int,
) -> None:
_release_prefill_reserved_tokens(residency, prefill_url, reserved_tokens)
previous_tokens = session.prefill_resident_tokens if session.prefill_opened else 0
new_tokens = _estimate_session_resident_tokens(request)
delta_tokens = new_tokens - previous_tokens
if delta_tokens != 0:
residency.prefill_resident_tokens_by_server[prefill_url] = (
residency.prefill_resident_tokens_by_server.get(prefill_url, 0)
+ delta_tokens
)
session.prefill_server_url = prefill_url
session.prefill_opened = True
session.prefill_resident_tokens = new_tokens
session.prefill_last_access_s = time.perf_counter()
session.prefill_low_priority = session.opened
async def _close_prefill_session(
*,
client: httpx.AsyncClient,
session: DirectSessionState,
residency: DecodeResidencyState,
) -> None:
if not session.prefill_opened or session.prefill_server_url is None:
session.prefill_opened = False
session.prefill_resident_tokens = 0
session.prefill_low_priority = False
return
prefill_url = session.prefill_server_url
await _close_streaming_session(
client=client,
server_url=prefill_url,
session_id=session.session_id,
allow_missing=True,
)
remaining = (
residency.prefill_resident_tokens_by_server.get(prefill_url, 0)
- session.prefill_resident_tokens
)
if remaining > 0:
residency.prefill_resident_tokens_by_server[prefill_url] = remaining
else:
residency.prefill_resident_tokens_by_server.pop(prefill_url, None)
session.prefill_opened = False
session.prefill_resident_tokens = 0
session.prefill_low_priority = False
async def _close_decode_session(
*,
client: httpx.AsyncClient,
session: DirectSessionState,
residency: DecodeResidencyState,
evicting_for_capacity: bool = False,
) -> None:
if not session.opened:
session.resident_tokens = 0
return
await _close_streaming_session(
client=client,
server_url=session.server_url,
session_id=session.session_id,
allow_missing=True,
)
remaining = (
residency.resident_tokens_by_server.get(session.server_url, 0)
- session.resident_tokens
)
if remaining > 0:
residency.resident_tokens_by_server[session.server_url] = remaining
else:
residency.resident_tokens_by_server.pop(session.server_url, None)
session.opened = False
session.resident_tokens = 0
if session.prefill_opened:
residency.decode_evictions_prefill_backed += int(evicting_for_capacity)
session.prefill_low_priority = False
elif evicting_for_capacity:
residency.decode_evictions_without_prefill_backup += 1
async def _reserve_prefill_backup_capacity(
*,
client: httpx.AsyncClient,
request: TraceRequest,
prefill_url: str,
session: DirectSessionState,
direct_sessions: dict[str, DirectSessionState],
residency: DecodeResidencyState,
) -> tuple[bool, int, int]:
session_cache, max_total_num_tokens, _reserved_decode_tokens = (
await _fetch_decode_server_state(
client=client,
server_url=prefill_url,
)
)
if max_total_num_tokens > 0:
residency.prefill_capacity_tokens[prefill_url] = max_total_num_tokens
capacity_tokens = residency.prefill_capacity_tokens.get(prefill_url, 0)
headroom_tokens = residency.prefill_headroom_tokens.get(prefill_url, 0)
if capacity_tokens <= 0:
return True, 0, 0
low_occupancy_headroom_tokens = max(
headroom_tokens,
capacity_tokens // 2,
)
target_session_status = _find_session_cache_status(
session_cache,
session.session_id,
)
if (
isinstance(target_session_status, dict)
and bool(target_session_status.get("resident"))
):
current_tokens = int(target_session_status.get("resident_tokens", 0) or 0)
else:
current_tokens = (
session.prefill_resident_tokens
if session.prefill_opened and session.prefill_server_url == prefill_url
else 0
)
target_tokens = _estimate_session_resident_tokens(request)
required_extra_tokens = max(0, target_tokens - current_tokens)
evicted_sessions = 0
max_backup_sessions = max(1, capacity_tokens // max(1, target_tokens * 2))
max_backup_sessions = min(max_backup_sessions, 4)
available_tokens = int(session_cache.get("available_tokens", 0) or 0)
if available_tokens <= 0:
held_tokens = int(session_cache.get("held_tokens", 0) or 0)
available_tokens = max(0, capacity_tokens - held_tokens)
available_tokens -= residency.prefill_reserved_tokens_by_server.get(prefill_url, 0)
def has_enough_prefill_headroom() -> bool:
return available_tokens - required_extra_tokens >= low_occupancy_headroom_tokens
def prefill_backup_count() -> int:
return sum(
1
for candidate in direct_sessions.values()
if candidate.prefill_opened and candidate.prefill_server_url == prefill_url
)
while (
required_extra_tokens > 0
and (
not has_enough_prefill_headroom()
or (
not session.prefill_opened
and prefill_backup_count() >= max_backup_sessions
)
)
):
candidates = sorted(
(
candidate
for candidate in direct_sessions.values()
if candidate.prefill_opened
and candidate.prefill_server_url == prefill_url
and candidate.session_id != session.session_id
and candidate.active_requests <= 0
),
key=lambda candidate: (
0 if candidate.prefill_low_priority else 1,
candidate.prefill_last_access_s,
),
)
if not candidates:
break
freed_tokens = candidates[0].prefill_resident_tokens
await _close_prefill_session(
client=client,
session=candidates[0],
residency=residency,
)
available_tokens += freed_tokens
evicted_sessions += 1
if not has_enough_prefill_headroom():
return False, 0, evicted_sessions
_add_prefill_reserved_tokens(residency, prefill_url, required_extra_tokens)
return True, required_extra_tokens, evicted_sessions
async def _reserve_decode_session_capacity(
*,
client: httpx.AsyncClient,
request: TraceRequest,
server_url: str,
session: DirectSessionState,
direct_sessions: dict[str, DirectSessionState],
residency: DecodeResidencyState,
treat_as_fresh_session: bool,
routing_mode: Literal["direct", "seed"],
admission_mode: KvCacheAdmissionMode,
) -> tuple[bool, int, int, int, str | None]:
if admission_mode == "router":
return await _reserve_decode_session_capacity_from_router_state(
client=client,
request=request,
server_url=server_url,
session=session,
direct_sessions=direct_sessions,
residency=residency,
treat_as_fresh_session=treat_as_fresh_session,
routing_mode=routing_mode,
)
if treat_as_fresh_session and session.opened:
await _close_decode_session(
client=client,
session=session,
residency=residency,
)
current_tokens = 0 if treat_as_fresh_session else session.resident_tokens
target_tokens = _estimate_session_resident_tokens(request)
required_extra_tokens = max(0, target_tokens - current_tokens)
prefill_backed_evictions = 0
if routing_mode == "direct" and not treat_as_fresh_session:
if not session.opened:
return False, 0, 0, 0, "d-session-not-resident"
admission = await _query_decode_direct_admission(
client=client,
server_url=server_url,
session_id=session.session_id,
uncached_input_tokens=max(0, request.input_length - current_tokens),
output_tokens=request.output_length,
)
if not bool(admission.get("resident")):
return False, 0, 0, 0, str(admission.get("reason") or "d-session-not-resident")
if not bool(admission.get("can_admit")):
return (
False,
0,
int(admission.get("evicted_session_count", 0) or 0),
0,
str(admission.get("reason") or "d-no-space"),
)
reserved_tokens = int(
admission.get("required_tokens", required_extra_tokens)
or required_extra_tokens
)
_add_reserved_tokens(residency, server_url, reserved_tokens)
return (
True,
reserved_tokens,
int(admission.get("evicted_session_count", 0) or 0),
0,
None,
)
session_cache, max_total_num_tokens, reserved_decode_tokens = (
await _fetch_decode_server_state(
client=client,
server_url=server_url,
)
)
if max_total_num_tokens > 0:
residency.capacity_tokens[server_url] = max_total_num_tokens
if reserved_decode_tokens > 0:
residency.reserved_decode_tokens[server_url] = reserved_decode_tokens
target_session_status = _find_session_cache_status(
session_cache,
session.session_id,
)
if routing_mode == "direct" and not (
isinstance(target_session_status, dict)
and bool(target_session_status.get("resident"))
):
return False, 0, 0, 0, "d-session-not-resident"
load_snapshot = await _fetch_decode_load_snapshot(
client=client,
server_url=server_url,
)
if load_snapshot is not None and load_snapshot.max_total_num_tokens > 0:
residency.capacity_tokens[server_url] = load_snapshot.max_total_num_tokens
backpressure_reason = _decode_load_backpressure_reason(
load_snapshot,
routing_mode=routing_mode,
)
if backpressure_reason is not None:
return False, 0, 0, 0, backpressure_reason
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
evicted_sessions = 0
while (
required_extra_tokens > 0
and residency.resident_tokens_by_server.get(server_url, 0)
+ residency.reserved_tokens_by_server.get(server_url, 0)
+ required_extra_tokens
> usable_capacity_tokens + int(session_cache.get("idle_evictable_tokens", 0) or 0)
):
candidates = sorted(
(
candidate
for candidate in direct_sessions.values()
if candidate.opened
and candidate.server_url == server_url
and candidate.session_id != session.session_id
and candidate.active_requests <= 0
),
key=lambda candidate: candidate.last_access_s,
)
if not candidates:
break
await _close_decode_session(
client=client,
session=candidates[0],
residency=residency,
evicting_for_capacity=True,
)
prefill_backed_evictions += int(candidates[0].prefill_opened)
evicted_sessions += 1
if evicted_sessions > 0:
load_snapshot = await _fetch_decode_load_snapshot(
client=client,
server_url=server_url,
)
session_cache, max_total_num_tokens, reserved_decode_tokens = (
await _fetch_decode_server_state(
client=client,
server_url=server_url,
)
)
if max_total_num_tokens > 0:
residency.capacity_tokens[server_url] = max_total_num_tokens
if reserved_decode_tokens > 0:
residency.reserved_decode_tokens[server_url] = reserved_decode_tokens
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
if load_snapshot is not None:
dynamic_headroom = _dynamic_decode_headroom_tokens(
residency=residency,
server_url=server_url,
snapshot=load_snapshot,
routing_mode=routing_mode,
)
residency.headroom_tokens[server_url] = min(
residency.capacity_tokens.get(server_url, dynamic_headroom),
dynamic_headroom,
)
usable_capacity_tokens = max(
0,
residency.capacity_tokens.get(server_url, 0) - dynamic_headroom,
)
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
if load_snapshot is not None:
dynamic_headroom = _dynamic_decode_headroom_tokens(
residency=residency,
server_url=server_url,
snapshot=load_snapshot,
routing_mode=routing_mode,
)
residency.headroom_tokens[server_url] = min(
residency.capacity_tokens.get(server_url, dynamic_headroom),
dynamic_headroom,
)
usable_capacity_tokens = max(
0,
residency.capacity_tokens.get(server_url, 0) - dynamic_headroom,
)
effective_used_tokens = (
load_snapshot.num_used_tokens
if load_snapshot is not None
else residency.resident_tokens_by_server.get(server_url, 0)
) + residency.reserved_tokens_by_server.get(server_url, 0)
idle_evictable_tokens = int(session_cache.get("idle_evictable_tokens", 0) or 0)
if (
routing_mode == "direct"
and isinstance(target_session_status, dict)
and bool(target_session_status.get("idle_evictable"))
):
idle_evictable_tokens = max(
0,
idle_evictable_tokens
- int(target_session_status.get("resident_tokens", 0) or 0),
)
if effective_used_tokens + required_extra_tokens > (
usable_capacity_tokens + idle_evictable_tokens
):
return False, 0, evicted_sessions, prefill_backed_evictions, "d-no-space"
_add_reserved_tokens(residency, server_url, required_extra_tokens)
return True, required_extra_tokens, evicted_sessions, prefill_backed_evictions, None
async def _reserve_decode_session_capacity_from_router_state(
*,
client: httpx.AsyncClient,
request: TraceRequest,
server_url: str,
session: DirectSessionState,
direct_sessions: dict[str, DirectSessionState],
residency: DecodeResidencyState,
treat_as_fresh_session: bool,
routing_mode: Literal["direct", "seed"],
) -> tuple[bool, int, int, int, str | None]:
if treat_as_fresh_session and session.opened:
await _close_decode_session(
client=client,
session=session,
residency=residency,
)
if routing_mode == "direct" and not session.opened:
return False, 0, 0, 0, "d-session-not-resident"
current_tokens = 0 if treat_as_fresh_session else session.resident_tokens
target_tokens = _estimate_session_resident_tokens(request)
required_extra_tokens = max(0, target_tokens - current_tokens)
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
# If discovery failed, do not force every request down the P/D fallback path.
# The router can still preserve correctness; this only disables proactive
# capacity admission until the worker reports capacity again in a later run.
if usable_capacity_tokens <= 0:
_add_reserved_tokens(residency, server_url, required_extra_tokens)
return True, required_extra_tokens, 0, 0, None
evicted_sessions = 0
prefill_backed_evictions = 0
while (
required_extra_tokens > 0
and residency.resident_tokens_by_server.get(server_url, 0)
+ residency.reserved_tokens_by_server.get(server_url, 0)
+ required_extra_tokens
> usable_capacity_tokens
):
candidates = sorted(
(
candidate
for candidate in direct_sessions.values()
if candidate.opened
and candidate.server_url == server_url
and candidate.session_id != session.session_id
and candidate.active_requests <= 0
),
key=lambda candidate: candidate.last_access_s,
)
if not candidates:
break
await _close_decode_session(
client=client,
session=candidates[0],
residency=residency,
evicting_for_capacity=True,
)
prefill_backed_evictions += int(candidates[0].prefill_opened)
evicted_sessions += 1
if (
residency.resident_tokens_by_server.get(server_url, 0)
+ residency.reserved_tokens_by_server.get(server_url, 0)
+ required_extra_tokens
> usable_capacity_tokens
):
return False, 0, evicted_sessions, prefill_backed_evictions, "d-no-space"
_add_reserved_tokens(residency, server_url, required_extra_tokens)
return True, required_extra_tokens, evicted_sessions, prefill_backed_evictions, None
def _build_direct_prompt(
*,
request: TraceRequest,
session: DirectSessionState,
) -> tuple[str, int, bool, bool]:
append_length, session_reused, session_reset = _inspect_direct_request(
request=request,
session=session,
)
if session_reset:
return build_synthetic_prompt(request), request.input_length, False, True
if not session_reused:
return build_synthetic_prompt(request), request.input_length, False, False
return (
build_synthetic_append_chunk(request, append_length),
append_length,
session_reused,
session_reset,
)
def _build_direct_full_input_ids(
request: TraceRequest,
*,
block_token_budget: int = 24,
) -> list[int]:
input_ids: list[int] = []
for hash_id in request.hash_ids:
base = 1000 + (hash_id % 3000)
for offset in range(block_token_budget):
input_ids.append(1000 + ((base + offset) % 3000))
while len(input_ids) < request.input_length:
input_ids.append(5000 + (len(input_ids) % 2000))
return input_ids[: request.input_length]
def _build_direct_append_input_ids(
request: TraceRequest,
append_length: int,
) -> list[int]:
if append_length <= 0:
return []
base = 9000 + ((request.chat_id + request.turn_id) % 2000)
return [
9000 + ((base + offset) % 2000)
for offset in range(append_length)
]
async def _invoke_plain_router(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
decision,
execution_mode: str,
) -> ExecutionResult:
latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_router(
client=client,
request=request,
config=config,
decode_worker_index=decision.decode_worker_index,
)
return ExecutionResult(
execution_mode=execution_mode,
actual_kv_transfer_blocks=decision.kv_transfer_blocks,
effective_input_length=request.input_length,
cached_tokens=cached_tokens,
session_reused=False,
session_reset=False,
latency_s=latency_s,
ttft_s=ttft_s,
tpot_s=tpot_s,
)
async def _invoke_kvcache_seeded_router(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
decision,
prefill_url: str,
decode_session: DirectSessionState,
direct_sessions: dict[str, DirectSessionState],
direct_session_lock: asyncio.Lock,
decode_residency: DecodeResidencyState,
reserved_tokens: int,
execution_mode: str,
) -> ExecutionResult:
async with direct_session_lock:
keep_prefill_backup, prefill_reserved_tokens, _prefill_evicted = (
await _reserve_prefill_backup_capacity(
client=client,
request=request,
prefill_url=prefill_url,
session=decode_session,
direct_sessions=direct_sessions,
residency=decode_residency,
)
)
if (
decode_session.prefill_opened
and decode_session.prefill_server_url != prefill_url
):
await _close_prefill_session(
client=client,
session=decode_session,
residency=decode_residency,
)
prefill_session_newly_opened = False
async with direct_session_lock:
if not decode_session.prefill_opened:
await _open_streaming_session(
client=client,
server_url=prefill_url,
session_id=request.session_id,
request=request,
)
decode_session.prefill_opened = True
decode_session.prefill_server_url = prefill_url
prefill_session_newly_opened = True
decode_session_newly_opened = False
try:
async with direct_session_lock:
if not decode_session.opened:
await _open_streaming_session(
client=client,
server_url=decode_session.server_url,
session_id=request.session_id,
request=request,
)
decode_session.opened = True
decode_session_newly_opened = True
decode_session.active_requests += 1
latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_router(
client=client,
request=request,
config=config,
decode_worker_index=decision.decode_worker_index,
session_id=request.session_id,
)
except Exception:
async with direct_session_lock:
decode_session.active_requests = max(0, decode_session.active_requests - 1)
_release_reserved_tokens(
decode_residency,
decode_session.server_url,
reserved_tokens,
)
_release_prefill_reserved_tokens(
decode_residency,
prefill_url,
prefill_reserved_tokens,
)
if decode_session_newly_opened:
await _close_decode_session(
client=client,
session=decode_session,
residency=decode_residency,
)
if prefill_session_newly_opened:
await _close_prefill_session(
client=client,
session=decode_session,
residency=decode_residency,
)
raise
async with direct_session_lock:
decode_session.active_requests = max(0, decode_session.active_requests - 1)
if keep_prefill_backup:
_commit_prefill_backup_residency(
residency=decode_residency,
session=decode_session,
request=request,
prefill_url=prefill_url,
reserved_tokens=prefill_reserved_tokens,
)
else:
_release_prefill_reserved_tokens(
decode_residency,
prefill_url,
prefill_reserved_tokens,
)
await _close_prefill_session(
client=client,
session=decode_session,
residency=decode_residency,
)
_commit_session_residency(
residency=decode_residency,
session=decode_session,
request=request,
reserved_tokens=reserved_tokens,
)
return ExecutionResult(
execution_mode=execution_mode,
actual_kv_transfer_blocks=decision.kv_transfer_blocks,
effective_input_length=request.input_length,
cached_tokens=cached_tokens,
session_reused=False,
session_reset=False,
latency_s=latency_s,
ttft_s=ttft_s,
tpot_s=tpot_s,
)
async def _execute_request(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
decision,
direct_sessions: dict[str, DirectSessionState],
direct_session_lock: asyncio.Lock,
decode_residency: DecodeResidencyState,
) -> ExecutionResult:
if config.mechanism_name == "pd-disaggregation":
if not config.router_url:
raise ValueError("router_url is required for pd-disaggregation replay")
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode="pd-disaggregation-router",
)
if config.mechanism_name == "pd-colo":
return await _invoke_direct(
client=client,
request=request,
config=config,
decision=decision,
direct_sessions=direct_sessions,
direct_session_lock=direct_session_lock,
)
if config.mechanism_name == "kvcache-centric":
if not config.router_url:
raise ValueError("router_url is required for kvcache-centric replay")
if not config.topology.decode_workers:
raise ValueError("kvcache-centric mechanism requires at least one decode worker")
prefill_url = _worker_url_by_id(
config.topology.prefill_workers,
decision.prefill_worker_id,
)
decode_url = config.topology.decode_workers[decision.decode_worker_index].url
async with direct_session_lock:
decode_session = direct_sessions.get(request.session_id)
if decode_session is None:
decode_session = DirectSessionState(
session_id=request.session_id,
server_url=decode_url,
)
direct_sessions[request.session_id] = decode_session
elif decode_session.server_url != decode_url and decode_session.opened:
await _close_decode_session(
client=client,
session=decode_session,
residency=decode_residency,
)
decode_session.server_url = decode_url
else:
decode_session.server_url = decode_url
direct_append_length: int | None = None
direct_session_reused = False
direct_session_reset = False
if request.turn_id > 1:
async with direct_session_lock:
(
direct_append_length,
direct_session_reused,
direct_session_reset,
) = _inspect_direct_request(
request=request,
session=decode_session,
)
if request.turn_id == 1:
async with direct_session_lock:
admit_new_decode_session = _should_admit_new_decode_session(
residency=decode_residency,
server_url=decode_url,
request=request,
session=decode_session,
direct_sessions=direct_sessions,
treat_as_fresh_session=True,
)
if not admit_new_decode_session:
can_seed = False
reserved_tokens = 0
seed_reason = "d-session-cap"
else:
can_seed, reserved_tokens, _evicted, _p_backed, seed_reason = (
await _reserve_decode_session_capacity(
client=client,
request=request,
server_url=decode_url,
session=decode_session,
direct_sessions=direct_sessions,
residency=decode_residency,
treat_as_fresh_session=True,
routing_mode="seed",
admission_mode=config.kvcache_admission_mode,
)
)
if seed_reason == "d-session-cap":
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode="pd-router-turn1-session-cap",
)
if can_seed:
return await _invoke_kvcache_seeded_router(
client=client,
request=request,
config=config,
decision=decision,
prefill_url=prefill_url,
decode_session=decode_session,
direct_sessions=direct_sessions,
direct_session_lock=direct_session_lock,
decode_residency=decode_residency,
reserved_tokens=reserved_tokens,
execution_mode="pd-router-turn1-seed",
)
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode=(
"pd-router-turn1-d-backpressure"
if seed_reason is not None and seed_reason != "d-no-space"
else "pd-router-turn1-no-d-capacity"
),
)
if (
_should_bypass_prefill(
request=request,
config=config,
decision=decision,
)
and direct_append_length is not None
and direct_session_reused
and not direct_session_reset
and direct_append_length <= config.kvcache_direct_max_uncached_tokens
):
async with direct_session_lock:
can_direct = (
decode_session.opened
and decode_session.server_url == decode_url
and direct_session_reused
and not direct_session_reset
)
direct_reserved_tokens = 0
direct_reason: str | None = None
if can_direct:
async with direct_session_lock:
(
can_direct,
direct_reserved_tokens,
_evicted,
_p_backed,
direct_reason,
) = (
await _reserve_decode_session_capacity(
client=client,
request=request,
server_url=decode_url,
session=decode_session,
direct_sessions=direct_sessions,
residency=decode_residency,
treat_as_fresh_session=False,
routing_mode="direct",
admission_mode=config.kvcache_admission_mode,
)
)
if can_direct:
return await _invoke_decode_session_direct(
client=client,
request=request,
config=config,
decision=decision,
direct_sessions=direct_sessions,
direct_session_lock=direct_session_lock,
decode_residency=decode_residency,
reserved_tokens=direct_reserved_tokens,
)
if _is_decode_backpressure_reason(direct_reason):
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode="pd-router-fallback-d-backpressure",
)
async with direct_session_lock:
admit_new_decode_session = _should_admit_new_decode_session(
residency=decode_residency,
server_url=decode_url,
request=request,
session=decode_session,
direct_sessions=direct_sessions,
treat_as_fresh_session=True,
)
if not admit_new_decode_session:
can_seed = False
reserved_tokens = 0
evicted_sessions = 0
prefill_backed_evictions = 0
seed_reason = "d-session-cap"
else:
(
can_seed,
reserved_tokens,
evicted_sessions,
prefill_backed_evictions,
seed_reason,
) = (
await _reserve_decode_session_capacity(
client=client,
request=request,
server_url=decode_url,
session=decode_session,
direct_sessions=direct_sessions,
residency=decode_residency,
treat_as_fresh_session=True,
routing_mode="seed",
admission_mode=config.kvcache_admission_mode,
)
)
if seed_reason == "d-session-cap":
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode="pd-router-fallback-session-cap",
)
if can_seed:
return await _invoke_kvcache_seeded_router(
client=client,
request=request,
config=config,
decision=decision,
prefill_url=prefill_url,
decode_session=decode_session,
direct_sessions=direct_sessions,
direct_session_lock=direct_session_lock,
decode_residency=decode_residency,
reserved_tokens=reserved_tokens,
execution_mode=(
"pd-router-d-session-reseed"
+ _eviction_suffix(
evicted_sessions,
prefill_backed_evictions,
)
),
)
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode=(
"pd-router-fallback-d-backpressure"
if _is_decode_backpressure_reason(seed_reason)
else "pd-router-fallback-no-d-capacity"
),
)
async with direct_session_lock:
admit_new_decode_session = _should_admit_new_decode_session(
residency=decode_residency,
server_url=decode_url,
request=request,
session=decode_session,
direct_sessions=direct_sessions,
treat_as_fresh_session=True,
)
if not admit_new_decode_session:
can_seed = False
reserved_tokens = 0
evicted_sessions = 0
prefill_backed_evictions = 0
seed_reason = "d-session-cap"
else:
(
can_seed,
reserved_tokens,
evicted_sessions,
prefill_backed_evictions,
seed_reason,
) = (
await _reserve_decode_session_capacity(
client=client,
request=request,
server_url=decode_url,
session=decode_session,
direct_sessions=direct_sessions,
residency=decode_residency,
treat_as_fresh_session=True,
routing_mode="seed",
admission_mode=config.kvcache_admission_mode,
)
)
if seed_reason == "d-session-cap":
return await _invoke_plain_router(
request=request,
client=client,
config=config,
decision=decision,
execution_mode="pd-router-fallback-large-append-session-cap",
)
if can_seed:
return await _invoke_kvcache_seeded_router(
client=client,
request=request,
config=config,
decision=decision,
prefill_url=prefill_url,
decode_session=decode_session,
direct_sessions=direct_sessions,
direct_session_lock=direct_session_lock,
decode_residency=decode_residency,
reserved_tokens=reserved_tokens,
execution_mode=(
"pd-router-large-append-reseed"
+ _eviction_suffix(
evicted_sessions,
prefill_backed_evictions,
)
),
)
return await _invoke_plain_router(
request=request,
client=client,
config=config,
decision=decision,
execution_mode=(
"pd-router-fallback-d-backpressure"
if _is_decode_backpressure_reason(seed_reason)
else "pd-router-fallback-large-append"
),
)
raise ValueError(f"Unsupported mechanism: {config.mechanism_name}")
async def _invoke_direct(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
decision,
direct_sessions: dict[str, DirectSessionState],
direct_session_lock: asyncio.Lock,
) -> ExecutionResult:
direct_workers = config.topology.direct_workers
if not direct_workers:
raise ValueError("pd-colo mechanism requires at least one direct worker")
server_url = direct_workers[decision.decode_worker_index].url
session = direct_sessions.get(request.session_id)
if session is None or session.server_url != server_url:
session = DirectSessionState(session_id=request.session_id, server_url=server_url)
direct_sessions[request.session_id] = session
return await _invoke_session_direct(
client=client,
request=request,
config=config,
session=session,
execution_mode="pd-colo-direct-session",
direct_session_lock=direct_session_lock,
)
async def _invoke_session_direct(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
session: DirectSessionState,
execution_mode: str,
decode_residency: DecodeResidencyState | None = None,
reserved_tokens: int = 0,
direct_session_lock: asyncio.Lock | None = None,
) -> ExecutionResult:
_prompt, effective_input_length, session_reused, session_reset = _build_direct_prompt(
request=request,
session=session,
)
if session_reused:
input_ids = _build_direct_append_input_ids(request, effective_input_length)
else:
input_ids = _build_direct_full_input_ids(request)
if session.opened and (session_reset or not session_reused):
if decode_residency is not None:
await _close_decode_session(
client=client,
session=session,
residency=decode_residency,
)
else:
await _close_streaming_session(
client=client,
server_url=session.server_url,
session_id=session.session_id,
allow_missing=True,
)
session.opened = False
session.resident_tokens = 0
if not session.opened:
await _open_streaming_session(
client=client,
server_url=session.server_url,
session_id=session.session_id,
request=request,
)
session.opened = True
if direct_session_lock is not None:
async with direct_session_lock:
session.active_requests += 1
try:
latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_generate(
client=client,
base_url=session.server_url,
headers={"x-request-id": request.request_id},
payload={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0,
"max_new_tokens": max(1, request.output_length),
"min_new_tokens": max(1, request.output_length),
"ignore_eos": True,
"no_stop_trim": True,
"skip_special_tokens": False,
},
"session_params": {"id": session.session_id},
"stream": config.stream,
},
timeout_s=config.timeout_s,
stream_idle_timeout_s=config.stream_idle_timeout_s,
stream=config.stream,
)
finally:
if direct_session_lock is not None:
async with direct_session_lock:
session.active_requests = max(0, session.active_requests - 1)
if decode_residency is not None:
_commit_session_residency(
residency=decode_residency,
session=session,
request=request,
reserved_tokens=reserved_tokens,
)
else:
session.last_trace_request = request
session.last_access_s = time.perf_counter()
return ExecutionResult(
execution_mode=execution_mode,
actual_kv_transfer_blocks=0,
effective_input_length=len(input_ids),
cached_tokens=cached_tokens,
session_reused=session_reused,
session_reset=session_reset,
latency_s=latency_s,
ttft_s=ttft_s,
tpot_s=tpot_s,
)
async def _invoke_decode_session_direct(
*,
client: httpx.AsyncClient,
request: TraceRequest,
config: ReplayConfig,
decision,
direct_sessions: dict[str, DirectSessionState],
direct_session_lock: asyncio.Lock,
decode_residency: DecodeResidencyState,
reserved_tokens: int,
) -> ExecutionResult:
decode_url = config.topology.decode_workers[decision.decode_worker_index].url
session = direct_sessions[request.session_id]
try:
return await _invoke_session_direct(
client=client,
request=request,
config=config,
session=session,
execution_mode="kvcache-direct-to-d-session",
decode_residency=decode_residency,
reserved_tokens=reserved_tokens,
direct_session_lock=direct_session_lock,
)
except Exception:
async with direct_session_lock:
_release_reserved_tokens(
decode_residency,
decode_url,
reserved_tokens,
)
raise
def _should_bypass_prefill(
*,
request: TraceRequest,
config: ReplayConfig,
decision,
block_token_budget: int = 24,
) -> bool:
if request.turn_id <= 1:
return False
if decision.observed_overlap_blocks <= 0:
return False
uncached_tokens = max(0, decision.kv_transfer_blocks * block_token_budget)
return uncached_tokens <= config.kvcache_direct_max_uncached_tokens
def _worker_url_by_id(workers, worker_id: str) -> str:
for worker in workers:
if worker.worker_id == worker_id:
return worker.url
raise KeyError(f"Unknown worker id: {worker_id}")
def _build_headers(
*,
request: TraceRequest,
header_mode: HeaderMode,
decode_worker_index: int,
policy_name: str,
) -> dict[str, str]:
if header_mode == "auto":
header_mode = "routing-key" if policy_name == "sticky" else "none"
headers = {
"x-request-id": request.request_id,
}
if header_mode == "routing-key":
headers["x-smg-routing-key"] = request.session_id
elif header_mode == "target-worker":
headers["x-smg-target-worker"] = str(decode_worker_index)
return headers
def _contains_token(payload: dict) -> bool:
choices = payload.get("choices")
if not isinstance(choices, list) or not choices:
return False
delta = choices[0].get("delta")
if not isinstance(delta, dict):
return False
reasoning_content = delta.get("reasoning_content")
if isinstance(reasoning_content, str) and reasoning_content:
return True
content = delta.get("content")
if isinstance(content, str) and content:
return True
if isinstance(content, list):
return any(
isinstance(item, dict) and item.get("text")
for item in content
)
tool_calls = delta.get("tool_calls")
if isinstance(tool_calls, list):
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
continue
if tool_call.get("id"):
return True
function = tool_call.get("function")
if not isinstance(function, dict):
continue
if function.get("name") or function.get("arguments"):
return True
return False
def _contains_generate_token(payload: dict) -> bool:
text = payload.get("text")
if isinstance(text, str) and text:
return True
meta_info = payload.get("meta_info")
if not isinstance(meta_info, dict):
return False
return int(meta_info.get("completion_tokens", 0)) > 0
def _is_generate_terminal_chunk(payload: dict) -> bool:
meta_info = payload.get("meta_info")
if not isinstance(meta_info, dict):
return False
return meta_info.get("finish_reason") is not None
def _extract_generate_cached_tokens(payload: dict) -> int:
meta_info = payload.get("meta_info")
if not isinstance(meta_info, dict):
return 0
return int(meta_info.get("cached_tokens", 0) or 0)
def _extract_openai_cached_tokens(payload: dict) -> int:
usage = payload.get("usage")
if not isinstance(usage, dict):
return 0
prompt_tokens_details = usage.get("prompt_tokens_details")
if not isinstance(prompt_tokens_details, dict):
return 0
return int(prompt_tokens_details.get("cached_tokens", 0) or 0)
async def _aiter_lines(
response: httpx.Response,
*,
idle_timeout_s: float | None,
):
line_iterator = response.aiter_lines()
while True:
try:
if idle_timeout_s is None or idle_timeout_s <= 0:
line = await anext(line_iterator)
else:
line = await asyncio.wait_for(
anext(line_iterator),
timeout=idle_timeout_s,
)
except StopAsyncIteration:
return
yield line
def _is_terminal_chunk(payload: dict) -> bool:
choices = payload.get("choices")
if not isinstance(choices, list) or not choices:
return False
finish_reason = choices[0].get("finish_reason")
return finish_reason is not None