2461 lines
83 KiB
Python
2461 lines
83 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import time
|
|
from collections import Counter
|
|
from dataclasses import dataclass, field, replace
|
|
from pathlib import Path
|
|
from typing import Any, Literal
|
|
|
|
import httpx
|
|
|
|
from agentic_pd_hybrid.metrics import (
|
|
RequestMetrics,
|
|
write_metrics_jsonl,
|
|
write_summary_json,
|
|
)
|
|
from agentic_pd_hybrid.policies import RoutingState, create_policy
|
|
from agentic_pd_hybrid.topology import SingleNodeTopology
|
|
from agentic_pd_hybrid.trace import (
|
|
TraceRequest,
|
|
build_synthetic_append_chunk,
|
|
build_synthetic_prompt,
|
|
load_trace,
|
|
)
|
|
|
|
|
|
HeaderMode = Literal["none", "routing-key", "target-worker", "auto"]
|
|
KvCacheAdmissionMode = Literal["router", "worker"]
|
|
KvCachePrefillBackupPolicy = Literal["release-after-transfer", "capacity-backup"]
|
|
_ADMISSION_PROBE_TIMEOUT_S = 2.0
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ReplayConfig:
|
|
trace_path: Path
|
|
output_path: Path
|
|
policy_name: str
|
|
mechanism_name: str
|
|
topology: SingleNodeTopology
|
|
router_url: str | None = None
|
|
model_name: str | None = None
|
|
pace: bool = True
|
|
time_scale: float = 1.0
|
|
request_limit: int | None = None
|
|
concurrency_limit: int = 32
|
|
header_mode: HeaderMode = "auto"
|
|
timeout_s: float = 600.0
|
|
stream: bool = True
|
|
stream_idle_timeout_s: float | None = 900.0
|
|
kvcache_direct_max_uncached_tokens: int = 2048
|
|
kvcache_admission_mode: KvCacheAdmissionMode = "router"
|
|
kvcache_seed_max_resident_tokens: int | None = None
|
|
kvcache_seed_max_output_tokens: int | None = None
|
|
kvcache_seed_min_turn_id: int = 1
|
|
kvcache_seed_only_multiturn_sessions: bool = False
|
|
kvcache_seed_allowed_session_ids: frozenset[str] | None = None
|
|
kvcache_prefill_backup_policy: KvCachePrefillBackupPolicy = (
|
|
"release-after-transfer"
|
|
)
|
|
kvcache_seed_max_inflight_decode: int | None = 3
|
|
kvcache_seed_max_decode_transfer_queue_reqs: int | None = None
|
|
kvcache_direct_max_decode_transfer_queue_reqs: int | None = None
|
|
kvcache_prefill_priority_eviction: bool = False
|
|
kvcache_prefill_direct_priority: int = -100
|
|
kvcache_prefill_normal_priority: int = 100
|
|
|
|
|
|
@dataclass
|
|
class DirectSessionState:
|
|
session_id: str
|
|
server_url: str
|
|
opened: bool = False
|
|
last_trace_request: TraceRequest | None = None
|
|
resident_tokens: int = 0
|
|
last_access_s: float = 0.0
|
|
active_requests: int = 0
|
|
prefill_server_url: str | None = None
|
|
prefill_opened: bool = False
|
|
prefill_resident_tokens: int = 0
|
|
prefill_last_access_s: float = 0.0
|
|
prefill_low_priority: bool = False
|
|
|
|
|
|
@dataclass
|
|
class DecodeResidencyState:
|
|
capacity_tokens: dict[str, int] = field(default_factory=dict)
|
|
headroom_tokens: dict[str, int] = field(default_factory=dict)
|
|
reserved_decode_tokens: dict[str, int] = field(default_factory=dict)
|
|
resident_tokens_by_server: dict[str, int] = field(default_factory=dict)
|
|
reserved_tokens_by_server: dict[str, int] = field(default_factory=dict)
|
|
prefill_capacity_tokens: dict[str, int] = field(default_factory=dict)
|
|
prefill_headroom_tokens: dict[str, int] = field(default_factory=dict)
|
|
prefill_resident_tokens_by_server: dict[str, int] = field(default_factory=dict)
|
|
prefill_reserved_tokens_by_server: dict[str, int] = field(default_factory=dict)
|
|
decode_evictions_prefill_backed: int = 0
|
|
decode_evictions_without_prefill_backup: int = 0
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DecodeLoadSnapshot:
|
|
timestamp_s: float
|
|
num_running_reqs: int
|
|
num_waiting_reqs: int
|
|
num_used_tokens: int
|
|
max_total_num_tokens: int
|
|
token_usage: float
|
|
decode_prealloc_queue_reqs: int
|
|
decode_transfer_queue_reqs: int
|
|
decode_retracted_queue_reqs: int
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExecutionResult:
|
|
execution_mode: str
|
|
actual_kv_transfer_blocks: int
|
|
effective_input_length: int | None
|
|
cached_tokens: int
|
|
session_reused: bool
|
|
session_reset: bool
|
|
latency_s: float | None
|
|
ttft_s: float | None
|
|
tpot_s: float | None
|
|
prefill_request_priority: int | None = None
|
|
decode_request_priority: int | None = None
|
|
error: str | None = None
|
|
|
|
|
|
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
|
|
requests = load_trace(config.trace_path, request_limit=config.request_limit)
|
|
if config.kvcache_seed_only_multiturn_sessions:
|
|
session_turns = Counter(request.session_id for request in requests)
|
|
config = replace(
|
|
config,
|
|
kvcache_seed_allowed_session_ids=frozenset(
|
|
session_id
|
|
for session_id, turn_count in session_turns.items()
|
|
if turn_count > 1
|
|
),
|
|
)
|
|
policy = create_policy(config.policy_name)
|
|
state = RoutingState.create(config.topology)
|
|
state_lock = asyncio.Lock()
|
|
semaphore = asyncio.Semaphore(config.concurrency_limit)
|
|
start_time = time.perf_counter()
|
|
first_timestamp = requests[0].timestamp_s if requests else 0.0
|
|
session_tail_tasks: dict[str, asyncio.Task[RequestMetrics]] = {}
|
|
direct_sessions: dict[str, DirectSessionState] = {}
|
|
direct_session_lock = asyncio.Lock()
|
|
async with httpx.AsyncClient(timeout=config.timeout_s, trust_env=False) as client:
|
|
decode_residency = await _discover_decode_residency(
|
|
client=client,
|
|
config=config,
|
|
)
|
|
tasks = []
|
|
for request in requests:
|
|
if config.pace:
|
|
target_offset = (request.timestamp_s - first_timestamp) / config.time_scale
|
|
sleep_s = target_offset - (time.perf_counter() - start_time)
|
|
if sleep_s > 0:
|
|
await asyncio.sleep(sleep_s)
|
|
tasks.append(
|
|
asyncio.create_task(
|
|
_run_request(
|
|
request=request,
|
|
config=config,
|
|
client=client,
|
|
policy=policy,
|
|
state=state,
|
|
state_lock=state_lock,
|
|
semaphore=semaphore,
|
|
direct_sessions=direct_sessions,
|
|
direct_session_lock=direct_session_lock,
|
|
decode_residency=decode_residency,
|
|
depends_on=session_tail_tasks.get(request.session_id),
|
|
)
|
|
)
|
|
)
|
|
session_tail_tasks[request.session_id] = tasks[-1]
|
|
|
|
results = await asyncio.gather(*tasks)
|
|
for session in direct_sessions.values():
|
|
if session.opened:
|
|
try:
|
|
await _close_streaming_session(
|
|
client=client,
|
|
server_url=session.server_url,
|
|
session_id=session.session_id,
|
|
allow_missing=True,
|
|
)
|
|
except Exception:
|
|
pass
|
|
if session.prefill_opened and session.prefill_server_url is not None:
|
|
try:
|
|
await _close_streaming_session(
|
|
client=client,
|
|
server_url=session.prefill_server_url,
|
|
session_id=session.session_id,
|
|
allow_missing=True,
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
write_metrics_jsonl(config.output_path, results)
|
|
write_summary_json(
|
|
config.output_path.with_suffix(config.output_path.suffix + ".summary.json"),
|
|
results,
|
|
trace_path=config.trace_path,
|
|
router_url=config.router_url,
|
|
)
|
|
return results
|
|
|
|
|
|
async def _run_request(
|
|
*,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
client: httpx.AsyncClient,
|
|
policy,
|
|
state: RoutingState,
|
|
state_lock: asyncio.Lock,
|
|
semaphore: asyncio.Semaphore,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
direct_session_lock: asyncio.Lock,
|
|
decode_residency: DecodeResidencyState,
|
|
depends_on: asyncio.Task[RequestMetrics] | None,
|
|
) -> RequestMetrics:
|
|
if depends_on is not None:
|
|
await depends_on
|
|
async with semaphore:
|
|
async with state_lock:
|
|
decision = policy.select(request, topology=config.topology, state=state)
|
|
|
|
try:
|
|
execution = await _execute_request(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
direct_sessions=direct_sessions,
|
|
direct_session_lock=direct_session_lock,
|
|
decode_residency=decode_residency,
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive logging path
|
|
execution = ExecutionResult(
|
|
execution_mode=config.mechanism_name,
|
|
actual_kv_transfer_blocks=0,
|
|
effective_input_length=None,
|
|
cached_tokens=0,
|
|
session_reused=False,
|
|
session_reset=False,
|
|
latency_s=None,
|
|
ttft_s=None,
|
|
tpot_s=None,
|
|
error=f"{type(exc).__name__}: {exc}",
|
|
)
|
|
|
|
async with state_lock:
|
|
state.finish(request, decision)
|
|
|
|
return RequestMetrics.from_decision(
|
|
request,
|
|
decision,
|
|
mechanism_name=config.mechanism_name,
|
|
execution_mode=execution.execution_mode,
|
|
actual_kv_transfer_blocks=execution.actual_kv_transfer_blocks,
|
|
effective_input_length=execution.effective_input_length,
|
|
cached_tokens=execution.cached_tokens,
|
|
session_reused=execution.session_reused,
|
|
session_reset=execution.session_reset,
|
|
latency_s=execution.latency_s,
|
|
ttft_s=execution.ttft_s,
|
|
tpot_s=execution.tpot_s,
|
|
prefill_request_priority=execution.prefill_request_priority,
|
|
decode_request_priority=execution.decode_request_priority,
|
|
error=execution.error,
|
|
)
|
|
|
|
|
|
async def _invoke_router(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decode_worker_index: int,
|
|
session_id: str | None = None,
|
|
prefill_request_priority: int | None = None,
|
|
decode_request_priority: int | None = None,
|
|
) -> tuple[float, float | None, float | None, int]:
|
|
headers = _build_headers(
|
|
request=request,
|
|
header_mode=config.header_mode,
|
|
decode_worker_index=decode_worker_index,
|
|
policy_name=config.policy_name,
|
|
)
|
|
assert config.router_url is not None
|
|
payload: dict[str, object] = {
|
|
"input_ids": _build_direct_full_input_ids(request),
|
|
"sampling_params": {
|
|
"temperature": 0,
|
|
"max_new_tokens": max(1, request.output_length),
|
|
"ignore_eos": True,
|
|
"no_stop_trim": True,
|
|
"skip_special_tokens": False,
|
|
},
|
|
"stream": config.stream,
|
|
}
|
|
if session_id is not None:
|
|
payload["session_params"] = {"id": session_id}
|
|
if prefill_request_priority is not None:
|
|
payload["smg_prefill_priority"] = prefill_request_priority
|
|
if decode_request_priority is not None:
|
|
payload["smg_decode_priority"] = decode_request_priority
|
|
|
|
return await _invoke_generate(
|
|
client=client,
|
|
base_url=config.router_url,
|
|
headers=headers,
|
|
payload=payload,
|
|
timeout_s=config.timeout_s,
|
|
stream_idle_timeout_s=config.stream_idle_timeout_s,
|
|
stream=config.stream,
|
|
)
|
|
|
|
|
|
def _build_payload(
|
|
*,
|
|
request: TraceRequest,
|
|
model_name: str,
|
|
prompt: str,
|
|
stream: bool,
|
|
session_params: dict[str, str] | None,
|
|
exact_output_length: bool,
|
|
) -> dict[str, object]:
|
|
payload: dict[str, object] = {
|
|
"model": model_name,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": max(1, request.output_length),
|
|
"temperature": 0,
|
|
"stream": stream,
|
|
}
|
|
if stream:
|
|
payload["stream_options"] = {"include_usage": True}
|
|
if exact_output_length:
|
|
payload.update(
|
|
{
|
|
"min_tokens": max(1, request.output_length),
|
|
"ignore_eos": True,
|
|
"no_stop_trim": True,
|
|
"skip_special_tokens": False,
|
|
}
|
|
)
|
|
if session_params is not None:
|
|
payload["session_params"] = session_params
|
|
return payload
|
|
|
|
|
|
async def _invoke_chat_completion(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
base_url: str,
|
|
headers: dict[str, str],
|
|
payload: dict[str, object],
|
|
timeout_s: float,
|
|
stream_idle_timeout_s: float | None,
|
|
stream: bool,
|
|
) -> tuple[float, float | None, float | None, int]:
|
|
start = time.perf_counter()
|
|
ttft_s: float | None = None
|
|
cached_tokens = 0
|
|
generated_tokens = int(payload.get("max_tokens", 1))
|
|
if stream:
|
|
async with client.stream(
|
|
"POST",
|
|
f"{base_url.rstrip('/')}/v1/chat/completions",
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=timeout_s,
|
|
) as response:
|
|
response.raise_for_status()
|
|
async for line in _aiter_lines(
|
|
response,
|
|
idle_timeout_s=stream_idle_timeout_s,
|
|
):
|
|
if not line.startswith("data:"):
|
|
continue
|
|
data = line[5:].strip()
|
|
if data == "[DONE]":
|
|
break
|
|
parsed = json.loads(data)
|
|
cached_tokens = max(cached_tokens, _extract_openai_cached_tokens(parsed))
|
|
if _contains_token(parsed) and ttft_s is None:
|
|
ttft_s = time.perf_counter() - start
|
|
if _is_terminal_chunk(parsed):
|
|
break
|
|
else:
|
|
response = await client.post(
|
|
f"{base_url.rstrip('/')}/v1/chat/completions",
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=timeout_s,
|
|
)
|
|
response.raise_for_status()
|
|
parsed = response.json()
|
|
cached_tokens = _extract_openai_cached_tokens(parsed)
|
|
|
|
latency_s = time.perf_counter() - start
|
|
if stream and ttft_s is None and generated_tokens > 0:
|
|
raise RuntimeError("generate stream ended before producing any token")
|
|
if ttft_s is None:
|
|
tpot_s = None
|
|
else:
|
|
tpot_s = max(0.0, latency_s - ttft_s) / max(1, generated_tokens)
|
|
return latency_s, ttft_s, tpot_s, cached_tokens
|
|
|
|
|
|
async def _invoke_generate(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
base_url: str,
|
|
headers: dict[str, str],
|
|
payload: dict[str, object],
|
|
timeout_s: float,
|
|
stream_idle_timeout_s: float | None,
|
|
stream: bool,
|
|
) -> tuple[float, float | None, float | None, int]:
|
|
start = time.perf_counter()
|
|
ttft_s: float | None = None
|
|
cached_tokens = 0
|
|
sampling_params = payload.get("sampling_params", {})
|
|
generated_tokens = int(sampling_params.get("max_new_tokens", 1))
|
|
if stream:
|
|
async with client.stream(
|
|
"POST",
|
|
f"{base_url.rstrip('/')}/generate",
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=timeout_s,
|
|
) as response:
|
|
response.raise_for_status()
|
|
async for line in _aiter_lines(
|
|
response,
|
|
idle_timeout_s=stream_idle_timeout_s,
|
|
):
|
|
if not line.startswith("data:"):
|
|
continue
|
|
data = line[5:].strip()
|
|
if data == "[DONE]":
|
|
break
|
|
parsed = json.loads(data)
|
|
error = parsed.get("error")
|
|
if isinstance(error, dict):
|
|
raise ValueError(error.get("message", json.dumps(error)))
|
|
cached_tokens = max(cached_tokens, _extract_generate_cached_tokens(parsed))
|
|
if _contains_generate_token(parsed) and ttft_s is None:
|
|
ttft_s = time.perf_counter() - start
|
|
if _is_generate_terminal_chunk(parsed):
|
|
break
|
|
else:
|
|
response = await client.post(
|
|
f"{base_url.rstrip('/')}/generate",
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=timeout_s,
|
|
)
|
|
response.raise_for_status()
|
|
parsed = response.json()
|
|
error = parsed.get("error")
|
|
if isinstance(error, dict):
|
|
raise ValueError(error.get("message", json.dumps(error)))
|
|
cached_tokens = _extract_generate_cached_tokens(parsed)
|
|
|
|
latency_s = time.perf_counter() - start
|
|
if stream and ttft_s is None and generated_tokens > 0:
|
|
raise RuntimeError("generate stream ended before producing any token")
|
|
if ttft_s is None:
|
|
tpot_s = None
|
|
else:
|
|
tpot_s = max(0.0, latency_s - ttft_s) / max(1, generated_tokens)
|
|
return latency_s, ttft_s, tpot_s, cached_tokens
|
|
|
|
|
|
async def _open_streaming_session(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
server_url: str,
|
|
session_id: str,
|
|
request: TraceRequest,
|
|
) -> None:
|
|
capacity = max(
|
|
4096,
|
|
request.input_length * 16,
|
|
(request.input_length + request.output_length) * 16,
|
|
)
|
|
response = await client.post(
|
|
f"{server_url.rstrip('/')}/open_session",
|
|
json={
|
|
"capacity_of_str_len": capacity,
|
|
"session_id": session_id,
|
|
"streaming": True,
|
|
},
|
|
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
|
)
|
|
response.raise_for_status()
|
|
opened_session_id = response.json()
|
|
if opened_session_id != session_id:
|
|
raise ValueError(
|
|
f"Unexpected session id from {server_url}: {opened_session_id!r} != {session_id!r}"
|
|
)
|
|
|
|
|
|
async def _close_streaming_session(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
server_url: str,
|
|
session_id: str,
|
|
allow_missing: bool = False,
|
|
) -> None:
|
|
response = await client.post(
|
|
f"{server_url.rstrip('/')}/close_session",
|
|
json={"session_id": session_id},
|
|
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
|
)
|
|
if response.is_success:
|
|
return
|
|
if allow_missing:
|
|
response_text = response.text.lower()
|
|
if response.status_code == 404 or "does not exist" in response_text:
|
|
return
|
|
response.raise_for_status()
|
|
|
|
|
|
def _extract_internal_state(payload: dict[str, Any]) -> dict[str, Any]:
|
|
internal_states = payload.get("internal_states")
|
|
if isinstance(internal_states, list) and internal_states:
|
|
internal_state = internal_states[0]
|
|
if isinstance(internal_state, dict):
|
|
return internal_state
|
|
return payload
|
|
|
|
|
|
def _extract_server_int(
|
|
payload: dict[str, Any],
|
|
key: str,
|
|
) -> int:
|
|
internal_state = _extract_internal_state(payload)
|
|
value = payload.get(key, internal_state.get(key, 0))
|
|
return int(value or 0)
|
|
|
|
|
|
def _extract_session_cache(
|
|
payload: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
internal_state = _extract_internal_state(payload)
|
|
session_cache = internal_state.get("session_cache")
|
|
if isinstance(session_cache, dict):
|
|
return session_cache
|
|
return {}
|
|
|
|
|
|
def _find_session_cache_status(
|
|
session_cache: dict[str, Any],
|
|
session_id: str,
|
|
) -> dict[str, Any] | None:
|
|
sessions = session_cache.get("sessions")
|
|
if not isinstance(sessions, list):
|
|
return None
|
|
for session in sessions:
|
|
if isinstance(session, dict) and session.get("session_id") == session_id:
|
|
return session
|
|
return None
|
|
|
|
|
|
async def _fetch_decode_server_state(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
server_url: str,
|
|
) -> tuple[dict[str, Any], int, int]:
|
|
try:
|
|
response = await client.get(
|
|
f"{server_url.rstrip('/')}/server_info",
|
|
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
|
)
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
except Exception:
|
|
return {}, 0, 0
|
|
|
|
return (
|
|
_extract_session_cache(payload),
|
|
_extract_server_int(payload, "max_total_num_tokens"),
|
|
_extract_server_int(payload, "num_reserved_decode_tokens"),
|
|
)
|
|
|
|
|
|
async def _query_decode_direct_admission(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
server_url: str,
|
|
session_id: str,
|
|
uncached_input_tokens: int,
|
|
output_tokens: int,
|
|
) -> dict[str, Any]:
|
|
try:
|
|
response = await client.post(
|
|
f"{server_url.rstrip('/')}/session_cache/admit_direct_append",
|
|
json={
|
|
"session_id": session_id,
|
|
"uncached_input_tokens": max(0, uncached_input_tokens),
|
|
"output_tokens": max(0, output_tokens),
|
|
},
|
|
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
|
)
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
if isinstance(payload, dict):
|
|
return payload
|
|
except Exception:
|
|
pass
|
|
return {
|
|
"can_admit": False,
|
|
"resident": False,
|
|
"reason": "admission-query-failed",
|
|
"required_tokens": 0,
|
|
"available_tokens_before": 0,
|
|
"available_tokens_after": 0,
|
|
"evicted_session_count": 0,
|
|
"freed_tokens": 0,
|
|
}
|
|
|
|
|
|
async def _discover_decode_residency(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
config: ReplayConfig,
|
|
) -> DecodeResidencyState:
|
|
residency = DecodeResidencyState()
|
|
if config.mechanism_name != "kvcache-centric":
|
|
return residency
|
|
|
|
for worker in config.topology.decode_workers:
|
|
_session_cache, max_total_num_tokens, reserved_decode_tokens = (
|
|
await _fetch_decode_server_state(
|
|
client=client,
|
|
server_url=worker.url,
|
|
)
|
|
)
|
|
if max_total_num_tokens <= 0:
|
|
continue
|
|
|
|
safety_headroom = max(
|
|
reserved_decode_tokens * 4,
|
|
max_total_num_tokens // 20,
|
|
8192,
|
|
)
|
|
residency.capacity_tokens[worker.url] = max_total_num_tokens
|
|
residency.headroom_tokens[worker.url] = min(
|
|
max_total_num_tokens,
|
|
safety_headroom,
|
|
)
|
|
residency.reserved_decode_tokens[worker.url] = reserved_decode_tokens
|
|
|
|
for worker in config.topology.prefill_workers:
|
|
_session_cache, max_total_num_tokens, _reserved_decode_tokens = (
|
|
await _fetch_decode_server_state(
|
|
client=client,
|
|
server_url=worker.url,
|
|
)
|
|
)
|
|
if max_total_num_tokens <= 0:
|
|
continue
|
|
|
|
safety_headroom = max(
|
|
max_total_num_tokens // 10,
|
|
16384,
|
|
)
|
|
residency.prefill_capacity_tokens[worker.url] = max_total_num_tokens
|
|
residency.prefill_headroom_tokens[worker.url] = min(
|
|
max_total_num_tokens,
|
|
safety_headroom,
|
|
)
|
|
return residency
|
|
|
|
|
|
def _estimate_session_resident_tokens(request: TraceRequest) -> int:
|
|
return request.input_length + request.output_length
|
|
|
|
|
|
def _seed_filter_reason(
|
|
*,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
inflight_decode_load: int | None = None,
|
|
) -> str | None:
|
|
if request.turn_id < config.kvcache_seed_min_turn_id:
|
|
return "seed-filter-early-turn"
|
|
if (
|
|
config.kvcache_seed_max_inflight_decode is not None
|
|
and inflight_decode_load is not None
|
|
and inflight_decode_load > config.kvcache_seed_max_inflight_decode
|
|
):
|
|
return "seed-filter-inflight-decode-load"
|
|
if (
|
|
config.kvcache_seed_allowed_session_ids is not None
|
|
and request.session_id not in config.kvcache_seed_allowed_session_ids
|
|
):
|
|
return "seed-filter-single-turn-session"
|
|
resident_tokens = _estimate_session_resident_tokens(request)
|
|
if (
|
|
config.kvcache_seed_max_resident_tokens is not None
|
|
and resident_tokens > config.kvcache_seed_max_resident_tokens
|
|
):
|
|
return "seed-filter-resident-tokens"
|
|
if (
|
|
config.kvcache_seed_max_output_tokens is not None
|
|
and request.output_length > config.kvcache_seed_max_output_tokens
|
|
):
|
|
return "seed-filter-output-tokens"
|
|
return None
|
|
|
|
|
|
def _prefill_priority_for_router_request(
|
|
*,
|
|
config: ReplayConfig,
|
|
direct_to_d_predicted: bool,
|
|
) -> int | None:
|
|
if not config.kvcache_prefill_priority_eviction:
|
|
return None
|
|
if direct_to_d_predicted:
|
|
return config.kvcache_prefill_direct_priority
|
|
return config.kvcache_prefill_normal_priority
|
|
|
|
|
|
def _inspect_direct_request(
|
|
*,
|
|
request: TraceRequest,
|
|
session: DirectSessionState,
|
|
) -> tuple[int, bool, bool]:
|
|
previous = session.last_trace_request
|
|
if previous is None:
|
|
return request.input_length, False, False
|
|
|
|
append_length = request.input_length - (
|
|
previous.input_length + previous.output_length
|
|
)
|
|
if append_length <= 0:
|
|
return request.input_length, False, True
|
|
|
|
return append_length, True, False
|
|
|
|
|
|
def _add_reserved_tokens(
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
delta_tokens: int,
|
|
) -> None:
|
|
if delta_tokens <= 0:
|
|
return
|
|
residency.reserved_tokens_by_server[server_url] = (
|
|
residency.reserved_tokens_by_server.get(server_url, 0) + delta_tokens
|
|
)
|
|
|
|
|
|
def _release_reserved_tokens(
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
delta_tokens: int,
|
|
) -> None:
|
|
if delta_tokens <= 0:
|
|
return
|
|
remaining = residency.reserved_tokens_by_server.get(server_url, 0) - delta_tokens
|
|
if remaining > 0:
|
|
residency.reserved_tokens_by_server[server_url] = remaining
|
|
else:
|
|
residency.reserved_tokens_by_server.pop(server_url, None)
|
|
|
|
|
|
def _add_prefill_reserved_tokens(
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
delta_tokens: int,
|
|
) -> None:
|
|
if delta_tokens <= 0:
|
|
return
|
|
residency.prefill_reserved_tokens_by_server[server_url] = (
|
|
residency.prefill_reserved_tokens_by_server.get(server_url, 0) + delta_tokens
|
|
)
|
|
|
|
|
|
def _release_prefill_reserved_tokens(
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
delta_tokens: int,
|
|
) -> None:
|
|
if delta_tokens <= 0:
|
|
return
|
|
remaining = (
|
|
residency.prefill_reserved_tokens_by_server.get(server_url, 0) - delta_tokens
|
|
)
|
|
if remaining > 0:
|
|
residency.prefill_reserved_tokens_by_server[server_url] = remaining
|
|
else:
|
|
residency.prefill_reserved_tokens_by_server.pop(server_url, None)
|
|
|
|
|
|
def _usable_capacity_tokens(
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
) -> int:
|
|
return max(
|
|
0,
|
|
residency.capacity_tokens.get(server_url, 0)
|
|
- residency.headroom_tokens.get(server_url, 0),
|
|
)
|
|
|
|
|
|
def _usable_prefill_backup_capacity_tokens(
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
) -> int:
|
|
return max(
|
|
0,
|
|
residency.prefill_capacity_tokens.get(server_url, 0)
|
|
- residency.prefill_headroom_tokens.get(server_url, 0),
|
|
)
|
|
|
|
|
|
def _eviction_suffix(evicted_sessions: int, prefill_backed_evictions: int) -> str:
|
|
if evicted_sessions <= 0:
|
|
return ""
|
|
if prefill_backed_evictions >= evicted_sessions:
|
|
return "-after-prefill-backed-eviction"
|
|
if prefill_backed_evictions > 0:
|
|
return "-after-mixed-eviction"
|
|
return "-after-eviction"
|
|
|
|
|
|
def _decode_session_soft_cap(
|
|
*,
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
request: TraceRequest,
|
|
) -> int:
|
|
target_tokens = max(1, _estimate_session_resident_tokens(request))
|
|
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
|
|
if usable_capacity_tokens <= 0:
|
|
usable_capacity_tokens = max(
|
|
0,
|
|
residency.capacity_tokens.get(server_url, 0)
|
|
- residency.headroom_tokens.get(server_url, 0),
|
|
)
|
|
if usable_capacity_tokens <= 0:
|
|
return 4
|
|
return max(1, min(4, usable_capacity_tokens // target_tokens))
|
|
|
|
|
|
def _should_admit_new_decode_session(
|
|
*,
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
request: TraceRequest,
|
|
session: DirectSessionState,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
treat_as_fresh_session: bool,
|
|
) -> bool:
|
|
if (
|
|
not treat_as_fresh_session
|
|
and session.opened
|
|
and session.server_url == server_url
|
|
):
|
|
return True
|
|
open_sessions = sum(
|
|
1
|
|
for candidate in direct_sessions.values()
|
|
if candidate.opened and candidate.server_url == server_url
|
|
)
|
|
return open_sessions < _decode_session_soft_cap(
|
|
residency=residency,
|
|
server_url=server_url,
|
|
request=request,
|
|
)
|
|
|
|
|
|
async def _fetch_decode_load_snapshot(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
server_url: str,
|
|
) -> DecodeLoadSnapshot | None:
|
|
try:
|
|
response = await client.get(
|
|
f"{server_url.rstrip('/')}/v1/loads",
|
|
params={"include": "core,disagg"},
|
|
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
|
)
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
except Exception:
|
|
return None
|
|
|
|
loads = payload.get("loads")
|
|
if not isinstance(loads, list) or not loads:
|
|
return None
|
|
load = loads[0]
|
|
disagg = load.get("disaggregation") or {}
|
|
return DecodeLoadSnapshot(
|
|
timestamp_s=time.perf_counter(),
|
|
num_running_reqs=int(load.get("num_running_reqs", 0) or 0),
|
|
num_waiting_reqs=int(load.get("num_waiting_reqs", 0) or 0),
|
|
num_used_tokens=int(load.get("num_used_tokens", 0) or 0),
|
|
max_total_num_tokens=int(load.get("max_total_num_tokens", 0) or 0),
|
|
token_usage=float(load.get("token_usage", 0.0) or 0.0),
|
|
decode_prealloc_queue_reqs=int(
|
|
disagg.get("decode_prealloc_queue_reqs", 0) or 0
|
|
),
|
|
decode_transfer_queue_reqs=int(
|
|
disagg.get("decode_transfer_queue_reqs", 0) or 0
|
|
),
|
|
decode_retracted_queue_reqs=int(
|
|
disagg.get("decode_retracted_queue_reqs", 0) or 0
|
|
),
|
|
)
|
|
|
|
|
|
def _decode_load_backpressure_reason(
|
|
snapshot: DecodeLoadSnapshot | None,
|
|
*,
|
|
config: ReplayConfig,
|
|
routing_mode: Literal["direct", "seed"],
|
|
) -> str | None:
|
|
if snapshot is None:
|
|
return None
|
|
if routing_mode == "direct":
|
|
if (
|
|
config.kvcache_direct_max_decode_transfer_queue_reqs is not None
|
|
and snapshot.decode_transfer_queue_reqs
|
|
> config.kvcache_direct_max_decode_transfer_queue_reqs
|
|
):
|
|
return "d-transfer-backpressure"
|
|
if snapshot.decode_retracted_queue_reqs > 0 and snapshot.token_usage >= 0.99:
|
|
return "d-retracted"
|
|
if snapshot.token_usage >= 0.992:
|
|
return "d-token-usage-critical"
|
|
else:
|
|
if (
|
|
config.kvcache_seed_max_decode_transfer_queue_reqs is not None
|
|
and snapshot.decode_transfer_queue_reqs
|
|
> config.kvcache_seed_max_decode_transfer_queue_reqs
|
|
):
|
|
return "d-transfer-backpressure"
|
|
if snapshot.decode_retracted_queue_reqs > 0:
|
|
return "d-retracted"
|
|
if snapshot.token_usage >= 0.985:
|
|
return "d-token-usage-critical"
|
|
if routing_mode == "seed" and snapshot.token_usage >= 0.94 and (
|
|
snapshot.decode_prealloc_queue_reqs > 0
|
|
or snapshot.decode_transfer_queue_reqs > 0
|
|
):
|
|
return "d-prealloc-backpressure"
|
|
return None
|
|
|
|
|
|
def _is_decode_backpressure_reason(reason: str | None) -> bool:
|
|
return reason in {
|
|
"d-retracted",
|
|
"d-token-usage-critical",
|
|
"d-prealloc-backpressure",
|
|
"d-transfer-backpressure",
|
|
}
|
|
|
|
|
|
def _is_stale_decode_session_error(exc: Exception) -> bool:
|
|
return (
|
|
isinstance(exc, httpx.HTTPStatusError)
|
|
and exc.response.status_code == 400
|
|
)
|
|
|
|
|
|
def _dynamic_decode_headroom_tokens(
|
|
*,
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
snapshot: DecodeLoadSnapshot | None,
|
|
routing_mode: Literal["direct", "seed"],
|
|
) -> int:
|
|
if snapshot is None:
|
|
return residency.headroom_tokens.get(server_url, 0)
|
|
|
|
base_reserved = max(512, residency.reserved_decode_tokens.get(server_url, 0))
|
|
if routing_mode == "direct":
|
|
direct_queue_pressure = max(
|
|
1,
|
|
snapshot.decode_prealloc_queue_reqs
|
|
+ snapshot.decode_transfer_queue_reqs
|
|
+ snapshot.decode_retracted_queue_reqs,
|
|
)
|
|
capacity_divisor = 24
|
|
minimum_headroom = 4096
|
|
return max(
|
|
base_reserved * direct_queue_pressure,
|
|
snapshot.max_total_num_tokens // capacity_divisor,
|
|
minimum_headroom,
|
|
)
|
|
|
|
disagg_queued = (
|
|
snapshot.decode_prealloc_queue_reqs
|
|
+ snapshot.decode_transfer_queue_reqs
|
|
+ snapshot.decode_retracted_queue_reqs
|
|
)
|
|
active_decode_pressure = max(1, snapshot.num_running_reqs + disagg_queued)
|
|
capacity_divisor = 15
|
|
minimum_headroom = 12288
|
|
return max(
|
|
base_reserved * active_decode_pressure,
|
|
snapshot.max_total_num_tokens // capacity_divisor,
|
|
minimum_headroom,
|
|
)
|
|
|
|
|
|
def _commit_session_residency(
|
|
*,
|
|
residency: DecodeResidencyState,
|
|
session: DirectSessionState,
|
|
request: TraceRequest,
|
|
reserved_tokens: int,
|
|
) -> None:
|
|
_release_reserved_tokens(residency, session.server_url, reserved_tokens)
|
|
previous_tokens = session.resident_tokens if session.opened else 0
|
|
new_tokens = _estimate_session_resident_tokens(request)
|
|
delta_tokens = new_tokens - previous_tokens
|
|
if delta_tokens != 0:
|
|
residency.resident_tokens_by_server[session.server_url] = (
|
|
residency.resident_tokens_by_server.get(session.server_url, 0) + delta_tokens
|
|
)
|
|
session.opened = True
|
|
session.resident_tokens = new_tokens
|
|
session.last_trace_request = request
|
|
session.last_access_s = time.perf_counter()
|
|
if session.prefill_opened:
|
|
session.prefill_low_priority = True
|
|
|
|
|
|
def _commit_prefill_backup_residency(
|
|
*,
|
|
residency: DecodeResidencyState,
|
|
session: DirectSessionState,
|
|
request: TraceRequest,
|
|
prefill_url: str,
|
|
reserved_tokens: int,
|
|
) -> None:
|
|
_release_prefill_reserved_tokens(residency, prefill_url, reserved_tokens)
|
|
previous_tokens = session.prefill_resident_tokens if session.prefill_opened else 0
|
|
new_tokens = _estimate_session_resident_tokens(request)
|
|
delta_tokens = new_tokens - previous_tokens
|
|
if delta_tokens != 0:
|
|
residency.prefill_resident_tokens_by_server[prefill_url] = (
|
|
residency.prefill_resident_tokens_by_server.get(prefill_url, 0)
|
|
+ delta_tokens
|
|
)
|
|
session.prefill_server_url = prefill_url
|
|
session.prefill_opened = True
|
|
session.prefill_resident_tokens = new_tokens
|
|
session.prefill_last_access_s = time.perf_counter()
|
|
session.prefill_low_priority = session.opened
|
|
|
|
|
|
async def _close_prefill_session(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
session: DirectSessionState,
|
|
residency: DecodeResidencyState,
|
|
) -> None:
|
|
if not session.prefill_opened or session.prefill_server_url is None:
|
|
session.prefill_opened = False
|
|
session.prefill_resident_tokens = 0
|
|
session.prefill_low_priority = False
|
|
return
|
|
|
|
prefill_url = session.prefill_server_url
|
|
await _close_streaming_session(
|
|
client=client,
|
|
server_url=prefill_url,
|
|
session_id=session.session_id,
|
|
allow_missing=True,
|
|
)
|
|
remaining = (
|
|
residency.prefill_resident_tokens_by_server.get(prefill_url, 0)
|
|
- session.prefill_resident_tokens
|
|
)
|
|
if remaining > 0:
|
|
residency.prefill_resident_tokens_by_server[prefill_url] = remaining
|
|
else:
|
|
residency.prefill_resident_tokens_by_server.pop(prefill_url, None)
|
|
session.prefill_opened = False
|
|
session.prefill_resident_tokens = 0
|
|
session.prefill_low_priority = False
|
|
|
|
|
|
async def _close_decode_session(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
session: DirectSessionState,
|
|
residency: DecodeResidencyState,
|
|
evicting_for_capacity: bool = False,
|
|
) -> None:
|
|
if not session.opened:
|
|
session.resident_tokens = 0
|
|
return
|
|
await _close_streaming_session(
|
|
client=client,
|
|
server_url=session.server_url,
|
|
session_id=session.session_id,
|
|
allow_missing=True,
|
|
)
|
|
remaining = (
|
|
residency.resident_tokens_by_server.get(session.server_url, 0)
|
|
- session.resident_tokens
|
|
)
|
|
if remaining > 0:
|
|
residency.resident_tokens_by_server[session.server_url] = remaining
|
|
else:
|
|
residency.resident_tokens_by_server.pop(session.server_url, None)
|
|
session.opened = False
|
|
session.resident_tokens = 0
|
|
if session.prefill_opened:
|
|
residency.decode_evictions_prefill_backed += int(evicting_for_capacity)
|
|
session.prefill_low_priority = False
|
|
elif evicting_for_capacity:
|
|
residency.decode_evictions_without_prefill_backup += 1
|
|
|
|
|
|
async def _reserve_prefill_backup_capacity(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
prefill_url: str,
|
|
session: DirectSessionState,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
residency: DecodeResidencyState,
|
|
) -> tuple[bool, int, int]:
|
|
session_cache, max_total_num_tokens, _reserved_decode_tokens = (
|
|
await _fetch_decode_server_state(
|
|
client=client,
|
|
server_url=prefill_url,
|
|
)
|
|
)
|
|
if max_total_num_tokens > 0:
|
|
residency.prefill_capacity_tokens[prefill_url] = max_total_num_tokens
|
|
|
|
capacity_tokens = residency.prefill_capacity_tokens.get(prefill_url, 0)
|
|
headroom_tokens = residency.prefill_headroom_tokens.get(prefill_url, 0)
|
|
if capacity_tokens <= 0:
|
|
return True, 0, 0
|
|
low_occupancy_headroom_tokens = max(
|
|
headroom_tokens,
|
|
capacity_tokens // 2,
|
|
)
|
|
|
|
target_session_status = _find_session_cache_status(
|
|
session_cache,
|
|
session.session_id,
|
|
)
|
|
if (
|
|
isinstance(target_session_status, dict)
|
|
and bool(target_session_status.get("resident"))
|
|
):
|
|
current_tokens = int(target_session_status.get("resident_tokens", 0) or 0)
|
|
else:
|
|
current_tokens = (
|
|
session.prefill_resident_tokens
|
|
if session.prefill_opened and session.prefill_server_url == prefill_url
|
|
else 0
|
|
)
|
|
|
|
target_tokens = _estimate_session_resident_tokens(request)
|
|
required_extra_tokens = max(0, target_tokens - current_tokens)
|
|
evicted_sessions = 0
|
|
max_backup_sessions = max(1, capacity_tokens // max(1, target_tokens * 2))
|
|
max_backup_sessions = min(max_backup_sessions, 4)
|
|
available_tokens = int(session_cache.get("available_tokens", 0) or 0)
|
|
if available_tokens <= 0:
|
|
held_tokens = int(session_cache.get("held_tokens", 0) or 0)
|
|
available_tokens = max(0, capacity_tokens - held_tokens)
|
|
available_tokens -= residency.prefill_reserved_tokens_by_server.get(prefill_url, 0)
|
|
|
|
def has_enough_prefill_headroom() -> bool:
|
|
return available_tokens - required_extra_tokens >= low_occupancy_headroom_tokens
|
|
|
|
def prefill_backup_count() -> int:
|
|
return sum(
|
|
1
|
|
for candidate in direct_sessions.values()
|
|
if candidate.prefill_opened and candidate.prefill_server_url == prefill_url
|
|
)
|
|
|
|
while (
|
|
required_extra_tokens > 0
|
|
and (
|
|
not has_enough_prefill_headroom()
|
|
or (
|
|
not session.prefill_opened
|
|
and prefill_backup_count() >= max_backup_sessions
|
|
)
|
|
)
|
|
):
|
|
candidates = sorted(
|
|
(
|
|
candidate
|
|
for candidate in direct_sessions.values()
|
|
if candidate.prefill_opened
|
|
and candidate.prefill_server_url == prefill_url
|
|
and candidate.session_id != session.session_id
|
|
and candidate.active_requests <= 0
|
|
),
|
|
key=lambda candidate: (
|
|
0 if candidate.prefill_low_priority else 1,
|
|
candidate.prefill_last_access_s,
|
|
),
|
|
)
|
|
if not candidates:
|
|
break
|
|
freed_tokens = candidates[0].prefill_resident_tokens
|
|
await _close_prefill_session(
|
|
client=client,
|
|
session=candidates[0],
|
|
residency=residency,
|
|
)
|
|
available_tokens += freed_tokens
|
|
evicted_sessions += 1
|
|
|
|
if not has_enough_prefill_headroom():
|
|
return False, 0, evicted_sessions
|
|
|
|
_add_prefill_reserved_tokens(residency, prefill_url, required_extra_tokens)
|
|
return True, required_extra_tokens, evicted_sessions
|
|
|
|
|
|
async def _reserve_decode_session_capacity(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
config: ReplayConfig,
|
|
request: TraceRequest,
|
|
server_url: str,
|
|
session: DirectSessionState,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
residency: DecodeResidencyState,
|
|
treat_as_fresh_session: bool,
|
|
routing_mode: Literal["direct", "seed"],
|
|
admission_mode: KvCacheAdmissionMode,
|
|
) -> tuple[bool, int, int, int, str | None]:
|
|
if admission_mode == "router":
|
|
return await _reserve_decode_session_capacity_from_router_state(
|
|
client=client,
|
|
config=config,
|
|
request=request,
|
|
server_url=server_url,
|
|
session=session,
|
|
direct_sessions=direct_sessions,
|
|
residency=residency,
|
|
treat_as_fresh_session=treat_as_fresh_session,
|
|
routing_mode=routing_mode,
|
|
)
|
|
|
|
if treat_as_fresh_session and session.opened:
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=session,
|
|
residency=residency,
|
|
)
|
|
|
|
current_tokens = 0 if treat_as_fresh_session else session.resident_tokens
|
|
target_tokens = _estimate_session_resident_tokens(request)
|
|
required_extra_tokens = max(0, target_tokens - current_tokens)
|
|
prefill_backed_evictions = 0
|
|
if routing_mode == "direct" and not treat_as_fresh_session:
|
|
if not session.opened:
|
|
return False, 0, 0, 0, "d-session-not-resident"
|
|
admission = await _query_decode_direct_admission(
|
|
client=client,
|
|
server_url=server_url,
|
|
session_id=session.session_id,
|
|
uncached_input_tokens=max(0, request.input_length - current_tokens),
|
|
output_tokens=request.output_length,
|
|
)
|
|
if not bool(admission.get("resident")):
|
|
return False, 0, 0, 0, str(admission.get("reason") or "d-session-not-resident")
|
|
if not bool(admission.get("can_admit")):
|
|
return (
|
|
False,
|
|
0,
|
|
int(admission.get("evicted_session_count", 0) or 0),
|
|
0,
|
|
str(admission.get("reason") or "d-no-space"),
|
|
)
|
|
reserved_tokens = int(
|
|
admission.get("required_tokens", required_extra_tokens)
|
|
or required_extra_tokens
|
|
)
|
|
_add_reserved_tokens(residency, server_url, reserved_tokens)
|
|
return (
|
|
True,
|
|
reserved_tokens,
|
|
int(admission.get("evicted_session_count", 0) or 0),
|
|
0,
|
|
None,
|
|
)
|
|
|
|
session_cache, max_total_num_tokens, reserved_decode_tokens = (
|
|
await _fetch_decode_server_state(
|
|
client=client,
|
|
server_url=server_url,
|
|
)
|
|
)
|
|
if max_total_num_tokens > 0:
|
|
residency.capacity_tokens[server_url] = max_total_num_tokens
|
|
if reserved_decode_tokens > 0:
|
|
residency.reserved_decode_tokens[server_url] = reserved_decode_tokens
|
|
|
|
target_session_status = _find_session_cache_status(
|
|
session_cache,
|
|
session.session_id,
|
|
)
|
|
if routing_mode == "direct" and not (
|
|
isinstance(target_session_status, dict)
|
|
and bool(target_session_status.get("resident"))
|
|
):
|
|
return False, 0, 0, 0, "d-session-not-resident"
|
|
|
|
load_snapshot = await _fetch_decode_load_snapshot(
|
|
client=client,
|
|
server_url=server_url,
|
|
)
|
|
if load_snapshot is not None and load_snapshot.max_total_num_tokens > 0:
|
|
residency.capacity_tokens[server_url] = load_snapshot.max_total_num_tokens
|
|
backpressure_reason = _decode_load_backpressure_reason(
|
|
load_snapshot,
|
|
config=config,
|
|
routing_mode=routing_mode,
|
|
)
|
|
if backpressure_reason is not None:
|
|
return False, 0, 0, 0, backpressure_reason
|
|
|
|
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
|
|
evicted_sessions = 0
|
|
while (
|
|
required_extra_tokens > 0
|
|
and residency.resident_tokens_by_server.get(server_url, 0)
|
|
+ residency.reserved_tokens_by_server.get(server_url, 0)
|
|
+ required_extra_tokens
|
|
> usable_capacity_tokens + int(session_cache.get("idle_evictable_tokens", 0) or 0)
|
|
):
|
|
candidates = sorted(
|
|
(
|
|
candidate
|
|
for candidate in direct_sessions.values()
|
|
if candidate.opened
|
|
and candidate.server_url == server_url
|
|
and candidate.session_id != session.session_id
|
|
and candidate.active_requests <= 0
|
|
),
|
|
key=lambda candidate: candidate.last_access_s,
|
|
)
|
|
if not candidates:
|
|
break
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=candidates[0],
|
|
residency=residency,
|
|
evicting_for_capacity=True,
|
|
)
|
|
prefill_backed_evictions += int(candidates[0].prefill_opened)
|
|
evicted_sessions += 1
|
|
|
|
if evicted_sessions > 0:
|
|
load_snapshot = await _fetch_decode_load_snapshot(
|
|
client=client,
|
|
server_url=server_url,
|
|
)
|
|
session_cache, max_total_num_tokens, reserved_decode_tokens = (
|
|
await _fetch_decode_server_state(
|
|
client=client,
|
|
server_url=server_url,
|
|
)
|
|
)
|
|
if max_total_num_tokens > 0:
|
|
residency.capacity_tokens[server_url] = max_total_num_tokens
|
|
if reserved_decode_tokens > 0:
|
|
residency.reserved_decode_tokens[server_url] = reserved_decode_tokens
|
|
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
|
|
if load_snapshot is not None:
|
|
dynamic_headroom = _dynamic_decode_headroom_tokens(
|
|
residency=residency,
|
|
server_url=server_url,
|
|
snapshot=load_snapshot,
|
|
routing_mode=routing_mode,
|
|
)
|
|
residency.headroom_tokens[server_url] = min(
|
|
residency.capacity_tokens.get(server_url, dynamic_headroom),
|
|
dynamic_headroom,
|
|
)
|
|
usable_capacity_tokens = max(
|
|
0,
|
|
residency.capacity_tokens.get(server_url, 0) - dynamic_headroom,
|
|
)
|
|
|
|
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
|
|
if load_snapshot is not None:
|
|
dynamic_headroom = _dynamic_decode_headroom_tokens(
|
|
residency=residency,
|
|
server_url=server_url,
|
|
snapshot=load_snapshot,
|
|
routing_mode=routing_mode,
|
|
)
|
|
residency.headroom_tokens[server_url] = min(
|
|
residency.capacity_tokens.get(server_url, dynamic_headroom),
|
|
dynamic_headroom,
|
|
)
|
|
usable_capacity_tokens = max(
|
|
0,
|
|
residency.capacity_tokens.get(server_url, 0) - dynamic_headroom,
|
|
)
|
|
|
|
effective_used_tokens = (
|
|
load_snapshot.num_used_tokens
|
|
if load_snapshot is not None
|
|
else residency.resident_tokens_by_server.get(server_url, 0)
|
|
) + residency.reserved_tokens_by_server.get(server_url, 0)
|
|
idle_evictable_tokens = int(session_cache.get("idle_evictable_tokens", 0) or 0)
|
|
if (
|
|
routing_mode == "direct"
|
|
and isinstance(target_session_status, dict)
|
|
and bool(target_session_status.get("idle_evictable"))
|
|
):
|
|
idle_evictable_tokens = max(
|
|
0,
|
|
idle_evictable_tokens
|
|
- int(target_session_status.get("resident_tokens", 0) or 0),
|
|
)
|
|
|
|
if effective_used_tokens + required_extra_tokens > (
|
|
usable_capacity_tokens + idle_evictable_tokens
|
|
):
|
|
return False, 0, evicted_sessions, prefill_backed_evictions, "d-no-space"
|
|
|
|
_add_reserved_tokens(residency, server_url, required_extra_tokens)
|
|
return True, required_extra_tokens, evicted_sessions, prefill_backed_evictions, None
|
|
|
|
|
|
async def _reserve_decode_session_capacity_from_router_state(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
config: ReplayConfig,
|
|
request: TraceRequest,
|
|
server_url: str,
|
|
session: DirectSessionState,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
residency: DecodeResidencyState,
|
|
treat_as_fresh_session: bool,
|
|
routing_mode: Literal["direct", "seed"],
|
|
) -> tuple[bool, int, int, int, str | None]:
|
|
if treat_as_fresh_session and session.opened:
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=session,
|
|
residency=residency,
|
|
)
|
|
|
|
if routing_mode == "direct" and not session.opened:
|
|
return False, 0, 0, 0, "d-session-not-resident"
|
|
|
|
current_tokens = 0 if treat_as_fresh_session else session.resident_tokens
|
|
target_tokens = _estimate_session_resident_tokens(request)
|
|
required_extra_tokens = max(0, target_tokens - current_tokens)
|
|
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
|
|
|
|
# If discovery failed, do not force every request down the P/D fallback path.
|
|
# The router can still preserve correctness; this only disables proactive
|
|
# capacity admission until the worker reports capacity again in a later run.
|
|
if usable_capacity_tokens <= 0:
|
|
_add_reserved_tokens(residency, server_url, required_extra_tokens)
|
|
return True, required_extra_tokens, 0, 0, None
|
|
|
|
evicted_sessions = 0
|
|
prefill_backed_evictions = 0
|
|
while (
|
|
required_extra_tokens > 0
|
|
and residency.resident_tokens_by_server.get(server_url, 0)
|
|
+ residency.reserved_tokens_by_server.get(server_url, 0)
|
|
+ required_extra_tokens
|
|
> usable_capacity_tokens
|
|
):
|
|
candidates = sorted(
|
|
(
|
|
candidate
|
|
for candidate in direct_sessions.values()
|
|
if candidate.opened
|
|
and candidate.server_url == server_url
|
|
and candidate.session_id != session.session_id
|
|
and candidate.active_requests <= 0
|
|
),
|
|
key=lambda candidate: candidate.last_access_s,
|
|
)
|
|
if not candidates:
|
|
break
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=candidates[0],
|
|
residency=residency,
|
|
evicting_for_capacity=True,
|
|
)
|
|
prefill_backed_evictions += int(candidates[0].prefill_opened)
|
|
evicted_sessions += 1
|
|
|
|
if (
|
|
residency.resident_tokens_by_server.get(server_url, 0)
|
|
+ residency.reserved_tokens_by_server.get(server_url, 0)
|
|
+ required_extra_tokens
|
|
> usable_capacity_tokens
|
|
):
|
|
return False, 0, evicted_sessions, prefill_backed_evictions, "d-no-space"
|
|
|
|
_add_reserved_tokens(residency, server_url, required_extra_tokens)
|
|
return True, required_extra_tokens, evicted_sessions, prefill_backed_evictions, None
|
|
|
|
|
|
def _build_direct_prompt(
|
|
*,
|
|
request: TraceRequest,
|
|
session: DirectSessionState,
|
|
) -> tuple[str, int, bool, bool]:
|
|
append_length, session_reused, session_reset = _inspect_direct_request(
|
|
request=request,
|
|
session=session,
|
|
)
|
|
if session_reset:
|
|
return build_synthetic_prompt(request), request.input_length, False, True
|
|
if not session_reused:
|
|
return build_synthetic_prompt(request), request.input_length, False, False
|
|
|
|
return (
|
|
build_synthetic_append_chunk(request, append_length),
|
|
append_length,
|
|
session_reused,
|
|
session_reset,
|
|
)
|
|
|
|
|
|
def _build_direct_full_input_ids(
|
|
request: TraceRequest,
|
|
*,
|
|
block_token_budget: int = 24,
|
|
) -> list[int]:
|
|
input_ids: list[int] = []
|
|
for hash_id in request.hash_ids:
|
|
base = 1000 + (hash_id % 3000)
|
|
for offset in range(block_token_budget):
|
|
input_ids.append(1000 + ((base + offset) % 3000))
|
|
|
|
while len(input_ids) < request.input_length:
|
|
input_ids.append(5000 + (len(input_ids) % 2000))
|
|
|
|
return input_ids[: request.input_length]
|
|
|
|
|
|
def _build_direct_append_input_ids(
|
|
request: TraceRequest,
|
|
append_length: int,
|
|
) -> list[int]:
|
|
if append_length <= 0:
|
|
return []
|
|
|
|
base = 9000 + ((request.chat_id + request.turn_id) % 2000)
|
|
return [
|
|
9000 + ((base + offset) % 2000)
|
|
for offset in range(append_length)
|
|
]
|
|
|
|
|
|
async def _invoke_plain_router(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decision,
|
|
execution_mode: str,
|
|
) -> ExecutionResult:
|
|
prefill_priority = _prefill_priority_for_router_request(
|
|
config=config,
|
|
direct_to_d_predicted=False,
|
|
)
|
|
latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decode_worker_index=decision.decode_worker_index,
|
|
prefill_request_priority=prefill_priority,
|
|
)
|
|
return ExecutionResult(
|
|
execution_mode=execution_mode,
|
|
actual_kv_transfer_blocks=decision.kv_transfer_blocks,
|
|
effective_input_length=request.input_length,
|
|
cached_tokens=cached_tokens,
|
|
prefill_request_priority=prefill_priority,
|
|
session_reused=False,
|
|
session_reset=False,
|
|
latency_s=latency_s,
|
|
ttft_s=ttft_s,
|
|
tpot_s=tpot_s,
|
|
)
|
|
|
|
|
|
async def _invoke_kvcache_seeded_router(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decision,
|
|
prefill_url: str,
|
|
decode_session: DirectSessionState,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
direct_session_lock: asyncio.Lock,
|
|
decode_residency: DecodeResidencyState,
|
|
reserved_tokens: int,
|
|
execution_mode: str,
|
|
) -> ExecutionResult:
|
|
keep_prefill_backup = False
|
|
prefill_reserved_tokens = 0
|
|
async with direct_session_lock:
|
|
if config.kvcache_prefill_backup_policy == "capacity-backup":
|
|
keep_prefill_backup, prefill_reserved_tokens, _prefill_evicted = (
|
|
await _reserve_prefill_backup_capacity(
|
|
client=client,
|
|
request=request,
|
|
prefill_url=prefill_url,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
residency=decode_residency,
|
|
)
|
|
)
|
|
if (
|
|
decode_session.prefill_opened
|
|
and decode_session.prefill_server_url != prefill_url
|
|
):
|
|
await _close_prefill_session(
|
|
client=client,
|
|
session=decode_session,
|
|
residency=decode_residency,
|
|
)
|
|
|
|
prefill_session_newly_opened = False
|
|
async with direct_session_lock:
|
|
if not decode_session.prefill_opened:
|
|
await _open_streaming_session(
|
|
client=client,
|
|
server_url=prefill_url,
|
|
session_id=request.session_id,
|
|
request=request,
|
|
)
|
|
decode_session.prefill_opened = True
|
|
decode_session.prefill_server_url = prefill_url
|
|
prefill_session_newly_opened = True
|
|
|
|
decode_session_newly_opened = False
|
|
try:
|
|
prefill_priority = _prefill_priority_for_router_request(
|
|
config=config,
|
|
direct_to_d_predicted=True,
|
|
)
|
|
async with direct_session_lock:
|
|
if not decode_session.opened:
|
|
await _open_streaming_session(
|
|
client=client,
|
|
server_url=decode_session.server_url,
|
|
session_id=request.session_id,
|
|
request=request,
|
|
)
|
|
decode_session.opened = True
|
|
decode_session_newly_opened = True
|
|
decode_session.active_requests += 1
|
|
latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decode_worker_index=decision.decode_worker_index,
|
|
session_id=request.session_id,
|
|
prefill_request_priority=prefill_priority,
|
|
)
|
|
except Exception:
|
|
async with direct_session_lock:
|
|
decode_session.active_requests = max(0, decode_session.active_requests - 1)
|
|
_release_reserved_tokens(
|
|
decode_residency,
|
|
decode_session.server_url,
|
|
reserved_tokens,
|
|
)
|
|
_release_prefill_reserved_tokens(
|
|
decode_residency,
|
|
prefill_url,
|
|
prefill_reserved_tokens,
|
|
)
|
|
if decode_session_newly_opened:
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=decode_session,
|
|
residency=decode_residency,
|
|
)
|
|
if prefill_session_newly_opened:
|
|
await _close_prefill_session(
|
|
client=client,
|
|
session=decode_session,
|
|
residency=decode_residency,
|
|
)
|
|
raise
|
|
|
|
async with direct_session_lock:
|
|
decode_session.active_requests = max(0, decode_session.active_requests - 1)
|
|
if keep_prefill_backup:
|
|
_commit_prefill_backup_residency(
|
|
residency=decode_residency,
|
|
session=decode_session,
|
|
request=request,
|
|
prefill_url=prefill_url,
|
|
reserved_tokens=prefill_reserved_tokens,
|
|
)
|
|
else:
|
|
_release_prefill_reserved_tokens(
|
|
decode_residency,
|
|
prefill_url,
|
|
prefill_reserved_tokens,
|
|
)
|
|
await _close_prefill_session(
|
|
client=client,
|
|
session=decode_session,
|
|
residency=decode_residency,
|
|
)
|
|
_commit_session_residency(
|
|
residency=decode_residency,
|
|
session=decode_session,
|
|
request=request,
|
|
reserved_tokens=reserved_tokens,
|
|
)
|
|
return ExecutionResult(
|
|
execution_mode=execution_mode,
|
|
actual_kv_transfer_blocks=decision.kv_transfer_blocks,
|
|
effective_input_length=request.input_length,
|
|
cached_tokens=cached_tokens,
|
|
prefill_request_priority=prefill_priority,
|
|
session_reused=False,
|
|
session_reset=False,
|
|
latency_s=latency_s,
|
|
ttft_s=ttft_s,
|
|
tpot_s=tpot_s,
|
|
)
|
|
|
|
|
|
async def _execute_request(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decision,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
direct_session_lock: asyncio.Lock,
|
|
decode_residency: DecodeResidencyState,
|
|
) -> ExecutionResult:
|
|
if config.mechanism_name == "pd-disaggregation":
|
|
if not config.router_url:
|
|
raise ValueError("router_url is required for pd-disaggregation replay")
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode="pd-disaggregation-router",
|
|
)
|
|
|
|
if config.mechanism_name == "pd-colo":
|
|
return await _invoke_direct(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
direct_sessions=direct_sessions,
|
|
direct_session_lock=direct_session_lock,
|
|
)
|
|
|
|
if config.mechanism_name == "kvcache-centric":
|
|
if not config.router_url:
|
|
raise ValueError("router_url is required for kvcache-centric replay")
|
|
if not config.topology.decode_workers:
|
|
raise ValueError("kvcache-centric mechanism requires at least one decode worker")
|
|
prefill_url = _worker_url_by_id(
|
|
config.topology.prefill_workers,
|
|
decision.prefill_worker_id,
|
|
)
|
|
decode_url = config.topology.decode_workers[decision.decode_worker_index].url
|
|
async with direct_session_lock:
|
|
decode_session = direct_sessions.get(request.session_id)
|
|
if decode_session is None:
|
|
decode_session = DirectSessionState(
|
|
session_id=request.session_id,
|
|
server_url=decode_url,
|
|
)
|
|
direct_sessions[request.session_id] = decode_session
|
|
elif decode_session.server_url != decode_url and decode_session.opened:
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=decode_session,
|
|
residency=decode_residency,
|
|
)
|
|
decode_session.server_url = decode_url
|
|
else:
|
|
decode_session.server_url = decode_url
|
|
|
|
direct_append_length: int | None = None
|
|
direct_session_reused = False
|
|
direct_session_reset = False
|
|
if request.turn_id > 1:
|
|
async with direct_session_lock:
|
|
(
|
|
direct_append_length,
|
|
direct_session_reused,
|
|
direct_session_reset,
|
|
) = _inspect_direct_request(
|
|
request=request,
|
|
session=decode_session,
|
|
)
|
|
|
|
if request.turn_id == 1:
|
|
seed_filter_reason = _seed_filter_reason(
|
|
request=request,
|
|
config=config,
|
|
inflight_decode_load=decision.inflight_decode_load,
|
|
)
|
|
if seed_filter_reason is not None:
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode=f"pd-router-turn1-{seed_filter_reason}",
|
|
)
|
|
async with direct_session_lock:
|
|
admit_new_decode_session = _should_admit_new_decode_session(
|
|
residency=decode_residency,
|
|
server_url=decode_url,
|
|
request=request,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
treat_as_fresh_session=True,
|
|
)
|
|
if not admit_new_decode_session:
|
|
can_seed = False
|
|
reserved_tokens = 0
|
|
seed_reason = "d-session-cap"
|
|
else:
|
|
can_seed, reserved_tokens, _evicted, _p_backed, seed_reason = (
|
|
await _reserve_decode_session_capacity(
|
|
client=client,
|
|
config=config,
|
|
request=request,
|
|
server_url=decode_url,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
residency=decode_residency,
|
|
treat_as_fresh_session=True,
|
|
routing_mode="seed",
|
|
admission_mode=config.kvcache_admission_mode,
|
|
)
|
|
)
|
|
if seed_reason == "d-session-cap":
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode="pd-router-turn1-session-cap",
|
|
)
|
|
if can_seed:
|
|
return await _invoke_kvcache_seeded_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
prefill_url=prefill_url,
|
|
decode_session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
direct_session_lock=direct_session_lock,
|
|
decode_residency=decode_residency,
|
|
reserved_tokens=reserved_tokens,
|
|
execution_mode="pd-router-turn1-seed",
|
|
)
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode=(
|
|
"pd-router-turn1-d-backpressure"
|
|
if seed_reason is not None and seed_reason != "d-no-space"
|
|
else "pd-router-turn1-no-d-capacity"
|
|
),
|
|
)
|
|
|
|
if (
|
|
_should_bypass_prefill(
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
)
|
|
and direct_append_length is not None
|
|
and direct_session_reused
|
|
and not direct_session_reset
|
|
and direct_append_length <= config.kvcache_direct_max_uncached_tokens
|
|
):
|
|
async with direct_session_lock:
|
|
can_direct = (
|
|
decode_session.opened
|
|
and decode_session.server_url == decode_url
|
|
and direct_session_reused
|
|
and not direct_session_reset
|
|
)
|
|
direct_reserved_tokens = 0
|
|
direct_reason: str | None = None
|
|
if can_direct:
|
|
async with direct_session_lock:
|
|
(
|
|
can_direct,
|
|
direct_reserved_tokens,
|
|
_evicted,
|
|
_p_backed,
|
|
direct_reason,
|
|
) = (
|
|
await _reserve_decode_session_capacity(
|
|
client=client,
|
|
config=config,
|
|
request=request,
|
|
server_url=decode_url,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
residency=decode_residency,
|
|
treat_as_fresh_session=False,
|
|
routing_mode="direct",
|
|
admission_mode=config.kvcache_admission_mode,
|
|
)
|
|
)
|
|
if can_direct:
|
|
try:
|
|
return await _invoke_decode_session_direct(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
direct_sessions=direct_sessions,
|
|
direct_session_lock=direct_session_lock,
|
|
decode_residency=decode_residency,
|
|
reserved_tokens=direct_reserved_tokens,
|
|
)
|
|
except Exception as exc:
|
|
if not _is_stale_decode_session_error(exc):
|
|
raise
|
|
async with direct_session_lock:
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=decode_session,
|
|
residency=decode_residency,
|
|
)
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode="pd-router-fallback-stale-d-session",
|
|
)
|
|
if _is_decode_backpressure_reason(direct_reason):
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode="pd-router-fallback-d-backpressure",
|
|
)
|
|
|
|
seed_filter_reason = _seed_filter_reason(
|
|
request=request,
|
|
config=config,
|
|
inflight_decode_load=decision.inflight_decode_load,
|
|
)
|
|
if seed_filter_reason is not None:
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode=f"pd-router-fallback-{seed_filter_reason}",
|
|
)
|
|
async with direct_session_lock:
|
|
admit_new_decode_session = _should_admit_new_decode_session(
|
|
residency=decode_residency,
|
|
server_url=decode_url,
|
|
request=request,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
treat_as_fresh_session=True,
|
|
)
|
|
if not admit_new_decode_session:
|
|
can_seed = False
|
|
reserved_tokens = 0
|
|
evicted_sessions = 0
|
|
prefill_backed_evictions = 0
|
|
seed_reason = "d-session-cap"
|
|
else:
|
|
(
|
|
can_seed,
|
|
reserved_tokens,
|
|
evicted_sessions,
|
|
prefill_backed_evictions,
|
|
seed_reason,
|
|
) = (
|
|
await _reserve_decode_session_capacity(
|
|
client=client,
|
|
config=config,
|
|
request=request,
|
|
server_url=decode_url,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
residency=decode_residency,
|
|
treat_as_fresh_session=True,
|
|
routing_mode="seed",
|
|
admission_mode=config.kvcache_admission_mode,
|
|
)
|
|
)
|
|
if seed_reason == "d-session-cap":
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode="pd-router-fallback-session-cap",
|
|
)
|
|
if can_seed:
|
|
return await _invoke_kvcache_seeded_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
prefill_url=prefill_url,
|
|
decode_session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
direct_session_lock=direct_session_lock,
|
|
decode_residency=decode_residency,
|
|
reserved_tokens=reserved_tokens,
|
|
execution_mode=(
|
|
"pd-router-d-session-reseed"
|
|
+ _eviction_suffix(
|
|
evicted_sessions,
|
|
prefill_backed_evictions,
|
|
)
|
|
),
|
|
)
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode=(
|
|
"pd-router-fallback-d-backpressure"
|
|
if _is_decode_backpressure_reason(seed_reason)
|
|
else "pd-router-fallback-no-d-capacity"
|
|
),
|
|
)
|
|
|
|
seed_filter_reason = _seed_filter_reason(
|
|
request=request,
|
|
config=config,
|
|
inflight_decode_load=decision.inflight_decode_load,
|
|
)
|
|
if seed_filter_reason is not None:
|
|
return await _invoke_plain_router(
|
|
request=request,
|
|
client=client,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode=f"pd-router-fallback-large-append-{seed_filter_reason}",
|
|
)
|
|
async with direct_session_lock:
|
|
admit_new_decode_session = _should_admit_new_decode_session(
|
|
residency=decode_residency,
|
|
server_url=decode_url,
|
|
request=request,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
treat_as_fresh_session=True,
|
|
)
|
|
if not admit_new_decode_session:
|
|
can_seed = False
|
|
reserved_tokens = 0
|
|
evicted_sessions = 0
|
|
prefill_backed_evictions = 0
|
|
seed_reason = "d-session-cap"
|
|
else:
|
|
(
|
|
can_seed,
|
|
reserved_tokens,
|
|
evicted_sessions,
|
|
prefill_backed_evictions,
|
|
seed_reason,
|
|
) = (
|
|
await _reserve_decode_session_capacity(
|
|
client=client,
|
|
config=config,
|
|
request=request,
|
|
server_url=decode_url,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
residency=decode_residency,
|
|
treat_as_fresh_session=True,
|
|
routing_mode="seed",
|
|
admission_mode=config.kvcache_admission_mode,
|
|
)
|
|
)
|
|
if seed_reason == "d-session-cap":
|
|
return await _invoke_plain_router(
|
|
request=request,
|
|
client=client,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode="pd-router-fallback-large-append-session-cap",
|
|
)
|
|
if can_seed:
|
|
return await _invoke_kvcache_seeded_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
prefill_url=prefill_url,
|
|
decode_session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
direct_session_lock=direct_session_lock,
|
|
decode_residency=decode_residency,
|
|
reserved_tokens=reserved_tokens,
|
|
execution_mode=(
|
|
"pd-router-large-append-reseed"
|
|
+ _eviction_suffix(
|
|
evicted_sessions,
|
|
prefill_backed_evictions,
|
|
)
|
|
),
|
|
)
|
|
return await _invoke_plain_router(
|
|
request=request,
|
|
client=client,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode=(
|
|
"pd-router-fallback-d-backpressure"
|
|
if _is_decode_backpressure_reason(seed_reason)
|
|
else "pd-router-fallback-large-append"
|
|
),
|
|
)
|
|
|
|
raise ValueError(f"Unsupported mechanism: {config.mechanism_name}")
|
|
|
|
|
|
async def _invoke_direct(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decision,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
direct_session_lock: asyncio.Lock,
|
|
) -> ExecutionResult:
|
|
direct_workers = config.topology.direct_workers
|
|
if not direct_workers:
|
|
raise ValueError("pd-colo mechanism requires at least one direct worker")
|
|
|
|
server_url = direct_workers[decision.decode_worker_index].url
|
|
session = direct_sessions.get(request.session_id)
|
|
if session is None or session.server_url != server_url:
|
|
session = DirectSessionState(session_id=request.session_id, server_url=server_url)
|
|
direct_sessions[request.session_id] = session
|
|
|
|
return await _invoke_session_direct(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
session=session,
|
|
execution_mode="pd-colo-direct-session",
|
|
direct_session_lock=direct_session_lock,
|
|
)
|
|
|
|
|
|
async def _invoke_session_direct(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
session: DirectSessionState,
|
|
execution_mode: str,
|
|
decode_residency: DecodeResidencyState | None = None,
|
|
reserved_tokens: int = 0,
|
|
direct_session_lock: asyncio.Lock | None = None,
|
|
) -> ExecutionResult:
|
|
_prompt, effective_input_length, session_reused, session_reset = _build_direct_prompt(
|
|
request=request,
|
|
session=session,
|
|
)
|
|
if session_reused:
|
|
input_ids = _build_direct_append_input_ids(request, effective_input_length)
|
|
else:
|
|
input_ids = _build_direct_full_input_ids(request)
|
|
if session.opened and (session_reset or not session_reused):
|
|
if decode_residency is not None:
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=session,
|
|
residency=decode_residency,
|
|
)
|
|
else:
|
|
await _close_streaming_session(
|
|
client=client,
|
|
server_url=session.server_url,
|
|
session_id=session.session_id,
|
|
allow_missing=True,
|
|
)
|
|
session.opened = False
|
|
session.resident_tokens = 0
|
|
if not session.opened:
|
|
await _open_streaming_session(
|
|
client=client,
|
|
server_url=session.server_url,
|
|
session_id=session.session_id,
|
|
request=request,
|
|
)
|
|
session.opened = True
|
|
if direct_session_lock is not None:
|
|
async with direct_session_lock:
|
|
session.active_requests += 1
|
|
|
|
try:
|
|
latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_generate(
|
|
client=client,
|
|
base_url=session.server_url,
|
|
headers={"x-request-id": request.request_id},
|
|
payload={
|
|
"input_ids": input_ids,
|
|
"sampling_params": {
|
|
"temperature": 0,
|
|
"max_new_tokens": max(1, request.output_length),
|
|
"min_new_tokens": max(1, request.output_length),
|
|
"ignore_eos": True,
|
|
"no_stop_trim": True,
|
|
"skip_special_tokens": False,
|
|
},
|
|
"session_params": {"id": session.session_id},
|
|
"stream": config.stream,
|
|
},
|
|
timeout_s=config.timeout_s,
|
|
stream_idle_timeout_s=config.stream_idle_timeout_s,
|
|
stream=config.stream,
|
|
)
|
|
finally:
|
|
if direct_session_lock is not None:
|
|
async with direct_session_lock:
|
|
session.active_requests = max(0, session.active_requests - 1)
|
|
if decode_residency is not None:
|
|
_commit_session_residency(
|
|
residency=decode_residency,
|
|
session=session,
|
|
request=request,
|
|
reserved_tokens=reserved_tokens,
|
|
)
|
|
else:
|
|
session.last_trace_request = request
|
|
session.last_access_s = time.perf_counter()
|
|
return ExecutionResult(
|
|
execution_mode=execution_mode,
|
|
actual_kv_transfer_blocks=0,
|
|
effective_input_length=len(input_ids),
|
|
cached_tokens=cached_tokens,
|
|
session_reused=session_reused,
|
|
session_reset=session_reset,
|
|
latency_s=latency_s,
|
|
ttft_s=ttft_s,
|
|
tpot_s=tpot_s,
|
|
)
|
|
|
|
|
|
async def _invoke_decode_session_direct(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decision,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
direct_session_lock: asyncio.Lock,
|
|
decode_residency: DecodeResidencyState,
|
|
reserved_tokens: int,
|
|
) -> ExecutionResult:
|
|
decode_url = config.topology.decode_workers[decision.decode_worker_index].url
|
|
session = direct_sessions[request.session_id]
|
|
try:
|
|
return await _invoke_session_direct(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
session=session,
|
|
execution_mode="kvcache-direct-to-d-session",
|
|
decode_residency=decode_residency,
|
|
reserved_tokens=reserved_tokens,
|
|
direct_session_lock=direct_session_lock,
|
|
)
|
|
except Exception:
|
|
async with direct_session_lock:
|
|
_release_reserved_tokens(
|
|
decode_residency,
|
|
decode_url,
|
|
reserved_tokens,
|
|
)
|
|
raise
|
|
|
|
|
|
def _should_bypass_prefill(
|
|
*,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decision,
|
|
block_token_budget: int = 24,
|
|
) -> bool:
|
|
if request.turn_id <= 1:
|
|
return False
|
|
if decision.observed_overlap_blocks <= 0:
|
|
return False
|
|
uncached_tokens = max(0, decision.kv_transfer_blocks * block_token_budget)
|
|
return uncached_tokens <= config.kvcache_direct_max_uncached_tokens
|
|
|
|
|
|
def _worker_url_by_id(workers, worker_id: str) -> str:
|
|
for worker in workers:
|
|
if worker.worker_id == worker_id:
|
|
return worker.url
|
|
raise KeyError(f"Unknown worker id: {worker_id}")
|
|
|
|
|
|
def _build_headers(
|
|
*,
|
|
request: TraceRequest,
|
|
header_mode: HeaderMode,
|
|
decode_worker_index: int,
|
|
policy_name: str,
|
|
) -> dict[str, str]:
|
|
if header_mode == "auto":
|
|
header_mode = "routing-key" if policy_name == "sticky" else "none"
|
|
|
|
headers = {
|
|
"x-request-id": request.request_id,
|
|
}
|
|
if header_mode == "routing-key":
|
|
headers["x-smg-routing-key"] = request.session_id
|
|
elif header_mode == "target-worker":
|
|
headers["x-smg-target-worker"] = str(decode_worker_index)
|
|
return headers
|
|
|
|
|
|
def _contains_token(payload: dict) -> bool:
|
|
choices = payload.get("choices")
|
|
if not isinstance(choices, list) or not choices:
|
|
return False
|
|
delta = choices[0].get("delta")
|
|
if not isinstance(delta, dict):
|
|
return False
|
|
|
|
reasoning_content = delta.get("reasoning_content")
|
|
if isinstance(reasoning_content, str) and reasoning_content:
|
|
return True
|
|
|
|
content = delta.get("content")
|
|
if isinstance(content, str) and content:
|
|
return True
|
|
if isinstance(content, list):
|
|
return any(
|
|
isinstance(item, dict) and item.get("text")
|
|
for item in content
|
|
)
|
|
|
|
tool_calls = delta.get("tool_calls")
|
|
if isinstance(tool_calls, list):
|
|
for tool_call in tool_calls:
|
|
if not isinstance(tool_call, dict):
|
|
continue
|
|
if tool_call.get("id"):
|
|
return True
|
|
function = tool_call.get("function")
|
|
if not isinstance(function, dict):
|
|
continue
|
|
if function.get("name") or function.get("arguments"):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _contains_generate_token(payload: dict) -> bool:
|
|
text = payload.get("text")
|
|
if isinstance(text, str) and text:
|
|
return True
|
|
meta_info = payload.get("meta_info")
|
|
if not isinstance(meta_info, dict):
|
|
return False
|
|
return int(meta_info.get("completion_tokens", 0)) > 0
|
|
|
|
|
|
def _is_generate_terminal_chunk(payload: dict) -> bool:
|
|
meta_info = payload.get("meta_info")
|
|
if not isinstance(meta_info, dict):
|
|
return False
|
|
return meta_info.get("finish_reason") is not None
|
|
|
|
|
|
def _extract_generate_cached_tokens(payload: dict) -> int:
|
|
meta_info = payload.get("meta_info")
|
|
if not isinstance(meta_info, dict):
|
|
return 0
|
|
return int(meta_info.get("cached_tokens", 0) or 0)
|
|
|
|
|
|
def _extract_openai_cached_tokens(payload: dict) -> int:
|
|
usage = payload.get("usage")
|
|
if not isinstance(usage, dict):
|
|
return 0
|
|
prompt_tokens_details = usage.get("prompt_tokens_details")
|
|
if not isinstance(prompt_tokens_details, dict):
|
|
return 0
|
|
return int(prompt_tokens_details.get("cached_tokens", 0) or 0)
|
|
|
|
|
|
async def _aiter_lines(
|
|
response: httpx.Response,
|
|
*,
|
|
idle_timeout_s: float | None,
|
|
):
|
|
line_iterator = response.aiter_lines()
|
|
while True:
|
|
try:
|
|
if idle_timeout_s is None or idle_timeout_s <= 0:
|
|
line = await anext(line_iterator)
|
|
else:
|
|
line = await asyncio.wait_for(
|
|
anext(line_iterator),
|
|
timeout=idle_timeout_s,
|
|
)
|
|
except StopAsyncIteration:
|
|
return
|
|
yield line
|
|
|
|
|
|
def _is_terminal_chunk(payload: dict) -> bool:
|
|
choices = payload.get("choices")
|
|
if not isinstance(choices, list) or not choices:
|
|
return False
|
|
finish_reason = choices[0].get("finish_reason")
|
|
return finish_reason is not None
|