E4 forensic (docs/E4_RESULTS_ZH.md): 272 admission rejections triggered the fallback seeded_router path, but zero /_snapshot/* HTTP calls hit the workers. Two root causes: 1. _attempt_d_to_p_sync gated on agentic-side `decode_session.opened`. By the time fallback runs, agentic has already flipped that flag to False in response to admission rejection. But D-side SessionAwareCache may still hold the session (release_session is not called automatically on admission rejection). Removing the gate; let D respond authoritatively with "session-not-resident" if it has actually evicted. 2. _attempt_d_to_p_sync logged decisions via logger.info, but agentic has no root logger handler so those events silently sank. Switching every branch (entry skip, prepare fail/not-ok, dump fail/not-ok, finalize fail/not-ok, ok) to write a structural-log line at outputs/<run>/structural/d-to-p-sync.jsonl. Each line carries stage, reason, durations, bytes pushed. The result doc is updated to reflect the honest E4-1 outcome and the P1 fix list.
3219 lines
114 KiB
Python
3219 lines
114 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import time
|
|
from collections import Counter
|
|
from dataclasses import dataclass, field, replace
|
|
from pathlib import Path
|
|
from typing import Any, Literal
|
|
|
|
import httpx
|
|
|
|
from agentic_pd_hybrid.metrics import (
|
|
RequestMetrics,
|
|
write_metrics_jsonl,
|
|
write_summary_json,
|
|
)
|
|
from agentic_pd_hybrid.policies import RoutingState, create_policy
|
|
from agentic_pd_hybrid.topology import SingleNodeTopology
|
|
from agentic_pd_hybrid.trace import (
|
|
TraceRequest,
|
|
build_synthetic_append_chunk,
|
|
build_synthetic_prompt,
|
|
load_trace,
|
|
)
|
|
|
|
|
|
HeaderMode = Literal["none", "routing-key", "target-worker", "auto"]
|
|
KvCacheAdmissionMode = Literal["router", "worker"]
|
|
KvCachePrefillBackupPolicy = Literal["release-after-transfer", "capacity-backup"]
|
|
_ADMISSION_PROBE_TIMEOUT_S = 2.0
|
|
|
|
|
|
# --- Structural event logging (admission probes, backpressure pauses, ---
|
|
# --- session-D bindings). Module-level state keeps call-site diff small. ---
|
|
_STRUCTURAL_LOG_DIR: Path | None = None
|
|
_STRUCTURAL_LOG_LOCK = asyncio.Lock()
|
|
_STRUCTURAL_LOG_FILES: dict[str, Any] = {}
|
|
_STRUCTURAL_RUN_START_S: float = 0.0
|
|
|
|
|
|
def _structural_init(log_dir: Path | None) -> None:
|
|
global _STRUCTURAL_LOG_DIR, _STRUCTURAL_RUN_START_S
|
|
_STRUCTURAL_LOG_DIR = log_dir
|
|
_STRUCTURAL_RUN_START_S = time.perf_counter()
|
|
if log_dir is not None:
|
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def _structural_close() -> None:
|
|
for handle in _STRUCTURAL_LOG_FILES.values():
|
|
try:
|
|
handle.close()
|
|
except Exception:
|
|
pass
|
|
_STRUCTURAL_LOG_FILES.clear()
|
|
|
|
|
|
async def _structural_emit(filename: str, event: dict[str, Any]) -> None:
|
|
if _STRUCTURAL_LOG_DIR is None:
|
|
return
|
|
event = {"t": round(time.perf_counter() - _STRUCTURAL_RUN_START_S, 4), **event}
|
|
async with _STRUCTURAL_LOG_LOCK:
|
|
handle = _STRUCTURAL_LOG_FILES.get(filename)
|
|
if handle is None:
|
|
handle = (_STRUCTURAL_LOG_DIR / filename).open("a", encoding="utf-8")
|
|
_STRUCTURAL_LOG_FILES[filename] = handle
|
|
handle.write(json.dumps(event, sort_keys=True) + "\n")
|
|
handle.flush()
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ReplayConfig:
|
|
trace_path: Path
|
|
output_path: Path
|
|
policy_name: str
|
|
mechanism_name: str
|
|
topology: SingleNodeTopology
|
|
router_url: str | None = None
|
|
model_name: str | None = None
|
|
pace: bool = True
|
|
time_scale: float = 1.0
|
|
request_limit: int | None = None
|
|
concurrency_limit: int = 32
|
|
header_mode: HeaderMode = "auto"
|
|
timeout_s: float = 600.0
|
|
stream: bool = True
|
|
stream_idle_timeout_s: float | None = 900.0
|
|
kvcache_direct_max_uncached_tokens: int = 2048
|
|
kvcache_admission_mode: KvCacheAdmissionMode = "router"
|
|
kvcache_seed_max_resident_tokens: int | None = None
|
|
kvcache_seed_max_output_tokens: int | None = None
|
|
kvcache_seed_min_turn_id: int = 1
|
|
kvcache_seed_only_multiturn_sessions: bool = False
|
|
kvcache_seed_allowed_session_ids: frozenset[str] | None = None
|
|
kvcache_prefill_backup_policy: KvCachePrefillBackupPolicy = (
|
|
"release-after-transfer"
|
|
)
|
|
kvcache_seed_max_inflight_decode: int | None = 3
|
|
kvcache_seed_max_decode_transfer_queue_reqs: int | None = None
|
|
kvcache_direct_max_decode_transfer_queue_reqs: int | None = None
|
|
kvcache_prefill_priority_eviction: bool = False
|
|
kvcache_prefill_direct_priority: int = -100
|
|
kvcache_prefill_normal_priority: int = 100
|
|
pool_poll_interval_s: float = 0.0
|
|
pool_poll_include_sessions: bool = True
|
|
enable_backpressure: bool = False
|
|
backpressure_max_pause_s: float = 2.0
|
|
# Session migration via per-(sess, D) admission reject memory.
|
|
# When a session has been admission-rejected this many times on a given D,
|
|
# KvAwarePolicy skips that D for the session (forcing migration). Default 3.
|
|
# Set 0 to disable. See REFACTOR_PLAN_V1 §6.2.
|
|
kvcache_migration_reject_threshold: int = 3
|
|
# Load-floor bonus magnitude for KvAwarePolicy: graduated boost added to
|
|
# under-loaded D workers to break overlap-pinning imbalance on workloads
|
|
# with shared cross-session prefix. 0 disables. See
|
|
# docs/E1_E2_FIX_DESIGN_ZH.md §Q2.
|
|
kvcache_load_floor_bonus: int = 0
|
|
# D→P snapshot push: when True and reseed fires, agentic will RDMA-dump
|
|
# the session's KV from the D-side worker that last held it onto the P
|
|
# worker and insert into P's radix tree, so the subsequent P prefill
|
|
# hits cache. See docs/D_TO_P_SYNC_DESIGN_ZH.md.
|
|
enable_d_to_p_sync: bool = False
|
|
structural_log_dir: Path | None = None
|
|
|
|
|
|
@dataclass
|
|
class DirectSessionState:
|
|
session_id: str
|
|
server_url: str
|
|
opened: bool = False
|
|
last_trace_request: TraceRequest | None = None
|
|
resident_tokens: int = 0
|
|
last_access_s: float = 0.0
|
|
active_requests: int = 0
|
|
prefill_server_url: str | None = None
|
|
prefill_opened: bool = False
|
|
prefill_resident_tokens: int = 0
|
|
prefill_last_access_s: float = 0.0
|
|
prefill_low_priority: bool = False
|
|
|
|
|
|
@dataclass
|
|
class DecodeResidencyState:
|
|
capacity_tokens: dict[str, int] = field(default_factory=dict)
|
|
headroom_tokens: dict[str, int] = field(default_factory=dict)
|
|
reserved_decode_tokens: dict[str, int] = field(default_factory=dict)
|
|
resident_tokens_by_server: dict[str, int] = field(default_factory=dict)
|
|
reserved_tokens_by_server: dict[str, int] = field(default_factory=dict)
|
|
prefill_capacity_tokens: dict[str, int] = field(default_factory=dict)
|
|
prefill_headroom_tokens: dict[str, int] = field(default_factory=dict)
|
|
prefill_resident_tokens_by_server: dict[str, int] = field(default_factory=dict)
|
|
prefill_reserved_tokens_by_server: dict[str, int] = field(default_factory=dict)
|
|
decode_evictions_prefill_backed: int = 0
|
|
decode_evictions_without_prefill_backup: int = 0
|
|
# Backpressure: per-D timestamp until which new requests should pause.
|
|
pause_until_s: dict[str, float] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DecodeLoadSnapshot:
|
|
timestamp_s: float
|
|
num_running_reqs: int
|
|
num_waiting_reqs: int
|
|
num_used_tokens: int
|
|
max_total_num_tokens: int
|
|
token_usage: float
|
|
decode_prealloc_queue_reqs: int
|
|
decode_transfer_queue_reqs: int
|
|
decode_retracted_queue_reqs: int
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExecutionResult:
|
|
execution_mode: str
|
|
actual_kv_transfer_blocks: int
|
|
effective_input_length: int | None
|
|
cached_tokens: int
|
|
session_reused: bool
|
|
session_reset: bool
|
|
latency_s: float | None
|
|
ttft_s: float | None
|
|
tpot_s: float | None
|
|
prefill_request_priority: int | None = None
|
|
decode_request_priority: int | None = None
|
|
error: str | None = None
|
|
actual_output_tokens: int | None = None
|
|
requested_output_tokens: int | None = None
|
|
finish_reason: str | None = None
|
|
|
|
|
|
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
|
|
structural_dir = config.structural_log_dir
|
|
if structural_dir is None and config.output_path is not None:
|
|
structural_dir = config.output_path.parent / "structural"
|
|
_structural_init(structural_dir)
|
|
requests = load_trace(config.trace_path, request_limit=config.request_limit)
|
|
if config.kvcache_seed_only_multiturn_sessions:
|
|
session_turns = Counter(request.session_id for request in requests)
|
|
config = replace(
|
|
config,
|
|
kvcache_seed_allowed_session_ids=frozenset(
|
|
session_id
|
|
for session_id, turn_count in session_turns.items()
|
|
if turn_count > 1
|
|
),
|
|
)
|
|
policy = create_policy(
|
|
config.policy_name,
|
|
migration_reject_threshold=config.kvcache_migration_reject_threshold,
|
|
load_floor_bonus=config.kvcache_load_floor_bonus,
|
|
)
|
|
state = RoutingState.create(config.topology)
|
|
state_lock = asyncio.Lock()
|
|
semaphore = asyncio.Semaphore(config.concurrency_limit)
|
|
start_time = time.perf_counter()
|
|
first_timestamp = requests[0].timestamp_s if requests else 0.0
|
|
session_tail_tasks: dict[str, asyncio.Task[RequestMetrics]] = {}
|
|
direct_sessions: dict[str, DirectSessionState] = {}
|
|
direct_session_lock = asyncio.Lock()
|
|
async with httpx.AsyncClient(timeout=config.timeout_s, trust_env=False) as client:
|
|
decode_residency = await _discover_decode_residency(
|
|
client=client,
|
|
config=config,
|
|
)
|
|
poll_task: asyncio.Task[None] | None = None
|
|
if config.pool_poll_interval_s > 0:
|
|
poll_workers: list[tuple[str, str, str]] = []
|
|
for worker in config.topology.decode_workers:
|
|
poll_workers.append((worker.worker_id, "decode", worker.url))
|
|
for worker in config.topology.prefill_workers:
|
|
poll_workers.append((worker.worker_id, "prefill", worker.url))
|
|
if poll_workers:
|
|
poll_output = config.output_path.parent / "d-pool-timeseries.jsonl"
|
|
poll_task = asyncio.create_task(
|
|
_poll_pool_timeseries(
|
|
client=client,
|
|
workers=poll_workers,
|
|
interval_s=config.pool_poll_interval_s,
|
|
output_path=poll_output,
|
|
start_time=start_time,
|
|
include_sessions=config.pool_poll_include_sessions,
|
|
)
|
|
)
|
|
tasks = []
|
|
for request in requests:
|
|
if config.pace:
|
|
target_offset = (request.timestamp_s - first_timestamp) / config.time_scale
|
|
sleep_s = target_offset - (time.perf_counter() - start_time)
|
|
if sleep_s > 0:
|
|
await asyncio.sleep(sleep_s)
|
|
tasks.append(
|
|
asyncio.create_task(
|
|
_run_request(
|
|
request=request,
|
|
config=config,
|
|
client=client,
|
|
policy=policy,
|
|
state=state,
|
|
state_lock=state_lock,
|
|
semaphore=semaphore,
|
|
direct_sessions=direct_sessions,
|
|
direct_session_lock=direct_session_lock,
|
|
decode_residency=decode_residency,
|
|
depends_on=session_tail_tasks.get(request.session_id),
|
|
)
|
|
)
|
|
)
|
|
session_tail_tasks[request.session_id] = tasks[-1]
|
|
|
|
results = await asyncio.gather(*tasks)
|
|
if poll_task is not None:
|
|
poll_task.cancel()
|
|
try:
|
|
await poll_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
for session in direct_sessions.values():
|
|
if session.opened:
|
|
try:
|
|
await _close_streaming_session(
|
|
client=client,
|
|
server_url=session.server_url,
|
|
session_id=session.session_id,
|
|
allow_missing=True,
|
|
)
|
|
except Exception:
|
|
pass
|
|
if session.prefill_opened and session.prefill_server_url is not None:
|
|
try:
|
|
await _close_streaming_session(
|
|
client=client,
|
|
server_url=session.prefill_server_url,
|
|
session_id=session.session_id,
|
|
allow_missing=True,
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
write_metrics_jsonl(config.output_path, results)
|
|
write_summary_json(
|
|
config.output_path.with_suffix(config.output_path.suffix + ".summary.json"),
|
|
results,
|
|
trace_path=config.trace_path,
|
|
router_url=config.router_url,
|
|
)
|
|
_structural_close()
|
|
return results
|
|
|
|
|
|
async def _run_request(
|
|
*,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
client: httpx.AsyncClient,
|
|
policy,
|
|
state: RoutingState,
|
|
state_lock: asyncio.Lock,
|
|
semaphore: asyncio.Semaphore,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
direct_session_lock: asyncio.Lock,
|
|
decode_residency: DecodeResidencyState,
|
|
depends_on: asyncio.Task[RequestMetrics] | None,
|
|
) -> RequestMetrics:
|
|
if depends_on is not None:
|
|
await depends_on
|
|
async with semaphore:
|
|
async with state_lock:
|
|
decision = policy.select(request, topology=config.topology, state=state)
|
|
|
|
await _structural_emit(
|
|
"session-d-binding.jsonl",
|
|
{
|
|
"session_id": request.session_id,
|
|
"request_id": request.request_id,
|
|
"turn_id": request.turn_id,
|
|
"decode_worker_index": decision.decode_worker_index,
|
|
"decode_worker_id": decision.decode_worker_id,
|
|
"prefill_worker_id": decision.prefill_worker_id,
|
|
"observed_overlap_blocks": decision.observed_overlap_blocks,
|
|
"kv_transfer_blocks": decision.kv_transfer_blocks,
|
|
"inflight_decode_load": decision.inflight_decode_load,
|
|
},
|
|
)
|
|
|
|
try:
|
|
execution = await _execute_request(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
direct_sessions=direct_sessions,
|
|
direct_session_lock=direct_session_lock,
|
|
decode_residency=decode_residency,
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive logging path
|
|
execution = ExecutionResult(
|
|
execution_mode=config.mechanism_name,
|
|
actual_kv_transfer_blocks=0,
|
|
effective_input_length=None,
|
|
cached_tokens=0,
|
|
session_reused=False,
|
|
session_reset=False,
|
|
latency_s=None,
|
|
ttft_s=None,
|
|
tpot_s=None,
|
|
error=f"{type(exc).__name__}: {exc}",
|
|
)
|
|
|
|
async with state_lock:
|
|
state.finish(request, decision)
|
|
# Migration feedback: if this request was forced into a fallback path
|
|
# because the chosen D rejected admission, record the (session, D)
|
|
# rejection so KvAwarePolicy can migrate this session next turn.
|
|
if _is_admission_rejection_mode(execution.execution_mode):
|
|
state.record_admission_reject(
|
|
request.session_id,
|
|
decision.decode_worker_id,
|
|
)
|
|
# Reset-on-success: a successful direct-to-D path proves D-X can
|
|
# currently serve this session — clear the cumulative reject counter
|
|
# so that brief past saturation doesn't permanently blacklist the D.
|
|
# (MIGRATION_V1_FINDINGS §4.1: blacklist-permanence bug fix.)
|
|
elif execution.execution_mode == "kvcache-direct-to-d-session":
|
|
state.session_d_rejects[
|
|
(request.session_id, decision.decode_worker_id)
|
|
] = 0
|
|
|
|
return RequestMetrics.from_decision(
|
|
request,
|
|
decision,
|
|
mechanism_name=config.mechanism_name,
|
|
execution_mode=execution.execution_mode,
|
|
actual_kv_transfer_blocks=execution.actual_kv_transfer_blocks,
|
|
effective_input_length=execution.effective_input_length,
|
|
cached_tokens=execution.cached_tokens,
|
|
session_reused=execution.session_reused,
|
|
session_reset=execution.session_reset,
|
|
latency_s=execution.latency_s,
|
|
ttft_s=execution.ttft_s,
|
|
tpot_s=execution.tpot_s,
|
|
prefill_request_priority=execution.prefill_request_priority,
|
|
decode_request_priority=execution.decode_request_priority,
|
|
error=execution.error,
|
|
actual_output_tokens=execution.actual_output_tokens,
|
|
requested_output_tokens=execution.requested_output_tokens,
|
|
finish_reason=execution.finish_reason,
|
|
)
|
|
|
|
|
|
async def _invoke_router(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decode_worker_index: int,
|
|
session_id: str | None = None,
|
|
prefill_request_priority: int | None = None,
|
|
decode_request_priority: int | None = None,
|
|
decode_residency: "DecodeResidencyState | None" = None,
|
|
) -> GenerateResult:
|
|
if decode_residency is not None and config.enable_backpressure:
|
|
decode_url = config.topology.decode_workers[decode_worker_index].url
|
|
await _wait_for_decode_pause(
|
|
config=config,
|
|
residency=decode_residency,
|
|
server_url=decode_url,
|
|
request_id=request.request_id,
|
|
session_id=session_id,
|
|
)
|
|
headers = _build_headers(
|
|
request=request,
|
|
header_mode=config.header_mode,
|
|
decode_worker_index=decode_worker_index,
|
|
policy_name=config.policy_name,
|
|
)
|
|
assert config.router_url is not None
|
|
payload: dict[str, object] = {
|
|
"input_ids": _build_direct_full_input_ids(request),
|
|
"sampling_params": {
|
|
"temperature": 0,
|
|
"max_new_tokens": max(1, request.output_length),
|
|
"ignore_eos": True,
|
|
"no_stop_trim": True,
|
|
"skip_special_tokens": False,
|
|
},
|
|
"stream": config.stream,
|
|
}
|
|
if session_id is not None:
|
|
payload["session_params"] = {"id": session_id}
|
|
if prefill_request_priority is not None:
|
|
payload["smg_prefill_priority"] = prefill_request_priority
|
|
if decode_request_priority is not None:
|
|
payload["smg_decode_priority"] = decode_request_priority
|
|
|
|
return await _invoke_generate(
|
|
client=client,
|
|
base_url=config.router_url,
|
|
headers=headers,
|
|
payload=payload,
|
|
timeout_s=config.timeout_s,
|
|
stream_idle_timeout_s=config.stream_idle_timeout_s,
|
|
stream=config.stream,
|
|
)
|
|
|
|
|
|
def _build_payload(
|
|
*,
|
|
request: TraceRequest,
|
|
model_name: str,
|
|
prompt: str,
|
|
stream: bool,
|
|
session_params: dict[str, str] | None,
|
|
exact_output_length: bool,
|
|
) -> dict[str, object]:
|
|
payload: dict[str, object] = {
|
|
"model": model_name,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": max(1, request.output_length),
|
|
"temperature": 0,
|
|
"stream": stream,
|
|
}
|
|
if stream:
|
|
payload["stream_options"] = {"include_usage": True}
|
|
if exact_output_length:
|
|
payload.update(
|
|
{
|
|
"min_tokens": max(1, request.output_length),
|
|
"ignore_eos": True,
|
|
"no_stop_trim": True,
|
|
"skip_special_tokens": False,
|
|
}
|
|
)
|
|
if session_params is not None:
|
|
payload["session_params"] = session_params
|
|
return payload
|
|
|
|
|
|
async def _invoke_chat_completion(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
base_url: str,
|
|
headers: dict[str, str],
|
|
payload: dict[str, object],
|
|
timeout_s: float,
|
|
stream_idle_timeout_s: float | None,
|
|
stream: bool,
|
|
) -> tuple[float, float | None, float | None, int]:
|
|
start = time.perf_counter()
|
|
ttft_s: float | None = None
|
|
cached_tokens = 0
|
|
generated_tokens = int(payload.get("max_tokens", 1))
|
|
if stream:
|
|
async with client.stream(
|
|
"POST",
|
|
f"{base_url.rstrip('/')}/v1/chat/completions",
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=timeout_s,
|
|
) as response:
|
|
response.raise_for_status()
|
|
async for line in _aiter_lines(
|
|
response,
|
|
idle_timeout_s=stream_idle_timeout_s,
|
|
):
|
|
if not line.startswith("data:"):
|
|
continue
|
|
data = line[5:].strip()
|
|
if data == "[DONE]":
|
|
break
|
|
parsed = json.loads(data)
|
|
cached_tokens = max(cached_tokens, _extract_openai_cached_tokens(parsed))
|
|
if _contains_token(parsed) and ttft_s is None:
|
|
ttft_s = time.perf_counter() - start
|
|
if _is_terminal_chunk(parsed):
|
|
break
|
|
else:
|
|
response = await client.post(
|
|
f"{base_url.rstrip('/')}/v1/chat/completions",
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=timeout_s,
|
|
)
|
|
response.raise_for_status()
|
|
parsed = response.json()
|
|
cached_tokens = _extract_openai_cached_tokens(parsed)
|
|
|
|
latency_s = time.perf_counter() - start
|
|
if stream and ttft_s is None and generated_tokens > 0:
|
|
raise RuntimeError("generate stream ended before producing any token")
|
|
if ttft_s is None:
|
|
tpot_s = None
|
|
else:
|
|
tpot_s = max(0.0, latency_s - ttft_s) / max(1, generated_tokens)
|
|
return latency_s, ttft_s, tpot_s, cached_tokens
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class GenerateResult:
|
|
latency_s: float
|
|
ttft_s: float | None
|
|
tpot_s: float | None
|
|
cached_tokens: int
|
|
actual_output_tokens: int
|
|
requested_output_tokens: int
|
|
finish_reason: str | None
|
|
server_meta_info: dict | None
|
|
|
|
|
|
async def _invoke_generate(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
base_url: str,
|
|
headers: dict[str, str],
|
|
payload: dict[str, object],
|
|
timeout_s: float,
|
|
stream_idle_timeout_s: float | None,
|
|
stream: bool,
|
|
) -> GenerateResult:
|
|
start = time.perf_counter()
|
|
ttft_s: float | None = None
|
|
cached_tokens = 0
|
|
sampling_params = payload.get("sampling_params", {})
|
|
requested_output_tokens = int(sampling_params.get("max_new_tokens", 1))
|
|
actual_token_count = 0
|
|
finish_reason: str | None = None
|
|
last_meta_info: dict | None = None
|
|
|
|
if stream:
|
|
async with client.stream(
|
|
"POST",
|
|
f"{base_url.rstrip('/')}/generate",
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=timeout_s,
|
|
) as response:
|
|
response.raise_for_status()
|
|
async for line in _aiter_lines(
|
|
response,
|
|
idle_timeout_s=stream_idle_timeout_s,
|
|
):
|
|
if not line.startswith("data:"):
|
|
continue
|
|
data = line[5:].strip()
|
|
if data == "[DONE]":
|
|
break
|
|
parsed = json.loads(data)
|
|
error = parsed.get("error")
|
|
if isinstance(error, dict):
|
|
raise ValueError(error.get("message", json.dumps(error)))
|
|
cached_tokens = max(cached_tokens, _extract_generate_cached_tokens(parsed))
|
|
if _contains_generate_token(parsed):
|
|
actual_token_count += 1
|
|
if ttft_s is None:
|
|
ttft_s = time.perf_counter() - start
|
|
meta_info = parsed.get("meta_info")
|
|
if isinstance(meta_info, dict):
|
|
last_meta_info = meta_info
|
|
completion_tokens = int(meta_info.get("completion_tokens", 0))
|
|
if completion_tokens > actual_token_count:
|
|
actual_token_count = completion_tokens
|
|
fr = meta_info.get("finish_reason")
|
|
if fr is not None:
|
|
finish_reason = str(fr)
|
|
if _is_generate_terminal_chunk(parsed):
|
|
break
|
|
else:
|
|
response = await client.post(
|
|
f"{base_url.rstrip('/')}/generate",
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=timeout_s,
|
|
)
|
|
response.raise_for_status()
|
|
parsed = response.json()
|
|
error = parsed.get("error")
|
|
if isinstance(error, dict):
|
|
raise ValueError(error.get("message", json.dumps(error)))
|
|
cached_tokens = _extract_generate_cached_tokens(parsed)
|
|
meta_info = parsed.get("meta_info")
|
|
if isinstance(meta_info, dict):
|
|
last_meta_info = meta_info
|
|
actual_token_count = int(meta_info.get("completion_tokens", 0))
|
|
finish_reason = meta_info.get("finish_reason")
|
|
|
|
latency_s = time.perf_counter() - start
|
|
if stream and ttft_s is None and requested_output_tokens > 0:
|
|
raise RuntimeError("generate stream ended before producing any token")
|
|
|
|
# Use actual token count for TPOT (not requested count)
|
|
effective_tokens = max(1, actual_token_count) if actual_token_count > 0 else max(1, requested_output_tokens)
|
|
if ttft_s is None:
|
|
tpot_s = None
|
|
else:
|
|
tpot_s = max(0.0, latency_s - ttft_s) / effective_tokens
|
|
|
|
return GenerateResult(
|
|
latency_s=latency_s,
|
|
ttft_s=ttft_s,
|
|
tpot_s=tpot_s,
|
|
cached_tokens=cached_tokens,
|
|
actual_output_tokens=actual_token_count,
|
|
requested_output_tokens=requested_output_tokens,
|
|
finish_reason=finish_reason,
|
|
server_meta_info=last_meta_info,
|
|
)
|
|
|
|
|
|
async def _open_streaming_session(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
server_url: str,
|
|
session_id: str,
|
|
request: TraceRequest,
|
|
) -> None:
|
|
capacity = max(
|
|
4096,
|
|
request.input_length * 16,
|
|
(request.input_length + request.output_length) * 16,
|
|
)
|
|
response = await client.post(
|
|
f"{server_url.rstrip('/')}/open_session",
|
|
json={
|
|
"capacity_of_str_len": capacity,
|
|
"session_id": session_id,
|
|
"streaming": True,
|
|
},
|
|
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
|
)
|
|
response.raise_for_status()
|
|
opened_session_id = response.json()
|
|
if opened_session_id != session_id:
|
|
raise ValueError(
|
|
f"Unexpected session id from {server_url}: {opened_session_id!r} != {session_id!r}"
|
|
)
|
|
|
|
|
|
async def _close_streaming_session(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
server_url: str,
|
|
session_id: str,
|
|
allow_missing: bool = False,
|
|
) -> None:
|
|
response = await client.post(
|
|
f"{server_url.rstrip('/')}/close_session",
|
|
json={"session_id": session_id},
|
|
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
|
)
|
|
if response.is_success:
|
|
return
|
|
if allow_missing:
|
|
response_text = response.text.lower()
|
|
if response.status_code == 404 or "does not exist" in response_text:
|
|
return
|
|
response.raise_for_status()
|
|
|
|
|
|
def _extract_internal_state(payload: dict[str, Any]) -> dict[str, Any]:
|
|
internal_states = payload.get("internal_states")
|
|
if isinstance(internal_states, list) and internal_states:
|
|
internal_state = internal_states[0]
|
|
if isinstance(internal_state, dict):
|
|
return internal_state
|
|
return payload
|
|
|
|
|
|
def _extract_server_int(
|
|
payload: dict[str, Any],
|
|
key: str,
|
|
) -> int:
|
|
internal_state = _extract_internal_state(payload)
|
|
value = payload.get(key, internal_state.get(key, 0))
|
|
return int(value or 0)
|
|
|
|
|
|
def _extract_session_cache(
|
|
payload: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
internal_state = _extract_internal_state(payload)
|
|
session_cache = internal_state.get("session_cache")
|
|
if isinstance(session_cache, dict):
|
|
return session_cache
|
|
return {}
|
|
|
|
|
|
def _find_session_cache_status(
|
|
session_cache: dict[str, Any],
|
|
session_id: str,
|
|
) -> dict[str, Any] | None:
|
|
sessions = session_cache.get("sessions")
|
|
if not isinstance(sessions, list):
|
|
return None
|
|
for session in sessions:
|
|
if isinstance(session, dict) and session.get("session_id") == session_id:
|
|
return session
|
|
return None
|
|
|
|
|
|
async def _fetch_decode_server_state(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
server_url: str,
|
|
) -> tuple[dict[str, Any], int, int]:
|
|
try:
|
|
response = await client.get(
|
|
f"{server_url.rstrip('/')}/server_info",
|
|
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
|
)
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
except Exception:
|
|
return {}, 0, 0
|
|
|
|
return (
|
|
_extract_session_cache(payload),
|
|
_extract_server_int(payload, "max_total_num_tokens"),
|
|
_extract_server_int(payload, "num_reserved_decode_tokens"),
|
|
)
|
|
|
|
|
|
async def _query_pool_snapshot(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
server_url: str,
|
|
include_sessions: bool,
|
|
) -> dict[str, Any]:
|
|
try:
|
|
response = await client.get(
|
|
f"{server_url.rstrip('/')}/server_info",
|
|
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
|
)
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
except Exception as exc:
|
|
return {"error": type(exc).__name__}
|
|
|
|
internal = _extract_internal_state(payload)
|
|
session_cache = _extract_session_cache(payload)
|
|
sessions: list[dict[str, Any]] = []
|
|
if include_sessions and isinstance(session_cache.get("sessions"), list):
|
|
for entry in session_cache["sessions"]:
|
|
if not isinstance(entry, dict):
|
|
continue
|
|
sessions.append(
|
|
{
|
|
"session_id": entry.get("session_id"),
|
|
"resident": bool(entry.get("resident")),
|
|
"resident_tokens": int(entry.get("resident_tokens") or 0),
|
|
"idle_evictable": bool(entry.get("idle_evictable")),
|
|
"timed_out": bool(entry.get("timed_out")),
|
|
}
|
|
)
|
|
|
|
memory_usage = internal.get("memory_usage") if isinstance(internal, dict) else None
|
|
if not isinstance(memory_usage, dict):
|
|
memory_usage = {}
|
|
|
|
# P1 instrument: pool_breakdown decomposes "other" into named buckets
|
|
pool_breakdown = internal.get("pool_breakdown") if isinstance(internal, dict) else None
|
|
if not isinstance(pool_breakdown, dict):
|
|
pool_breakdown = {}
|
|
|
|
return {
|
|
"session_cache_enabled": bool(session_cache.get("enabled")),
|
|
"session_count": int(session_cache.get("session_count") or 0),
|
|
"resident_session_count": int(session_cache.get("resident_session_count") or 0),
|
|
"held_tokens": int(session_cache.get("held_tokens") or 0),
|
|
"available_tokens": int(session_cache.get("available_tokens") or 0),
|
|
"capacity_tokens": int(session_cache.get("capacity_tokens") or 0),
|
|
"idle_evictable_session_count": int(
|
|
session_cache.get("idle_evictable_session_count") or 0
|
|
),
|
|
"idle_evictable_tokens": int(session_cache.get("idle_evictable_tokens") or 0),
|
|
"kvcache_mem_gb": float(memory_usage.get("kvcache") or 0.0),
|
|
"token_capacity": int(memory_usage.get("token_capacity") or 0),
|
|
"max_total_num_tokens": int(internal.get("max_total_num_tokens") or 0)
|
|
if isinstance(internal, dict)
|
|
else 0,
|
|
"last_gen_throughput": float(internal.get("last_gen_throughput") or 0.0)
|
|
if isinstance(internal, dict)
|
|
else 0.0,
|
|
"radix_evictable_tokens": int(pool_breakdown.get("radix_evictable_tokens") or 0),
|
|
"radix_protected_tokens": int(pool_breakdown.get("radix_protected_tokens") or 0),
|
|
"slot_private_held_tokens": int(pool_breakdown.get("slot_private_held_tokens") or 0),
|
|
"session_slot_count": int(pool_breakdown.get("session_slot_count") or 0),
|
|
"running_batch_reqs": int(pool_breakdown.get("running_batch_reqs") or 0),
|
|
"running_batch_kv_tokens": int(pool_breakdown.get("running_batch_kv_tokens") or 0),
|
|
"transfer_queue_reqs": int(pool_breakdown.get("transfer_queue_reqs") or 0),
|
|
"transfer_queue_tokens": int(pool_breakdown.get("transfer_queue_tokens") or 0),
|
|
"prealloc_queue_reqs": int(pool_breakdown.get("prealloc_queue_reqs") or 0),
|
|
"prealloc_queue_tokens": int(pool_breakdown.get("prealloc_queue_tokens") or 0),
|
|
"retracted_queue_reqs": int(pool_breakdown.get("retracted_queue_reqs") or 0),
|
|
"retracted_queue_tokens": int(pool_breakdown.get("retracted_queue_tokens") or 0),
|
|
"sessions": sessions,
|
|
}
|
|
|
|
|
|
async def _poll_pool_timeseries(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
workers: list[tuple[str, str, str]],
|
|
interval_s: float,
|
|
output_path: Path,
|
|
start_time: float,
|
|
include_sessions: bool,
|
|
) -> None:
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with output_path.open("w", encoding="utf-8") as handle:
|
|
try:
|
|
while True:
|
|
tick_started = time.perf_counter()
|
|
ts = time.time()
|
|
wall_s = tick_started - start_time
|
|
snapshots = await asyncio.gather(
|
|
*(
|
|
_query_pool_snapshot(
|
|
client=client,
|
|
server_url=url,
|
|
include_sessions=include_sessions,
|
|
)
|
|
for _, _, url in workers
|
|
),
|
|
return_exceptions=True,
|
|
)
|
|
for (worker_id, role, url), snap in zip(workers, snapshots):
|
|
if isinstance(snap, BaseException):
|
|
row: dict[str, Any] = {
|
|
"ts": ts,
|
|
"wall_s": wall_s,
|
|
"worker_id": worker_id,
|
|
"worker_role": role,
|
|
"worker_url": url,
|
|
"error": type(snap).__name__,
|
|
}
|
|
else:
|
|
row = {
|
|
"ts": ts,
|
|
"wall_s": wall_s,
|
|
"worker_id": worker_id,
|
|
"worker_role": role,
|
|
"worker_url": url,
|
|
**snap,
|
|
}
|
|
handle.write(json.dumps(row, sort_keys=True) + "\n")
|
|
handle.flush()
|
|
elapsed = time.perf_counter() - tick_started
|
|
sleep_s = interval_s - elapsed
|
|
if sleep_s > 0:
|
|
await asyncio.sleep(sleep_s)
|
|
except asyncio.CancelledError:
|
|
return
|
|
|
|
|
|
async def _query_decode_direct_admission(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
server_url: str,
|
|
session_id: str,
|
|
uncached_input_tokens: int,
|
|
output_tokens: int,
|
|
mode: str = "direct_append",
|
|
config: "ReplayConfig | None" = None,
|
|
residency: "DecodeResidencyState | None" = None,
|
|
request_id: str | None = None,
|
|
turn_id: int | None = None,
|
|
) -> dict[str, Any]:
|
|
started = time.perf_counter()
|
|
try:
|
|
response = await client.post(
|
|
f"{server_url.rstrip('/')}/session_cache/admit_direct_append",
|
|
json={
|
|
"session_id": session_id,
|
|
"uncached_input_tokens": max(0, uncached_input_tokens),
|
|
"output_tokens": max(0, output_tokens),
|
|
"mode": mode,
|
|
},
|
|
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
|
)
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
if not isinstance(payload, dict):
|
|
payload = None
|
|
except Exception as exc:
|
|
payload = None
|
|
_last_exc_msg = type(exc).__name__
|
|
else:
|
|
_last_exc_msg = None
|
|
|
|
if payload is None:
|
|
payload = {
|
|
"can_admit": False,
|
|
"resident": False,
|
|
"reason": "admission-query-failed",
|
|
"required_tokens": 0,
|
|
"available_tokens_before": 0,
|
|
"available_tokens_after": 0,
|
|
"evicted_session_count": 0,
|
|
"freed_tokens": 0,
|
|
}
|
|
|
|
rtt_s = time.perf_counter() - started
|
|
pause_ms = int(payload.get("recommended_pause_ms", 0) or 0)
|
|
|
|
# Update per-D pause window when backpressure is enabled.
|
|
if (
|
|
config is not None
|
|
and residency is not None
|
|
and config.enable_backpressure
|
|
and pause_ms > 0
|
|
):
|
|
max_pause_s = max(0.0, config.backpressure_max_pause_s)
|
|
applied_pause_s = min(pause_ms / 1000.0, max_pause_s)
|
|
new_until = time.perf_counter() + applied_pause_s
|
|
prev = residency.pause_until_s.get(server_url, 0.0)
|
|
if new_until > prev:
|
|
residency.pause_until_s[server_url] = new_until
|
|
|
|
# Always emit admission event for analysis (even if backpressure disabled).
|
|
await _structural_emit(
|
|
"admission-events.jsonl",
|
|
{
|
|
"server_url": server_url,
|
|
"session_id": session_id,
|
|
"request_id": request_id,
|
|
"turn_id": turn_id,
|
|
"mode": mode,
|
|
"rtt_s": round(rtt_s, 4),
|
|
"can_admit": bool(payload.get("can_admit")),
|
|
"resident": bool(payload.get("resident")),
|
|
"reason": payload.get("reason"),
|
|
"queue_depth": int(payload.get("decode_transfer_queue_reqs", 0) or 0),
|
|
"retracted_depth": int(payload.get("decode_retracted_queue_reqs", 0) or 0),
|
|
"available_tokens_after": int(payload.get("available_tokens_after", 0) or 0),
|
|
"token_usage": float(payload.get("token_usage", 0.0) or 0.0),
|
|
"evicted_session_count": int(payload.get("evicted_session_count", 0) or 0),
|
|
"recommended_pause_ms": pause_ms,
|
|
"uncached_input_tokens": int(uncached_input_tokens),
|
|
"output_tokens": int(output_tokens),
|
|
},
|
|
)
|
|
return payload
|
|
|
|
|
|
async def _wait_for_decode_pause(
|
|
*,
|
|
config: "ReplayConfig",
|
|
residency: "DecodeResidencyState",
|
|
server_url: str,
|
|
request_id: str | None = None,
|
|
session_id: str | None = None,
|
|
) -> None:
|
|
if not config.enable_backpressure:
|
|
return
|
|
until = residency.pause_until_s.get(server_url, 0.0)
|
|
if until <= 0:
|
|
return
|
|
now = time.perf_counter()
|
|
if now >= until:
|
|
return
|
|
sleep_s = min(until - now, config.backpressure_max_pause_s)
|
|
await _structural_emit(
|
|
"backpressure-events.jsonl",
|
|
{
|
|
"server_url": server_url,
|
|
"session_id": session_id,
|
|
"request_id": request_id,
|
|
"sleep_s": round(sleep_s, 4),
|
|
"until_offset_s": round(until - _STRUCTURAL_RUN_START_S, 4),
|
|
},
|
|
)
|
|
await asyncio.sleep(sleep_s)
|
|
|
|
|
|
async def _discover_decode_residency(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
config: ReplayConfig,
|
|
) -> DecodeResidencyState:
|
|
residency = DecodeResidencyState()
|
|
if config.mechanism_name != "kvcache-centric":
|
|
return residency
|
|
|
|
for worker in config.topology.decode_workers:
|
|
_session_cache, max_total_num_tokens, reserved_decode_tokens = (
|
|
await _fetch_decode_server_state(
|
|
client=client,
|
|
server_url=worker.url,
|
|
)
|
|
)
|
|
if max_total_num_tokens <= 0:
|
|
continue
|
|
|
|
safety_headroom = max(
|
|
reserved_decode_tokens * 4,
|
|
max_total_num_tokens // 20,
|
|
8192,
|
|
)
|
|
residency.capacity_tokens[worker.url] = max_total_num_tokens
|
|
residency.headroom_tokens[worker.url] = min(
|
|
max_total_num_tokens,
|
|
safety_headroom,
|
|
)
|
|
residency.reserved_decode_tokens[worker.url] = reserved_decode_tokens
|
|
|
|
for worker in config.topology.prefill_workers:
|
|
_session_cache, max_total_num_tokens, _reserved_decode_tokens = (
|
|
await _fetch_decode_server_state(
|
|
client=client,
|
|
server_url=worker.url,
|
|
)
|
|
)
|
|
if max_total_num_tokens <= 0:
|
|
continue
|
|
|
|
safety_headroom = max(
|
|
max_total_num_tokens // 10,
|
|
16384,
|
|
)
|
|
residency.prefill_capacity_tokens[worker.url] = max_total_num_tokens
|
|
residency.prefill_headroom_tokens[worker.url] = min(
|
|
max_total_num_tokens,
|
|
safety_headroom,
|
|
)
|
|
return residency
|
|
|
|
|
|
def _estimate_session_resident_tokens(request: TraceRequest) -> int:
|
|
return request.input_length + request.output_length
|
|
|
|
|
|
def _seed_filter_reason(
|
|
*,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
inflight_decode_load: int | None = None,
|
|
) -> str | None:
|
|
if request.turn_id < config.kvcache_seed_min_turn_id:
|
|
return "seed-filter-early-turn"
|
|
if (
|
|
config.kvcache_seed_max_inflight_decode is not None
|
|
and inflight_decode_load is not None
|
|
and inflight_decode_load > config.kvcache_seed_max_inflight_decode
|
|
):
|
|
return "seed-filter-inflight-decode-load"
|
|
if (
|
|
config.kvcache_seed_allowed_session_ids is not None
|
|
and request.session_id not in config.kvcache_seed_allowed_session_ids
|
|
):
|
|
return "seed-filter-single-turn-session"
|
|
resident_tokens = _estimate_session_resident_tokens(request)
|
|
if (
|
|
config.kvcache_seed_max_resident_tokens is not None
|
|
and resident_tokens > config.kvcache_seed_max_resident_tokens
|
|
):
|
|
return "seed-filter-resident-tokens"
|
|
if (
|
|
config.kvcache_seed_max_output_tokens is not None
|
|
and request.output_length > config.kvcache_seed_max_output_tokens
|
|
):
|
|
return "seed-filter-output-tokens"
|
|
return None
|
|
|
|
|
|
def _prefill_priority_for_router_request(
|
|
*,
|
|
config: ReplayConfig,
|
|
direct_to_d_predicted: bool,
|
|
) -> int | None:
|
|
if not config.kvcache_prefill_priority_eviction:
|
|
return None
|
|
if direct_to_d_predicted:
|
|
return config.kvcache_prefill_direct_priority
|
|
return config.kvcache_prefill_normal_priority
|
|
|
|
|
|
def _inspect_direct_request(
|
|
*,
|
|
request: TraceRequest,
|
|
session: DirectSessionState,
|
|
) -> tuple[int, bool, bool]:
|
|
previous = session.last_trace_request
|
|
if previous is None:
|
|
return request.input_length, False, False
|
|
|
|
append_length = request.input_length - (
|
|
previous.input_length + previous.output_length
|
|
)
|
|
if append_length <= 0:
|
|
return request.input_length, False, True
|
|
|
|
return append_length, True, False
|
|
|
|
|
|
def _add_reserved_tokens(
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
delta_tokens: int,
|
|
) -> None:
|
|
if delta_tokens <= 0:
|
|
return
|
|
residency.reserved_tokens_by_server[server_url] = (
|
|
residency.reserved_tokens_by_server.get(server_url, 0) + delta_tokens
|
|
)
|
|
|
|
|
|
def _release_reserved_tokens(
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
delta_tokens: int,
|
|
) -> None:
|
|
if delta_tokens <= 0:
|
|
return
|
|
remaining = residency.reserved_tokens_by_server.get(server_url, 0) - delta_tokens
|
|
if remaining > 0:
|
|
residency.reserved_tokens_by_server[server_url] = remaining
|
|
else:
|
|
residency.reserved_tokens_by_server.pop(server_url, None)
|
|
|
|
|
|
def _add_prefill_reserved_tokens(
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
delta_tokens: int,
|
|
) -> None:
|
|
if delta_tokens <= 0:
|
|
return
|
|
residency.prefill_reserved_tokens_by_server[server_url] = (
|
|
residency.prefill_reserved_tokens_by_server.get(server_url, 0) + delta_tokens
|
|
)
|
|
|
|
|
|
def _release_prefill_reserved_tokens(
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
delta_tokens: int,
|
|
) -> None:
|
|
if delta_tokens <= 0:
|
|
return
|
|
remaining = (
|
|
residency.prefill_reserved_tokens_by_server.get(server_url, 0) - delta_tokens
|
|
)
|
|
if remaining > 0:
|
|
residency.prefill_reserved_tokens_by_server[server_url] = remaining
|
|
else:
|
|
residency.prefill_reserved_tokens_by_server.pop(server_url, None)
|
|
|
|
|
|
def _usable_capacity_tokens(
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
) -> int:
|
|
return max(
|
|
0,
|
|
residency.capacity_tokens.get(server_url, 0)
|
|
- residency.headroom_tokens.get(server_url, 0),
|
|
)
|
|
|
|
|
|
def _usable_prefill_backup_capacity_tokens(
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
) -> int:
|
|
return max(
|
|
0,
|
|
residency.prefill_capacity_tokens.get(server_url, 0)
|
|
- residency.prefill_headroom_tokens.get(server_url, 0),
|
|
)
|
|
|
|
|
|
def _eviction_suffix(evicted_sessions: int, prefill_backed_evictions: int) -> str:
|
|
if evicted_sessions <= 0:
|
|
return ""
|
|
if prefill_backed_evictions >= evicted_sessions:
|
|
return "-after-prefill-backed-eviction"
|
|
if prefill_backed_evictions > 0:
|
|
return "-after-mixed-eviction"
|
|
return "-after-eviction"
|
|
|
|
|
|
def _decode_session_soft_cap(
|
|
*,
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
request: TraceRequest,
|
|
) -> int:
|
|
target_tokens = max(1, _estimate_session_resident_tokens(request))
|
|
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
|
|
if usable_capacity_tokens <= 0:
|
|
usable_capacity_tokens = max(
|
|
0,
|
|
residency.capacity_tokens.get(server_url, 0)
|
|
- residency.headroom_tokens.get(server_url, 0),
|
|
)
|
|
if usable_capacity_tokens <= 0:
|
|
return 16
|
|
return max(1, min(16, usable_capacity_tokens // target_tokens))
|
|
|
|
|
|
def _should_admit_new_decode_session(
|
|
*,
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
request: TraceRequest,
|
|
session: DirectSessionState,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
treat_as_fresh_session: bool,
|
|
admission_mode: KvCacheAdmissionMode = "router",
|
|
) -> bool:
|
|
if (
|
|
not treat_as_fresh_session
|
|
and session.opened
|
|
and session.server_url == server_url
|
|
):
|
|
return True
|
|
if admission_mode == "worker":
|
|
# Defer the capacity decision to D's admit_direct_append (mode=seed),
|
|
# which checks real KV pool availability and runs LRU eviction. The
|
|
# local soft cap is router-mode only.
|
|
return True
|
|
open_sessions = sum(
|
|
1
|
|
for candidate in direct_sessions.values()
|
|
if candidate.opened and candidate.server_url == server_url
|
|
)
|
|
return open_sessions < _decode_session_soft_cap(
|
|
residency=residency,
|
|
server_url=server_url,
|
|
request=request,
|
|
)
|
|
|
|
|
|
async def _fetch_decode_load_snapshot(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
server_url: str,
|
|
) -> DecodeLoadSnapshot | None:
|
|
try:
|
|
response = await client.get(
|
|
f"{server_url.rstrip('/')}/v1/loads",
|
|
params={"include": "core,disagg"},
|
|
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
|
)
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
except Exception:
|
|
return None
|
|
|
|
loads = payload.get("loads")
|
|
if not isinstance(loads, list) or not loads:
|
|
return None
|
|
load = loads[0]
|
|
disagg = load.get("disaggregation") or {}
|
|
return DecodeLoadSnapshot(
|
|
timestamp_s=time.perf_counter(),
|
|
num_running_reqs=int(load.get("num_running_reqs", 0) or 0),
|
|
num_waiting_reqs=int(load.get("num_waiting_reqs", 0) or 0),
|
|
num_used_tokens=int(load.get("num_used_tokens", 0) or 0),
|
|
max_total_num_tokens=int(load.get("max_total_num_tokens", 0) or 0),
|
|
token_usage=float(load.get("token_usage", 0.0) or 0.0),
|
|
decode_prealloc_queue_reqs=int(
|
|
disagg.get("decode_prealloc_queue_reqs", 0) or 0
|
|
),
|
|
decode_transfer_queue_reqs=int(
|
|
disagg.get("decode_transfer_queue_reqs", 0) or 0
|
|
),
|
|
decode_retracted_queue_reqs=int(
|
|
disagg.get("decode_retracted_queue_reqs", 0) or 0
|
|
),
|
|
)
|
|
|
|
|
|
def _decode_load_backpressure_reason(
|
|
snapshot: DecodeLoadSnapshot | None,
|
|
*,
|
|
config: ReplayConfig,
|
|
routing_mode: Literal["direct", "seed"],
|
|
) -> str | None:
|
|
if snapshot is None:
|
|
return None
|
|
if routing_mode == "direct":
|
|
if (
|
|
config.kvcache_direct_max_decode_transfer_queue_reqs is not None
|
|
and snapshot.decode_transfer_queue_reqs
|
|
> config.kvcache_direct_max_decode_transfer_queue_reqs
|
|
):
|
|
return "d-transfer-backpressure"
|
|
if snapshot.decode_retracted_queue_reqs > 0 and snapshot.token_usage >= 0.99:
|
|
return "d-retracted"
|
|
if snapshot.token_usage >= 0.992:
|
|
return "d-token-usage-critical"
|
|
else:
|
|
if (
|
|
config.kvcache_seed_max_decode_transfer_queue_reqs is not None
|
|
and snapshot.decode_transfer_queue_reqs
|
|
> config.kvcache_seed_max_decode_transfer_queue_reqs
|
|
):
|
|
return "d-transfer-backpressure"
|
|
if snapshot.decode_retracted_queue_reqs > 0:
|
|
return "d-retracted"
|
|
if snapshot.token_usage >= 0.985:
|
|
return "d-token-usage-critical"
|
|
if routing_mode == "seed" and snapshot.token_usage >= 0.94 and (
|
|
snapshot.decode_prealloc_queue_reqs > 0
|
|
or snapshot.decode_transfer_queue_reqs > 0
|
|
):
|
|
return "d-prealloc-backpressure"
|
|
return None
|
|
|
|
|
|
def _is_decode_backpressure_reason(reason: str | None) -> bool:
|
|
return reason in {
|
|
"d-retracted",
|
|
"d-token-usage-critical",
|
|
"d-prealloc-backpressure",
|
|
"d-transfer-backpressure",
|
|
}
|
|
|
|
|
|
def _is_stale_decode_session_error(exc: Exception) -> bool:
|
|
return (
|
|
isinstance(exc, httpx.HTTPStatusError)
|
|
and exc.response.status_code == 400
|
|
)
|
|
|
|
|
|
# execution_mode substrings that signal D-side admission rejected this request.
|
|
# Used by _run_request to update state.session_d_rejects so KvAwarePolicy can
|
|
# migrate persistently-starved sessions to a different D next turn.
|
|
_ADMISSION_REJECTION_SUBSTRINGS = (
|
|
"session-cap",
|
|
"no-d-capacity",
|
|
"d-backpressure",
|
|
)
|
|
|
|
|
|
def _is_admission_rejection_mode(execution_mode: str) -> bool:
|
|
return any(token in execution_mode for token in _ADMISSION_REJECTION_SUBSTRINGS)
|
|
|
|
|
|
def _fallthrough_reason(
|
|
*,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decision,
|
|
direct_append_length: int | None,
|
|
direct_session_reused: bool,
|
|
direct_session_reset: bool,
|
|
) -> str:
|
|
"""Classify why a turn-2+ KVC request fell through to the seed/large-append branch.
|
|
|
|
Returns a short label suffix used in execution_mode strings to replace the
|
|
misleading 'large-append' label (TEAM_REPORT §2.7). In particular,
|
|
'session-not-resident' is the §1 starvation signature — direct_session_reused
|
|
is False because the session was never opened on the policy-chosen D.
|
|
"""
|
|
if not direct_session_reused:
|
|
return "session-not-resident"
|
|
if direct_session_reset:
|
|
return "session-was-evicted"
|
|
if direct_append_length is None:
|
|
return "no-direct-info"
|
|
if direct_append_length > config.kvcache_direct_max_uncached_tokens:
|
|
return "real-large-append"
|
|
if not _should_bypass_prefill(request=request, config=config, decision=decision):
|
|
return "policy-no-bypass"
|
|
return "other-large-append"
|
|
|
|
|
|
def _dynamic_decode_headroom_tokens(
|
|
*,
|
|
residency: DecodeResidencyState,
|
|
server_url: str,
|
|
snapshot: DecodeLoadSnapshot | None,
|
|
routing_mode: Literal["direct", "seed"],
|
|
) -> int:
|
|
if snapshot is None:
|
|
return residency.headroom_tokens.get(server_url, 0)
|
|
|
|
base_reserved = max(512, residency.reserved_decode_tokens.get(server_url, 0))
|
|
if routing_mode == "direct":
|
|
direct_queue_pressure = max(
|
|
1,
|
|
snapshot.decode_prealloc_queue_reqs
|
|
+ snapshot.decode_transfer_queue_reqs
|
|
+ snapshot.decode_retracted_queue_reqs,
|
|
)
|
|
capacity_divisor = 24
|
|
minimum_headroom = 4096
|
|
return max(
|
|
base_reserved * direct_queue_pressure,
|
|
snapshot.max_total_num_tokens // capacity_divisor,
|
|
minimum_headroom,
|
|
)
|
|
|
|
disagg_queued = (
|
|
snapshot.decode_prealloc_queue_reqs
|
|
+ snapshot.decode_transfer_queue_reqs
|
|
+ snapshot.decode_retracted_queue_reqs
|
|
)
|
|
active_decode_pressure = max(1, snapshot.num_running_reqs + disagg_queued)
|
|
capacity_divisor = 15
|
|
minimum_headroom = 12288
|
|
return max(
|
|
base_reserved * active_decode_pressure,
|
|
snapshot.max_total_num_tokens // capacity_divisor,
|
|
minimum_headroom,
|
|
)
|
|
|
|
|
|
def _commit_session_residency(
|
|
*,
|
|
residency: DecodeResidencyState,
|
|
session: DirectSessionState,
|
|
request: TraceRequest,
|
|
reserved_tokens: int,
|
|
) -> None:
|
|
_release_reserved_tokens(residency, session.server_url, reserved_tokens)
|
|
previous_tokens = session.resident_tokens if session.opened else 0
|
|
new_tokens = _estimate_session_resident_tokens(request)
|
|
delta_tokens = new_tokens - previous_tokens
|
|
if delta_tokens != 0:
|
|
residency.resident_tokens_by_server[session.server_url] = (
|
|
residency.resident_tokens_by_server.get(session.server_url, 0) + delta_tokens
|
|
)
|
|
session.opened = True
|
|
session.resident_tokens = new_tokens
|
|
session.last_trace_request = request
|
|
session.last_access_s = time.perf_counter()
|
|
if session.prefill_opened:
|
|
session.prefill_low_priority = True
|
|
|
|
|
|
def _commit_prefill_backup_residency(
|
|
*,
|
|
residency: DecodeResidencyState,
|
|
session: DirectSessionState,
|
|
request: TraceRequest,
|
|
prefill_url: str,
|
|
reserved_tokens: int,
|
|
) -> None:
|
|
_release_prefill_reserved_tokens(residency, prefill_url, reserved_tokens)
|
|
previous_tokens = session.prefill_resident_tokens if session.prefill_opened else 0
|
|
new_tokens = _estimate_session_resident_tokens(request)
|
|
delta_tokens = new_tokens - previous_tokens
|
|
if delta_tokens != 0:
|
|
residency.prefill_resident_tokens_by_server[prefill_url] = (
|
|
residency.prefill_resident_tokens_by_server.get(prefill_url, 0)
|
|
+ delta_tokens
|
|
)
|
|
session.prefill_server_url = prefill_url
|
|
session.prefill_opened = True
|
|
session.prefill_resident_tokens = new_tokens
|
|
session.prefill_last_access_s = time.perf_counter()
|
|
session.prefill_low_priority = session.opened
|
|
|
|
|
|
async def _close_prefill_session(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
session: DirectSessionState,
|
|
residency: DecodeResidencyState,
|
|
) -> None:
|
|
if not session.prefill_opened or session.prefill_server_url is None:
|
|
session.prefill_opened = False
|
|
session.prefill_resident_tokens = 0
|
|
session.prefill_low_priority = False
|
|
return
|
|
|
|
prefill_url = session.prefill_server_url
|
|
await _close_streaming_session(
|
|
client=client,
|
|
server_url=prefill_url,
|
|
session_id=session.session_id,
|
|
allow_missing=True,
|
|
)
|
|
remaining = (
|
|
residency.prefill_resident_tokens_by_server.get(prefill_url, 0)
|
|
- session.prefill_resident_tokens
|
|
)
|
|
if remaining > 0:
|
|
residency.prefill_resident_tokens_by_server[prefill_url] = remaining
|
|
else:
|
|
residency.prefill_resident_tokens_by_server.pop(prefill_url, None)
|
|
session.prefill_opened = False
|
|
session.prefill_resident_tokens = 0
|
|
session.prefill_low_priority = False
|
|
|
|
|
|
async def _close_decode_session(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
session: DirectSessionState,
|
|
residency: DecodeResidencyState,
|
|
evicting_for_capacity: bool = False,
|
|
) -> None:
|
|
if not session.opened:
|
|
session.resident_tokens = 0
|
|
return
|
|
await _close_streaming_session(
|
|
client=client,
|
|
server_url=session.server_url,
|
|
session_id=session.session_id,
|
|
allow_missing=True,
|
|
)
|
|
remaining = (
|
|
residency.resident_tokens_by_server.get(session.server_url, 0)
|
|
- session.resident_tokens
|
|
)
|
|
if remaining > 0:
|
|
residency.resident_tokens_by_server[session.server_url] = remaining
|
|
else:
|
|
residency.resident_tokens_by_server.pop(session.server_url, None)
|
|
session.opened = False
|
|
session.resident_tokens = 0
|
|
if session.prefill_opened:
|
|
residency.decode_evictions_prefill_backed += int(evicting_for_capacity)
|
|
session.prefill_low_priority = False
|
|
elif evicting_for_capacity:
|
|
residency.decode_evictions_without_prefill_backup += 1
|
|
|
|
|
|
async def _reserve_prefill_backup_capacity(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
prefill_url: str,
|
|
session: DirectSessionState,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
residency: DecodeResidencyState,
|
|
) -> tuple[bool, int, int]:
|
|
session_cache, max_total_num_tokens, _reserved_decode_tokens = (
|
|
await _fetch_decode_server_state(
|
|
client=client,
|
|
server_url=prefill_url,
|
|
)
|
|
)
|
|
if max_total_num_tokens > 0:
|
|
residency.prefill_capacity_tokens[prefill_url] = max_total_num_tokens
|
|
|
|
capacity_tokens = residency.prefill_capacity_tokens.get(prefill_url, 0)
|
|
headroom_tokens = residency.prefill_headroom_tokens.get(prefill_url, 0)
|
|
if capacity_tokens <= 0:
|
|
return True, 0, 0
|
|
low_occupancy_headroom_tokens = max(
|
|
headroom_tokens,
|
|
capacity_tokens // 2,
|
|
)
|
|
|
|
target_session_status = _find_session_cache_status(
|
|
session_cache,
|
|
session.session_id,
|
|
)
|
|
if (
|
|
isinstance(target_session_status, dict)
|
|
and bool(target_session_status.get("resident"))
|
|
):
|
|
current_tokens = int(target_session_status.get("resident_tokens", 0) or 0)
|
|
else:
|
|
current_tokens = (
|
|
session.prefill_resident_tokens
|
|
if session.prefill_opened and session.prefill_server_url == prefill_url
|
|
else 0
|
|
)
|
|
|
|
target_tokens = _estimate_session_resident_tokens(request)
|
|
required_extra_tokens = max(0, target_tokens - current_tokens)
|
|
evicted_sessions = 0
|
|
max_backup_sessions = max(1, capacity_tokens // max(1, target_tokens * 2))
|
|
max_backup_sessions = min(max_backup_sessions, 4)
|
|
available_tokens = int(session_cache.get("available_tokens", 0) or 0)
|
|
if available_tokens <= 0:
|
|
held_tokens = int(session_cache.get("held_tokens", 0) or 0)
|
|
available_tokens = max(0, capacity_tokens - held_tokens)
|
|
available_tokens -= residency.prefill_reserved_tokens_by_server.get(prefill_url, 0)
|
|
|
|
def has_enough_prefill_headroom() -> bool:
|
|
return available_tokens - required_extra_tokens >= low_occupancy_headroom_tokens
|
|
|
|
def prefill_backup_count() -> int:
|
|
return sum(
|
|
1
|
|
for candidate in direct_sessions.values()
|
|
if candidate.prefill_opened and candidate.prefill_server_url == prefill_url
|
|
)
|
|
|
|
while (
|
|
required_extra_tokens > 0
|
|
and (
|
|
not has_enough_prefill_headroom()
|
|
or (
|
|
not session.prefill_opened
|
|
and prefill_backup_count() >= max_backup_sessions
|
|
)
|
|
)
|
|
):
|
|
candidates = sorted(
|
|
(
|
|
candidate
|
|
for candidate in direct_sessions.values()
|
|
if candidate.prefill_opened
|
|
and candidate.prefill_server_url == prefill_url
|
|
and candidate.session_id != session.session_id
|
|
and candidate.active_requests <= 0
|
|
),
|
|
key=lambda candidate: (
|
|
0 if candidate.prefill_low_priority else 1,
|
|
candidate.prefill_last_access_s,
|
|
),
|
|
)
|
|
if not candidates:
|
|
break
|
|
freed_tokens = candidates[0].prefill_resident_tokens
|
|
await _close_prefill_session(
|
|
client=client,
|
|
session=candidates[0],
|
|
residency=residency,
|
|
)
|
|
available_tokens += freed_tokens
|
|
evicted_sessions += 1
|
|
|
|
if not has_enough_prefill_headroom():
|
|
return False, 0, evicted_sessions
|
|
|
|
_add_prefill_reserved_tokens(residency, prefill_url, required_extra_tokens)
|
|
return True, required_extra_tokens, evicted_sessions
|
|
|
|
|
|
async def _reserve_decode_session_capacity(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
config: ReplayConfig,
|
|
request: TraceRequest,
|
|
server_url: str,
|
|
session: DirectSessionState,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
residency: DecodeResidencyState,
|
|
treat_as_fresh_session: bool,
|
|
routing_mode: Literal["direct", "seed"],
|
|
admission_mode: KvCacheAdmissionMode,
|
|
) -> tuple[bool, int, int, int, str | None]:
|
|
if admission_mode == "router":
|
|
return await _reserve_decode_session_capacity_from_router_state(
|
|
client=client,
|
|
config=config,
|
|
request=request,
|
|
server_url=server_url,
|
|
session=session,
|
|
direct_sessions=direct_sessions,
|
|
residency=residency,
|
|
treat_as_fresh_session=treat_as_fresh_session,
|
|
routing_mode=routing_mode,
|
|
)
|
|
|
|
if treat_as_fresh_session and session.opened:
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=session,
|
|
residency=residency,
|
|
)
|
|
|
|
current_tokens = 0 if treat_as_fresh_session else session.resident_tokens
|
|
target_tokens = _estimate_session_resident_tokens(request)
|
|
required_extra_tokens = max(0, target_tokens - current_tokens)
|
|
prefill_backed_evictions = 0
|
|
if routing_mode == "direct" and not treat_as_fresh_session:
|
|
if not session.opened:
|
|
return False, 0, 0, 0, "d-session-not-resident"
|
|
admission = await _query_decode_direct_admission(
|
|
client=client,
|
|
server_url=server_url,
|
|
session_id=session.session_id,
|
|
uncached_input_tokens=max(0, request.input_length - current_tokens),
|
|
output_tokens=request.output_length,
|
|
mode="direct_append",
|
|
config=config,
|
|
residency=residency,
|
|
request_id=request.request_id,
|
|
turn_id=request.turn_id,
|
|
)
|
|
if not bool(admission.get("resident")):
|
|
return False, 0, 0, 0, str(admission.get("reason") or "d-session-not-resident")
|
|
if not bool(admission.get("can_admit")):
|
|
return (
|
|
False,
|
|
0,
|
|
int(admission.get("evicted_session_count", 0) or 0),
|
|
0,
|
|
str(admission.get("reason") or "d-no-space"),
|
|
)
|
|
reserved_tokens = int(
|
|
admission.get("required_tokens", required_extra_tokens)
|
|
or required_extra_tokens
|
|
)
|
|
_add_reserved_tokens(residency, server_url, reserved_tokens)
|
|
return (
|
|
True,
|
|
reserved_tokens,
|
|
int(admission.get("evicted_session_count", 0) or 0),
|
|
0,
|
|
None,
|
|
)
|
|
|
|
# Seed / reseed path: ask D itself via the seed-mode admission endpoint
|
|
# instead of estimating capacity from a stale router-state snapshot. D
|
|
# will run LRU eviction internally to make room. Falls through to the
|
|
# legacy router-state logic below if the endpoint is unavailable.
|
|
seed_admission = await _query_decode_direct_admission(
|
|
client=client,
|
|
server_url=server_url,
|
|
session_id=session.session_id,
|
|
uncached_input_tokens=max(0, request.input_length - current_tokens),
|
|
output_tokens=request.output_length,
|
|
mode="seed",
|
|
config=config,
|
|
residency=residency,
|
|
request_id=request.request_id,
|
|
turn_id=request.turn_id,
|
|
)
|
|
seed_reason = seed_admission.get("reason")
|
|
if seed_reason != "admission-query-failed":
|
|
if not bool(seed_admission.get("can_admit")):
|
|
return (
|
|
False,
|
|
0,
|
|
int(seed_admission.get("evicted_session_count", 0) or 0),
|
|
0,
|
|
str(seed_reason or "d-no-space"),
|
|
)
|
|
reserved_tokens = int(
|
|
seed_admission.get("required_tokens", required_extra_tokens)
|
|
or required_extra_tokens
|
|
)
|
|
_add_reserved_tokens(residency, server_url, reserved_tokens)
|
|
return (
|
|
True,
|
|
reserved_tokens,
|
|
int(seed_admission.get("evicted_session_count", 0) or 0),
|
|
0,
|
|
None,
|
|
)
|
|
|
|
session_cache, max_total_num_tokens, reserved_decode_tokens = (
|
|
await _fetch_decode_server_state(
|
|
client=client,
|
|
server_url=server_url,
|
|
)
|
|
)
|
|
if max_total_num_tokens > 0:
|
|
residency.capacity_tokens[server_url] = max_total_num_tokens
|
|
if reserved_decode_tokens > 0:
|
|
residency.reserved_decode_tokens[server_url] = reserved_decode_tokens
|
|
|
|
target_session_status = _find_session_cache_status(
|
|
session_cache,
|
|
session.session_id,
|
|
)
|
|
if routing_mode == "direct" and not (
|
|
isinstance(target_session_status, dict)
|
|
and bool(target_session_status.get("resident"))
|
|
):
|
|
return False, 0, 0, 0, "d-session-not-resident"
|
|
|
|
load_snapshot = await _fetch_decode_load_snapshot(
|
|
client=client,
|
|
server_url=server_url,
|
|
)
|
|
if load_snapshot is not None and load_snapshot.max_total_num_tokens > 0:
|
|
residency.capacity_tokens[server_url] = load_snapshot.max_total_num_tokens
|
|
backpressure_reason = _decode_load_backpressure_reason(
|
|
load_snapshot,
|
|
config=config,
|
|
routing_mode=routing_mode,
|
|
)
|
|
if backpressure_reason is not None:
|
|
return False, 0, 0, 0, backpressure_reason
|
|
|
|
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
|
|
evicted_sessions = 0
|
|
while (
|
|
required_extra_tokens > 0
|
|
and residency.resident_tokens_by_server.get(server_url, 0)
|
|
+ residency.reserved_tokens_by_server.get(server_url, 0)
|
|
+ required_extra_tokens
|
|
> usable_capacity_tokens + int(session_cache.get("idle_evictable_tokens", 0) or 0)
|
|
):
|
|
candidates = sorted(
|
|
(
|
|
candidate
|
|
for candidate in direct_sessions.values()
|
|
if candidate.opened
|
|
and candidate.server_url == server_url
|
|
and candidate.session_id != session.session_id
|
|
and candidate.active_requests <= 0
|
|
),
|
|
key=lambda candidate: candidate.last_access_s,
|
|
)
|
|
if not candidates:
|
|
break
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=candidates[0],
|
|
residency=residency,
|
|
evicting_for_capacity=True,
|
|
)
|
|
prefill_backed_evictions += int(candidates[0].prefill_opened)
|
|
evicted_sessions += 1
|
|
|
|
if evicted_sessions > 0:
|
|
load_snapshot = await _fetch_decode_load_snapshot(
|
|
client=client,
|
|
server_url=server_url,
|
|
)
|
|
session_cache, max_total_num_tokens, reserved_decode_tokens = (
|
|
await _fetch_decode_server_state(
|
|
client=client,
|
|
server_url=server_url,
|
|
)
|
|
)
|
|
if max_total_num_tokens > 0:
|
|
residency.capacity_tokens[server_url] = max_total_num_tokens
|
|
if reserved_decode_tokens > 0:
|
|
residency.reserved_decode_tokens[server_url] = reserved_decode_tokens
|
|
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
|
|
if load_snapshot is not None:
|
|
dynamic_headroom = _dynamic_decode_headroom_tokens(
|
|
residency=residency,
|
|
server_url=server_url,
|
|
snapshot=load_snapshot,
|
|
routing_mode=routing_mode,
|
|
)
|
|
residency.headroom_tokens[server_url] = min(
|
|
residency.capacity_tokens.get(server_url, dynamic_headroom),
|
|
dynamic_headroom,
|
|
)
|
|
usable_capacity_tokens = max(
|
|
0,
|
|
residency.capacity_tokens.get(server_url, 0) - dynamic_headroom,
|
|
)
|
|
|
|
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
|
|
if load_snapshot is not None:
|
|
dynamic_headroom = _dynamic_decode_headroom_tokens(
|
|
residency=residency,
|
|
server_url=server_url,
|
|
snapshot=load_snapshot,
|
|
routing_mode=routing_mode,
|
|
)
|
|
residency.headroom_tokens[server_url] = min(
|
|
residency.capacity_tokens.get(server_url, dynamic_headroom),
|
|
dynamic_headroom,
|
|
)
|
|
usable_capacity_tokens = max(
|
|
0,
|
|
residency.capacity_tokens.get(server_url, 0) - dynamic_headroom,
|
|
)
|
|
|
|
effective_used_tokens = (
|
|
load_snapshot.num_used_tokens
|
|
if load_snapshot is not None
|
|
else residency.resident_tokens_by_server.get(server_url, 0)
|
|
) + residency.reserved_tokens_by_server.get(server_url, 0)
|
|
idle_evictable_tokens = int(session_cache.get("idle_evictable_tokens", 0) or 0)
|
|
if (
|
|
routing_mode == "direct"
|
|
and isinstance(target_session_status, dict)
|
|
and bool(target_session_status.get("idle_evictable"))
|
|
):
|
|
idle_evictable_tokens = max(
|
|
0,
|
|
idle_evictable_tokens
|
|
- int(target_session_status.get("resident_tokens", 0) or 0),
|
|
)
|
|
|
|
if effective_used_tokens + required_extra_tokens > (
|
|
usable_capacity_tokens + idle_evictable_tokens
|
|
):
|
|
return False, 0, evicted_sessions, prefill_backed_evictions, "d-no-space"
|
|
|
|
_add_reserved_tokens(residency, server_url, required_extra_tokens)
|
|
return True, required_extra_tokens, evicted_sessions, prefill_backed_evictions, None
|
|
|
|
|
|
async def _reserve_decode_session_capacity_from_router_state(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
config: ReplayConfig,
|
|
request: TraceRequest,
|
|
server_url: str,
|
|
session: DirectSessionState,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
residency: DecodeResidencyState,
|
|
treat_as_fresh_session: bool,
|
|
routing_mode: Literal["direct", "seed"],
|
|
) -> tuple[bool, int, int, int, str | None]:
|
|
if treat_as_fresh_session and session.opened:
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=session,
|
|
residency=residency,
|
|
)
|
|
|
|
if routing_mode == "direct" and not session.opened:
|
|
return False, 0, 0, 0, "d-session-not-resident"
|
|
|
|
current_tokens = 0 if treat_as_fresh_session else session.resident_tokens
|
|
target_tokens = _estimate_session_resident_tokens(request)
|
|
required_extra_tokens = max(0, target_tokens - current_tokens)
|
|
usable_capacity_tokens = _usable_capacity_tokens(residency, server_url)
|
|
|
|
# If discovery failed, do not force every request down the P/D fallback path.
|
|
# The router can still preserve correctness; this only disables proactive
|
|
# capacity admission until the worker reports capacity again in a later run.
|
|
if usable_capacity_tokens <= 0:
|
|
_add_reserved_tokens(residency, server_url, required_extra_tokens)
|
|
return True, required_extra_tokens, 0, 0, None
|
|
|
|
evicted_sessions = 0
|
|
prefill_backed_evictions = 0
|
|
while (
|
|
required_extra_tokens > 0
|
|
and residency.resident_tokens_by_server.get(server_url, 0)
|
|
+ residency.reserved_tokens_by_server.get(server_url, 0)
|
|
+ required_extra_tokens
|
|
> usable_capacity_tokens
|
|
):
|
|
candidates = sorted(
|
|
(
|
|
candidate
|
|
for candidate in direct_sessions.values()
|
|
if candidate.opened
|
|
and candidate.server_url == server_url
|
|
and candidate.session_id != session.session_id
|
|
and candidate.active_requests <= 0
|
|
),
|
|
key=lambda candidate: candidate.last_access_s,
|
|
)
|
|
if not candidates:
|
|
break
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=candidates[0],
|
|
residency=residency,
|
|
evicting_for_capacity=True,
|
|
)
|
|
prefill_backed_evictions += int(candidates[0].prefill_opened)
|
|
evicted_sessions += 1
|
|
|
|
if (
|
|
residency.resident_tokens_by_server.get(server_url, 0)
|
|
+ residency.reserved_tokens_by_server.get(server_url, 0)
|
|
+ required_extra_tokens
|
|
> usable_capacity_tokens
|
|
):
|
|
return False, 0, evicted_sessions, prefill_backed_evictions, "d-no-space"
|
|
|
|
_add_reserved_tokens(residency, server_url, required_extra_tokens)
|
|
return True, required_extra_tokens, evicted_sessions, prefill_backed_evictions, None
|
|
|
|
|
|
def _build_direct_prompt(
|
|
*,
|
|
request: TraceRequest,
|
|
session: DirectSessionState,
|
|
) -> tuple[str, int, bool, bool]:
|
|
append_length, session_reused, session_reset = _inspect_direct_request(
|
|
request=request,
|
|
session=session,
|
|
)
|
|
if session_reset:
|
|
return build_synthetic_prompt(request), request.input_length, False, True
|
|
if not session_reused:
|
|
return build_synthetic_prompt(request), request.input_length, False, False
|
|
|
|
return (
|
|
build_synthetic_append_chunk(request, append_length),
|
|
append_length,
|
|
session_reused,
|
|
session_reset,
|
|
)
|
|
|
|
|
|
def _build_direct_full_input_ids(
|
|
request: TraceRequest,
|
|
*,
|
|
block_token_budget: int = 24,
|
|
) -> list[int]:
|
|
input_ids: list[int] = []
|
|
for hash_id in request.hash_ids:
|
|
base = 1000 + (hash_id % 3000)
|
|
for offset in range(block_token_budget):
|
|
input_ids.append(1000 + ((base + offset) % 3000))
|
|
|
|
while len(input_ids) < request.input_length:
|
|
input_ids.append(5000 + (len(input_ids) % 2000))
|
|
|
|
return input_ids[: request.input_length]
|
|
|
|
|
|
def _build_direct_append_input_ids(
|
|
request: TraceRequest,
|
|
append_length: int,
|
|
) -> list[int]:
|
|
if append_length <= 0:
|
|
return []
|
|
|
|
base = 9000 + ((request.chat_id + request.turn_id) % 2000)
|
|
return [
|
|
9000 + ((base + offset) % 2000)
|
|
for offset in range(append_length)
|
|
]
|
|
|
|
|
|
async def _invoke_plain_router(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decision,
|
|
execution_mode: str,
|
|
decode_residency: "DecodeResidencyState | None" = None,
|
|
) -> ExecutionResult:
|
|
prefill_priority = _prefill_priority_for_router_request(
|
|
config=config,
|
|
direct_to_d_predicted=False,
|
|
)
|
|
gen = await _invoke_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decode_worker_index=decision.decode_worker_index,
|
|
prefill_request_priority=prefill_priority,
|
|
decode_residency=decode_residency,
|
|
)
|
|
return ExecutionResult(
|
|
execution_mode=execution_mode,
|
|
actual_kv_transfer_blocks=decision.kv_transfer_blocks,
|
|
effective_input_length=request.input_length,
|
|
cached_tokens=gen.cached_tokens,
|
|
prefill_request_priority=prefill_priority,
|
|
session_reused=False,
|
|
session_reset=False,
|
|
latency_s=gen.latency_s,
|
|
ttft_s=gen.ttft_s,
|
|
tpot_s=gen.tpot_s,
|
|
actual_output_tokens=gen.actual_output_tokens,
|
|
requested_output_tokens=gen.requested_output_tokens,
|
|
finish_reason=gen.finish_reason,
|
|
)
|
|
|
|
|
|
async def _attempt_d_to_p_sync(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
prefill_url: str,
|
|
decode_session: DirectSessionState,
|
|
) -> dict | None:
|
|
"""Try to RDMA-dump session KV from the D that last held it to ``prefill_url``.
|
|
|
|
Returns a dict with status info on success/skip, or ``None`` on a
|
|
non-recoverable error. The caller falls back to normal re-prefill on
|
|
any failure. Each path emits a structural-log line so we can forensic
|
|
why sync skipped vs succeeded vs failed.
|
|
"""
|
|
if not config.enable_d_to_p_sync:
|
|
return None
|
|
source_d_url = decode_session.server_url
|
|
sid = request.session_id
|
|
rid = request.request_id
|
|
if not source_d_url:
|
|
await _structural_emit(
|
|
"d-to-p-sync.jsonl",
|
|
{"event": "skipped", "stage": "entry", "sid": sid, "rid": rid,
|
|
"reason": "no-source-d"},
|
|
)
|
|
return {"status": "skipped-no-source-d"}
|
|
# NB: do NOT gate on decode_session.opened. By the time we reach the
|
|
# fallback seeded_router, agentic has already flipped that flag to False
|
|
# in response to admission rejection. But the D-side scheduler's
|
|
# SessionAwareCache may STILL hold the session resident (release_session
|
|
# is only called explicitly, not from admission events). Let D be the
|
|
# source of truth via its own snapshot_dump response.
|
|
target_tokens = max(0, int(_estimate_session_resident_tokens(request)))
|
|
if target_tokens <= 0:
|
|
await _structural_emit(
|
|
"d-to-p-sync.jsonl",
|
|
{"event": "skipped", "stage": "entry", "sid": sid, "rid": rid,
|
|
"reason": "zero-target-tokens"},
|
|
)
|
|
return {"status": "skipped-zero-tokens"}
|
|
|
|
t_prep0 = time.perf_counter()
|
|
try:
|
|
prep_resp = await client.post(
|
|
f"{prefill_url}/_snapshot/prepare_receive",
|
|
json={
|
|
"session_id": request.session_id,
|
|
"num_tokens": target_tokens,
|
|
},
|
|
timeout=30.0,
|
|
)
|
|
prep_resp.raise_for_status()
|
|
prep = prep_resp.json()
|
|
except Exception as exc:
|
|
await _structural_emit(
|
|
"d-to-p-sync.jsonl",
|
|
{"event": "failed", "stage": "prepare", "sid": sid, "rid": rid,
|
|
"error": repr(exc)[:200]},
|
|
)
|
|
return {"status": "prepare-failed", "error": repr(exc)}
|
|
t_prep1 = time.perf_counter()
|
|
if not prep.get("ok"):
|
|
await _structural_emit(
|
|
"d-to-p-sync.jsonl",
|
|
{"event": "skipped", "stage": "prepare", "sid": sid, "rid": rid,
|
|
"reason": prep.get("reason"),
|
|
"prepare_dur_ms": round((t_prep1 - t_prep0) * 1000, 2)},
|
|
)
|
|
return {"status": "prepare-not-ok", "reason": prep.get("reason")}
|
|
|
|
t_dump0 = time.perf_counter()
|
|
try:
|
|
dump_resp = await client.post(
|
|
f"{source_d_url}/_snapshot/dump",
|
|
json={
|
|
"session_id": request.session_id,
|
|
"target_snapshot_session_id": prep["snapshot_session_id"],
|
|
"target_k_base_ptrs": prep["k_base_ptrs"],
|
|
"target_v_base_ptrs": prep["v_base_ptrs"],
|
|
"target_slot_indices": prep["slot_indices"],
|
|
"target_stride_k_bytes": prep["stride_k_bytes"],
|
|
"target_stride_v_bytes": prep["stride_v_bytes"],
|
|
},
|
|
timeout=60.0,
|
|
)
|
|
dump_resp.raise_for_status()
|
|
dump = dump_resp.json()
|
|
except Exception as exc:
|
|
await _structural_emit(
|
|
"d-to-p-sync.jsonl",
|
|
{"event": "failed", "stage": "dump", "sid": sid, "rid": rid,
|
|
"error": repr(exc)[:200]},
|
|
)
|
|
return {"status": "dump-failed", "error": repr(exc)}
|
|
t_dump1 = time.perf_counter()
|
|
if not dump.get("ok"):
|
|
await _structural_emit(
|
|
"d-to-p-sync.jsonl",
|
|
{"event": "skipped", "stage": "dump", "sid": sid, "rid": rid,
|
|
"reason": dump.get("reason"),
|
|
"dump_dur_ms": round((t_dump1 - t_dump0) * 1000, 2),
|
|
"kv_committed_len": int(dump.get("kv_committed_len", 0))},
|
|
)
|
|
return {"status": "dump-not-ok", "reason": dump.get("reason"),
|
|
"bytes_pushed": dump.get("bytes_pushed", 0)}
|
|
|
|
# We need token_ids for radix insert. The caller has request.input_token_ids
|
|
# for the first N — use that as best-available approximation.
|
|
tokens = list(getattr(request, "input_token_ids", []) or [])
|
|
if not tokens:
|
|
# No token_ids available — can't insert into radix. P will fall back
|
|
# to normal prefill but will have wasted slots. Discard.
|
|
try:
|
|
await client.post(
|
|
f"{prefill_url}/_snapshot/finalize_ingest",
|
|
json={
|
|
"session_id": request.session_id,
|
|
"token_ids": [],
|
|
"slot_indices": prep["slot_indices"],
|
|
},
|
|
timeout=15.0,
|
|
)
|
|
except Exception:
|
|
pass
|
|
await _structural_emit(
|
|
"d-to-p-sync.jsonl",
|
|
{"event": "skipped", "stage": "post-dump", "sid": sid, "rid": rid,
|
|
"reason": "no-input-token-ids",
|
|
"bytes_pushed": int(dump.get("bytes_pushed", 0))},
|
|
)
|
|
return {"status": "no-tokens-discard", "bytes_pushed": dump.get("bytes_pushed", 0)}
|
|
|
|
n = min(len(tokens), len(prep["slot_indices"]))
|
|
t_fin0 = time.perf_counter()
|
|
try:
|
|
fin_resp = await client.post(
|
|
f"{prefill_url}/_snapshot/finalize_ingest",
|
|
json={
|
|
"session_id": request.session_id,
|
|
"token_ids": tokens[:n],
|
|
"slot_indices": prep["slot_indices"][:n],
|
|
},
|
|
timeout=30.0,
|
|
)
|
|
fin_resp.raise_for_status()
|
|
fin = fin_resp.json()
|
|
except Exception as exc:
|
|
await _structural_emit(
|
|
"d-to-p-sync.jsonl",
|
|
{"event": "failed", "stage": "finalize", "sid": sid, "rid": rid,
|
|
"error": repr(exc)[:200],
|
|
"bytes_pushed": int(dump.get("bytes_pushed", 0))},
|
|
)
|
|
return {"status": "finalize-failed", "error": repr(exc),
|
|
"bytes_pushed": dump.get("bytes_pushed", 0)}
|
|
t_fin1 = time.perf_counter()
|
|
if not fin.get("ok"):
|
|
await _structural_emit(
|
|
"d-to-p-sync.jsonl",
|
|
{"event": "skipped", "stage": "finalize", "sid": sid, "rid": rid,
|
|
"reason": fin.get("reason"),
|
|
"bytes_pushed": int(dump.get("bytes_pushed", 0))},
|
|
)
|
|
return {"status": "finalize-not-ok", "reason": fin.get("reason"),
|
|
"bytes_pushed": dump.get("bytes_pushed", 0)}
|
|
await _structural_emit(
|
|
"d-to-p-sync.jsonl",
|
|
{"event": "ok", "sid": sid, "rid": rid,
|
|
"bytes_pushed": int(dump.get("bytes_pushed", 0)),
|
|
"kv_committed_len": int(dump.get("kv_committed_len", 0)),
|
|
"inserted_prefix_len": int(fin.get("inserted_prefix_len", 0)),
|
|
"prepare_dur_ms": round((t_prep1 - t_prep0) * 1000, 2),
|
|
"dump_dur_ms": round((t_dump1 - t_dump0) * 1000, 2),
|
|
"finalize_dur_ms": round((t_fin1 - t_fin0) * 1000, 2),
|
|
"snapshot_session_id": prep.get("snapshot_session_id")},
|
|
)
|
|
return {
|
|
"status": "ok",
|
|
"bytes_pushed": int(dump.get("bytes_pushed", 0)),
|
|
"inserted_prefix_len": int(fin.get("inserted_prefix_len", 0)),
|
|
"snapshot_session_id": prep.get("snapshot_session_id"),
|
|
}
|
|
|
|
|
|
async def _invoke_kvcache_seeded_router(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decision,
|
|
prefill_url: str,
|
|
decode_session: DirectSessionState,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
direct_session_lock: asyncio.Lock,
|
|
decode_residency: DecodeResidencyState,
|
|
reserved_tokens: int,
|
|
execution_mode: str,
|
|
) -> ExecutionResult:
|
|
keep_prefill_backup = False
|
|
prefill_reserved_tokens = 0
|
|
async with direct_session_lock:
|
|
if config.kvcache_prefill_backup_policy == "capacity-backup":
|
|
keep_prefill_backup, prefill_reserved_tokens, _prefill_evicted = (
|
|
await _reserve_prefill_backup_capacity(
|
|
client=client,
|
|
request=request,
|
|
prefill_url=prefill_url,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
residency=decode_residency,
|
|
)
|
|
)
|
|
if (
|
|
decode_session.prefill_opened
|
|
and decode_session.prefill_server_url != prefill_url
|
|
):
|
|
await _close_prefill_session(
|
|
client=client,
|
|
session=decode_session,
|
|
residency=decode_residency,
|
|
)
|
|
|
|
prefill_session_newly_opened = False
|
|
async with direct_session_lock:
|
|
if not decode_session.prefill_opened:
|
|
await _open_streaming_session(
|
|
client=client,
|
|
server_url=prefill_url,
|
|
session_id=request.session_id,
|
|
request=request,
|
|
)
|
|
decode_session.prefill_opened = True
|
|
decode_session.prefill_server_url = prefill_url
|
|
prefill_session_newly_opened = True
|
|
|
|
# D→P snapshot push (Phase 3) — best-effort; on any failure we silently
|
|
# fall back to the existing re-prefill path. The result is logged for
|
|
# post-hoc analysis but does not affect correctness.
|
|
if config.enable_d_to_p_sync:
|
|
sync_result = await _attempt_d_to_p_sync(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
prefill_url=prefill_url,
|
|
decode_session=decode_session,
|
|
)
|
|
if sync_result is not None and sync_result.get("status") != "ok":
|
|
logger.info(
|
|
"d_to_p_sync sid=%s rid=%s skipped: %s",
|
|
request.session_id, request.request_id, sync_result,
|
|
)
|
|
elif sync_result and sync_result.get("status") == "ok":
|
|
logger.info(
|
|
"d_to_p_sync sid=%s rid=%s pushed=%d ingested_prefix=%d",
|
|
request.session_id,
|
|
request.request_id,
|
|
sync_result.get("bytes_pushed", 0),
|
|
sync_result.get("inserted_prefix_len", 0),
|
|
)
|
|
|
|
decode_session_newly_opened = False
|
|
try:
|
|
prefill_priority = _prefill_priority_for_router_request(
|
|
config=config,
|
|
direct_to_d_predicted=True,
|
|
)
|
|
async with direct_session_lock:
|
|
if not decode_session.opened:
|
|
await _open_streaming_session(
|
|
client=client,
|
|
server_url=decode_session.server_url,
|
|
session_id=request.session_id,
|
|
request=request,
|
|
)
|
|
decode_session.opened = True
|
|
decode_session_newly_opened = True
|
|
decode_session.active_requests += 1
|
|
gen = await _invoke_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decode_worker_index=decision.decode_worker_index,
|
|
session_id=request.session_id,
|
|
prefill_request_priority=prefill_priority,
|
|
decode_residency=decode_residency,
|
|
)
|
|
except Exception:
|
|
async with direct_session_lock:
|
|
decode_session.active_requests = max(0, decode_session.active_requests - 1)
|
|
_release_reserved_tokens(
|
|
decode_residency,
|
|
decode_session.server_url,
|
|
reserved_tokens,
|
|
)
|
|
_release_prefill_reserved_tokens(
|
|
decode_residency,
|
|
prefill_url,
|
|
prefill_reserved_tokens,
|
|
)
|
|
if decode_session_newly_opened:
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=decode_session,
|
|
residency=decode_residency,
|
|
)
|
|
if prefill_session_newly_opened:
|
|
await _close_prefill_session(
|
|
client=client,
|
|
session=decode_session,
|
|
residency=decode_residency,
|
|
)
|
|
raise
|
|
|
|
async with direct_session_lock:
|
|
decode_session.active_requests = max(0, decode_session.active_requests - 1)
|
|
if keep_prefill_backup:
|
|
_commit_prefill_backup_residency(
|
|
residency=decode_residency,
|
|
session=decode_session,
|
|
request=request,
|
|
prefill_url=prefill_url,
|
|
reserved_tokens=prefill_reserved_tokens,
|
|
)
|
|
else:
|
|
_release_prefill_reserved_tokens(
|
|
decode_residency,
|
|
prefill_url,
|
|
prefill_reserved_tokens,
|
|
)
|
|
await _close_prefill_session(
|
|
client=client,
|
|
session=decode_session,
|
|
residency=decode_residency,
|
|
)
|
|
_commit_session_residency(
|
|
residency=decode_residency,
|
|
session=decode_session,
|
|
request=request,
|
|
reserved_tokens=reserved_tokens,
|
|
)
|
|
return ExecutionResult(
|
|
execution_mode=execution_mode,
|
|
actual_kv_transfer_blocks=decision.kv_transfer_blocks,
|
|
effective_input_length=request.input_length,
|
|
cached_tokens=gen.cached_tokens,
|
|
prefill_request_priority=prefill_priority,
|
|
session_reused=False,
|
|
session_reset=False,
|
|
latency_s=gen.latency_s,
|
|
ttft_s=gen.ttft_s,
|
|
tpot_s=gen.tpot_s,
|
|
actual_output_tokens=gen.actual_output_tokens,
|
|
requested_output_tokens=gen.requested_output_tokens,
|
|
finish_reason=gen.finish_reason,
|
|
)
|
|
|
|
|
|
async def _execute_request(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decision,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
direct_session_lock: asyncio.Lock,
|
|
decode_residency: DecodeResidencyState,
|
|
) -> ExecutionResult:
|
|
if config.mechanism_name == "pd-disaggregation":
|
|
if not config.router_url:
|
|
raise ValueError("router_url is required for pd-disaggregation replay")
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode="pd-disaggregation-router",
|
|
decode_residency=decode_residency,
|
|
)
|
|
|
|
if config.mechanism_name == "pd-colo":
|
|
if not config.router_url:
|
|
raise ValueError("router_url is required for pd-colo replay")
|
|
result = await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode="dp-colo-router",
|
|
decode_residency=decode_residency,
|
|
)
|
|
return replace(result, actual_kv_transfer_blocks=0)
|
|
|
|
if config.mechanism_name == "kvcache-centric":
|
|
if not config.router_url:
|
|
raise ValueError("router_url is required for kvcache-centric replay")
|
|
if not config.topology.decode_workers:
|
|
raise ValueError("kvcache-centric mechanism requires at least one decode worker")
|
|
prefill_url = _worker_url_by_id(
|
|
config.topology.prefill_workers,
|
|
decision.prefill_worker_id,
|
|
)
|
|
decode_url = config.topology.decode_workers[decision.decode_worker_index].url
|
|
async with direct_session_lock:
|
|
decode_session = direct_sessions.get(request.session_id)
|
|
if decode_session is None:
|
|
decode_session = DirectSessionState(
|
|
session_id=request.session_id,
|
|
server_url=decode_url,
|
|
)
|
|
direct_sessions[request.session_id] = decode_session
|
|
elif decode_session.server_url != decode_url and decode_session.opened:
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=decode_session,
|
|
residency=decode_residency,
|
|
)
|
|
decode_session.server_url = decode_url
|
|
else:
|
|
decode_session.server_url = decode_url
|
|
|
|
direct_append_length: int | None = None
|
|
direct_session_reused = False
|
|
direct_session_reset = False
|
|
if request.turn_id > 1:
|
|
async with direct_session_lock:
|
|
(
|
|
direct_append_length,
|
|
direct_session_reused,
|
|
direct_session_reset,
|
|
) = _inspect_direct_request(
|
|
request=request,
|
|
session=decode_session,
|
|
)
|
|
|
|
if request.turn_id == 1:
|
|
seed_filter_reason = _seed_filter_reason(
|
|
request=request,
|
|
config=config,
|
|
inflight_decode_load=decision.inflight_decode_load,
|
|
)
|
|
if seed_filter_reason is not None:
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode=f"pd-router-turn1-{seed_filter_reason}",
|
|
decode_residency=decode_residency,
|
|
)
|
|
async with direct_session_lock:
|
|
admit_new_decode_session = _should_admit_new_decode_session(
|
|
residency=decode_residency,
|
|
server_url=decode_url,
|
|
request=request,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
treat_as_fresh_session=True,
|
|
admission_mode=config.kvcache_admission_mode,
|
|
)
|
|
if not admit_new_decode_session:
|
|
can_seed = False
|
|
reserved_tokens = 0
|
|
seed_reason = "d-session-cap"
|
|
else:
|
|
can_seed, reserved_tokens, _evicted, _p_backed, seed_reason = (
|
|
await _reserve_decode_session_capacity(
|
|
client=client,
|
|
config=config,
|
|
request=request,
|
|
server_url=decode_url,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
residency=decode_residency,
|
|
treat_as_fresh_session=True,
|
|
routing_mode="seed",
|
|
admission_mode=config.kvcache_admission_mode,
|
|
)
|
|
)
|
|
if seed_reason == "d-session-cap":
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode="pd-router-turn1-session-cap",
|
|
decode_residency=decode_residency,
|
|
)
|
|
if can_seed:
|
|
return await _invoke_kvcache_seeded_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
prefill_url=prefill_url,
|
|
decode_session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
direct_session_lock=direct_session_lock,
|
|
decode_residency=decode_residency,
|
|
reserved_tokens=reserved_tokens,
|
|
execution_mode="pd-router-turn1-seed",
|
|
)
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode=(
|
|
"pd-router-turn1-d-backpressure"
|
|
if seed_reason is not None and seed_reason != "d-no-space"
|
|
else "pd-router-turn1-no-d-capacity"
|
|
),
|
|
decode_residency=decode_residency,
|
|
)
|
|
|
|
if (
|
|
_should_bypass_prefill(
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
)
|
|
and direct_append_length is not None
|
|
and direct_session_reused
|
|
and not direct_session_reset
|
|
and direct_append_length <= config.kvcache_direct_max_uncached_tokens
|
|
):
|
|
async with direct_session_lock:
|
|
can_direct = (
|
|
decode_session.opened
|
|
and decode_session.server_url == decode_url
|
|
and direct_session_reused
|
|
and not direct_session_reset
|
|
)
|
|
direct_reserved_tokens = 0
|
|
direct_reason: str | None = None
|
|
if can_direct:
|
|
async with direct_session_lock:
|
|
(
|
|
can_direct,
|
|
direct_reserved_tokens,
|
|
_evicted,
|
|
_p_backed,
|
|
direct_reason,
|
|
) = (
|
|
await _reserve_decode_session_capacity(
|
|
client=client,
|
|
config=config,
|
|
request=request,
|
|
server_url=decode_url,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
residency=decode_residency,
|
|
treat_as_fresh_session=False,
|
|
routing_mode="direct",
|
|
admission_mode=config.kvcache_admission_mode,
|
|
)
|
|
)
|
|
if can_direct:
|
|
try:
|
|
return await _invoke_decode_session_direct(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
direct_sessions=direct_sessions,
|
|
direct_session_lock=direct_session_lock,
|
|
decode_residency=decode_residency,
|
|
reserved_tokens=direct_reserved_tokens,
|
|
)
|
|
except Exception as exc:
|
|
if not _is_stale_decode_session_error(exc):
|
|
raise
|
|
async with direct_session_lock:
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=decode_session,
|
|
residency=decode_residency,
|
|
)
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode="pd-router-fallback-stale-d-session",
|
|
decode_residency=decode_residency,
|
|
)
|
|
if _is_decode_backpressure_reason(direct_reason):
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode="pd-router-fallback-d-backpressure",
|
|
decode_residency=decode_residency,
|
|
)
|
|
|
|
seed_filter_reason = _seed_filter_reason(
|
|
request=request,
|
|
config=config,
|
|
inflight_decode_load=decision.inflight_decode_load,
|
|
)
|
|
if seed_filter_reason is not None:
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode=f"pd-router-fallback-{seed_filter_reason}",
|
|
decode_residency=decode_residency,
|
|
)
|
|
async with direct_session_lock:
|
|
admit_new_decode_session = _should_admit_new_decode_session(
|
|
residency=decode_residency,
|
|
server_url=decode_url,
|
|
request=request,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
treat_as_fresh_session=True,
|
|
admission_mode=config.kvcache_admission_mode,
|
|
)
|
|
if not admit_new_decode_session:
|
|
can_seed = False
|
|
reserved_tokens = 0
|
|
evicted_sessions = 0
|
|
prefill_backed_evictions = 0
|
|
seed_reason = "d-session-cap"
|
|
else:
|
|
(
|
|
can_seed,
|
|
reserved_tokens,
|
|
evicted_sessions,
|
|
prefill_backed_evictions,
|
|
seed_reason,
|
|
) = (
|
|
await _reserve_decode_session_capacity(
|
|
client=client,
|
|
config=config,
|
|
request=request,
|
|
server_url=decode_url,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
residency=decode_residency,
|
|
treat_as_fresh_session=True,
|
|
routing_mode="seed",
|
|
admission_mode=config.kvcache_admission_mode,
|
|
)
|
|
)
|
|
if seed_reason == "d-session-cap":
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode="pd-router-fallback-session-cap",
|
|
decode_residency=decode_residency,
|
|
)
|
|
if can_seed:
|
|
return await _invoke_kvcache_seeded_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
prefill_url=prefill_url,
|
|
decode_session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
direct_session_lock=direct_session_lock,
|
|
decode_residency=decode_residency,
|
|
reserved_tokens=reserved_tokens,
|
|
execution_mode=(
|
|
"pd-router-d-session-reseed"
|
|
+ _eviction_suffix(
|
|
evicted_sessions,
|
|
prefill_backed_evictions,
|
|
)
|
|
),
|
|
)
|
|
return await _invoke_plain_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode=(
|
|
"pd-router-fallback-d-backpressure"
|
|
if _is_decode_backpressure_reason(seed_reason)
|
|
else "pd-router-fallback-no-d-capacity"
|
|
),
|
|
decode_residency=decode_residency,
|
|
)
|
|
|
|
# TEAM_REPORT §2.7: 'large-append' is misleading — most fallthroughs are
|
|
# actually 'session-not-resident-on-pinned-D' (§1 starvation). Classify
|
|
# the real reason and embed it in the execution_mode label.
|
|
fallthrough = _fallthrough_reason(
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
direct_append_length=direct_append_length,
|
|
direct_session_reused=direct_session_reused,
|
|
direct_session_reset=direct_session_reset,
|
|
)
|
|
seed_filter_reason = _seed_filter_reason(
|
|
request=request,
|
|
config=config,
|
|
inflight_decode_load=decision.inflight_decode_load,
|
|
)
|
|
if seed_filter_reason is not None:
|
|
return await _invoke_plain_router(
|
|
request=request,
|
|
client=client,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode=f"pd-router-fallback-{fallthrough}-{seed_filter_reason}",
|
|
decode_residency=decode_residency,
|
|
)
|
|
async with direct_session_lock:
|
|
admit_new_decode_session = _should_admit_new_decode_session(
|
|
residency=decode_residency,
|
|
server_url=decode_url,
|
|
request=request,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
treat_as_fresh_session=True,
|
|
)
|
|
if not admit_new_decode_session:
|
|
can_seed = False
|
|
reserved_tokens = 0
|
|
evicted_sessions = 0
|
|
prefill_backed_evictions = 0
|
|
seed_reason = "d-session-cap"
|
|
else:
|
|
(
|
|
can_seed,
|
|
reserved_tokens,
|
|
evicted_sessions,
|
|
prefill_backed_evictions,
|
|
seed_reason,
|
|
) = (
|
|
await _reserve_decode_session_capacity(
|
|
client=client,
|
|
config=config,
|
|
request=request,
|
|
server_url=decode_url,
|
|
session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
residency=decode_residency,
|
|
treat_as_fresh_session=True,
|
|
routing_mode="seed",
|
|
admission_mode=config.kvcache_admission_mode,
|
|
)
|
|
)
|
|
if seed_reason == "d-session-cap":
|
|
return await _invoke_plain_router(
|
|
request=request,
|
|
client=client,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode=f"pd-router-fallback-{fallthrough}-session-cap",
|
|
decode_residency=decode_residency,
|
|
)
|
|
if can_seed:
|
|
return await _invoke_kvcache_seeded_router(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
decision=decision,
|
|
prefill_url=prefill_url,
|
|
decode_session=decode_session,
|
|
direct_sessions=direct_sessions,
|
|
direct_session_lock=direct_session_lock,
|
|
decode_residency=decode_residency,
|
|
reserved_tokens=reserved_tokens,
|
|
execution_mode=(
|
|
f"pd-router-{fallthrough}-reseed"
|
|
+ _eviction_suffix(
|
|
evicted_sessions,
|
|
prefill_backed_evictions,
|
|
)
|
|
),
|
|
)
|
|
# Preserve seed_reason in the label so migration feedback fires for
|
|
# 'd-no-space' / 'd-*-backpressure' (matched via _is_admission_rejection_mode).
|
|
if _is_decode_backpressure_reason(seed_reason):
|
|
mode_label = f"pd-router-fallback-{fallthrough}-d-backpressure"
|
|
elif seed_reason == "d-no-space":
|
|
mode_label = f"pd-router-fallback-{fallthrough}-no-d-capacity"
|
|
else:
|
|
mode_label = f"pd-router-fallback-{fallthrough}"
|
|
return await _invoke_plain_router(
|
|
request=request,
|
|
client=client,
|
|
config=config,
|
|
decision=decision,
|
|
execution_mode=mode_label,
|
|
decode_residency=decode_residency,
|
|
)
|
|
|
|
raise ValueError(f"Unsupported mechanism: {config.mechanism_name}")
|
|
|
|
|
|
async def _invoke_direct(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decision,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
direct_session_lock: asyncio.Lock,
|
|
) -> ExecutionResult:
|
|
direct_workers = config.topology.direct_workers
|
|
if not direct_workers:
|
|
raise ValueError("pd-colo mechanism requires at least one direct worker")
|
|
|
|
server_url = direct_workers[decision.decode_worker_index].url
|
|
session = direct_sessions.get(request.session_id)
|
|
if session is None or session.server_url != server_url:
|
|
session = DirectSessionState(session_id=request.session_id, server_url=server_url)
|
|
direct_sessions[request.session_id] = session
|
|
|
|
return await _invoke_session_direct(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
session=session,
|
|
execution_mode="pd-colo-direct-session",
|
|
direct_session_lock=direct_session_lock,
|
|
)
|
|
|
|
|
|
async def _invoke_session_direct(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
session: DirectSessionState,
|
|
execution_mode: str,
|
|
decode_residency: DecodeResidencyState | None = None,
|
|
reserved_tokens: int = 0,
|
|
direct_session_lock: asyncio.Lock | None = None,
|
|
) -> ExecutionResult:
|
|
if decode_residency is not None and config.enable_backpressure:
|
|
await _wait_for_decode_pause(
|
|
config=config,
|
|
residency=decode_residency,
|
|
server_url=session.server_url,
|
|
request_id=request.request_id,
|
|
session_id=session.session_id,
|
|
)
|
|
_prompt, effective_input_length, session_reused, session_reset = _build_direct_prompt(
|
|
request=request,
|
|
session=session,
|
|
)
|
|
if session_reused:
|
|
input_ids = _build_direct_append_input_ids(request, effective_input_length)
|
|
else:
|
|
input_ids = _build_direct_full_input_ids(request)
|
|
if session.opened and (session_reset or not session_reused):
|
|
if decode_residency is not None:
|
|
await _close_decode_session(
|
|
client=client,
|
|
session=session,
|
|
residency=decode_residency,
|
|
)
|
|
else:
|
|
await _close_streaming_session(
|
|
client=client,
|
|
server_url=session.server_url,
|
|
session_id=session.session_id,
|
|
allow_missing=True,
|
|
)
|
|
session.opened = False
|
|
session.resident_tokens = 0
|
|
if not session.opened:
|
|
await _open_streaming_session(
|
|
client=client,
|
|
server_url=session.server_url,
|
|
session_id=session.session_id,
|
|
request=request,
|
|
)
|
|
session.opened = True
|
|
if direct_session_lock is not None:
|
|
async with direct_session_lock:
|
|
session.active_requests += 1
|
|
|
|
try:
|
|
gen = await _invoke_generate(
|
|
client=client,
|
|
base_url=session.server_url,
|
|
headers={"x-request-id": request.request_id},
|
|
payload={
|
|
"input_ids": input_ids,
|
|
"sampling_params": {
|
|
"temperature": 0,
|
|
"max_new_tokens": max(1, request.output_length),
|
|
"min_new_tokens": max(1, request.output_length),
|
|
"ignore_eos": True,
|
|
"no_stop_trim": True,
|
|
"skip_special_tokens": False,
|
|
},
|
|
"session_params": {"id": session.session_id},
|
|
"stream": config.stream,
|
|
},
|
|
timeout_s=config.timeout_s,
|
|
stream_idle_timeout_s=config.stream_idle_timeout_s,
|
|
stream=config.stream,
|
|
)
|
|
finally:
|
|
if direct_session_lock is not None:
|
|
async with direct_session_lock:
|
|
session.active_requests = max(0, session.active_requests - 1)
|
|
if decode_residency is not None:
|
|
_commit_session_residency(
|
|
residency=decode_residency,
|
|
session=session,
|
|
request=request,
|
|
reserved_tokens=reserved_tokens,
|
|
)
|
|
else:
|
|
session.last_trace_request = request
|
|
session.last_access_s = time.perf_counter()
|
|
return ExecutionResult(
|
|
execution_mode=execution_mode,
|
|
actual_kv_transfer_blocks=0,
|
|
effective_input_length=len(input_ids),
|
|
cached_tokens=gen.cached_tokens,
|
|
session_reused=session_reused,
|
|
session_reset=session_reset,
|
|
latency_s=gen.latency_s,
|
|
ttft_s=gen.ttft_s,
|
|
tpot_s=gen.tpot_s,
|
|
actual_output_tokens=gen.actual_output_tokens,
|
|
requested_output_tokens=gen.requested_output_tokens,
|
|
finish_reason=gen.finish_reason,
|
|
)
|
|
|
|
|
|
async def _invoke_decode_session_direct(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decision,
|
|
direct_sessions: dict[str, DirectSessionState],
|
|
direct_session_lock: asyncio.Lock,
|
|
decode_residency: DecodeResidencyState,
|
|
reserved_tokens: int,
|
|
) -> ExecutionResult:
|
|
decode_url = config.topology.decode_workers[decision.decode_worker_index].url
|
|
session = direct_sessions[request.session_id]
|
|
try:
|
|
return await _invoke_session_direct(
|
|
client=client,
|
|
request=request,
|
|
config=config,
|
|
session=session,
|
|
execution_mode="kvcache-direct-to-d-session",
|
|
decode_residency=decode_residency,
|
|
reserved_tokens=reserved_tokens,
|
|
direct_session_lock=direct_session_lock,
|
|
)
|
|
except Exception:
|
|
async with direct_session_lock:
|
|
_release_reserved_tokens(
|
|
decode_residency,
|
|
decode_url,
|
|
reserved_tokens,
|
|
)
|
|
raise
|
|
|
|
|
|
def _should_bypass_prefill(
|
|
*,
|
|
request: TraceRequest,
|
|
config: ReplayConfig,
|
|
decision,
|
|
block_token_budget: int = 24,
|
|
) -> bool:
|
|
if request.turn_id <= 1:
|
|
return False
|
|
if decision.observed_overlap_blocks <= 0:
|
|
return False
|
|
uncached_tokens = max(0, decision.kv_transfer_blocks * block_token_budget)
|
|
return uncached_tokens <= config.kvcache_direct_max_uncached_tokens
|
|
|
|
|
|
def _worker_url_by_id(workers, worker_id: str) -> str:
|
|
for worker in workers:
|
|
if worker.worker_id == worker_id:
|
|
return worker.url
|
|
raise KeyError(f"Unknown worker id: {worker_id}")
|
|
|
|
|
|
def _build_headers(
|
|
*,
|
|
request: TraceRequest,
|
|
header_mode: HeaderMode,
|
|
decode_worker_index: int,
|
|
policy_name: str,
|
|
) -> dict[str, str]:
|
|
if header_mode == "auto":
|
|
header_mode = "routing-key" if policy_name == "sticky" else "none"
|
|
|
|
headers = {
|
|
"x-request-id": request.request_id,
|
|
}
|
|
if header_mode == "routing-key":
|
|
headers["x-smg-routing-key"] = request.session_id
|
|
elif header_mode == "target-worker":
|
|
headers["x-smg-target-worker"] = str(decode_worker_index)
|
|
return headers
|
|
|
|
|
|
def _contains_token(payload: dict) -> bool:
|
|
choices = payload.get("choices")
|
|
if not isinstance(choices, list) or not choices:
|
|
return False
|
|
delta = choices[0].get("delta")
|
|
if not isinstance(delta, dict):
|
|
return False
|
|
|
|
reasoning_content = delta.get("reasoning_content")
|
|
if isinstance(reasoning_content, str) and reasoning_content:
|
|
return True
|
|
|
|
content = delta.get("content")
|
|
if isinstance(content, str) and content:
|
|
return True
|
|
if isinstance(content, list):
|
|
return any(
|
|
isinstance(item, dict) and item.get("text")
|
|
for item in content
|
|
)
|
|
|
|
tool_calls = delta.get("tool_calls")
|
|
if isinstance(tool_calls, list):
|
|
for tool_call in tool_calls:
|
|
if not isinstance(tool_call, dict):
|
|
continue
|
|
if tool_call.get("id"):
|
|
return True
|
|
function = tool_call.get("function")
|
|
if not isinstance(function, dict):
|
|
continue
|
|
if function.get("name") or function.get("arguments"):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _contains_generate_token(payload: dict) -> bool:
|
|
text = payload.get("text")
|
|
if isinstance(text, str) and text:
|
|
return True
|
|
meta_info = payload.get("meta_info")
|
|
if not isinstance(meta_info, dict):
|
|
return False
|
|
return int(meta_info.get("completion_tokens", 0)) > 0
|
|
|
|
|
|
def _is_generate_terminal_chunk(payload: dict) -> bool:
|
|
meta_info = payload.get("meta_info")
|
|
if not isinstance(meta_info, dict):
|
|
return False
|
|
return meta_info.get("finish_reason") is not None
|
|
|
|
|
|
def _extract_generate_cached_tokens(payload: dict) -> int:
|
|
meta_info = payload.get("meta_info")
|
|
if not isinstance(meta_info, dict):
|
|
return 0
|
|
return int(meta_info.get("cached_tokens", 0) or 0)
|
|
|
|
|
|
def _extract_openai_cached_tokens(payload: dict) -> int:
|
|
usage = payload.get("usage")
|
|
if not isinstance(usage, dict):
|
|
return 0
|
|
prompt_tokens_details = usage.get("prompt_tokens_details")
|
|
if not isinstance(prompt_tokens_details, dict):
|
|
return 0
|
|
return int(prompt_tokens_details.get("cached_tokens", 0) or 0)
|
|
|
|
|
|
async def _aiter_lines(
|
|
response: httpx.Response,
|
|
*,
|
|
idle_timeout_s: float | None,
|
|
):
|
|
line_iterator = response.aiter_lines()
|
|
while True:
|
|
try:
|
|
if idle_timeout_s is None or idle_timeout_s <= 0:
|
|
line = await anext(line_iterator)
|
|
else:
|
|
line = await asyncio.wait_for(
|
|
anext(line_iterator),
|
|
timeout=idle_timeout_s,
|
|
)
|
|
except StopAsyncIteration:
|
|
return
|
|
yield line
|
|
|
|
|
|
def _is_terminal_chunk(payload: dict) -> bool:
|
|
choices = payload.get("choices")
|
|
if not isinstance(choices, list) or not choices:
|
|
return False
|
|
finish_reason = choices[0].get("finish_reason")
|
|
return finish_reason is not None
|