from __future__ import annotations import asyncio import json import time from collections import Counter from dataclasses import dataclass, field, replace from pathlib import Path from typing import Any, Literal import httpx from agentic_pd_hybrid.metrics import ( RequestMetrics, write_metrics_jsonl, write_summary_json, ) from agentic_pd_hybrid.policies import RoutingState, create_policy from agentic_pd_hybrid.topology import SingleNodeTopology from agentic_pd_hybrid.trace import ( TraceRequest, build_synthetic_append_chunk, build_synthetic_prompt, load_trace, ) HeaderMode = Literal["none", "routing-key", "target-worker", "auto"] KvCacheAdmissionMode = Literal["router", "worker"] KvCachePrefillBackupPolicy = Literal["release-after-transfer", "capacity-backup"] _ADMISSION_PROBE_TIMEOUT_S = 2.0 @dataclass(frozen=True) class ReplayConfig: trace_path: Path output_path: Path policy_name: str mechanism_name: str topology: SingleNodeTopology router_url: str | None = None model_name: str | None = None pace: bool = True time_scale: float = 1.0 request_limit: int | None = None concurrency_limit: int = 32 header_mode: HeaderMode = "auto" timeout_s: float = 600.0 stream: bool = True stream_idle_timeout_s: float | None = 900.0 kvcache_direct_max_uncached_tokens: int = 2048 kvcache_admission_mode: KvCacheAdmissionMode = "router" kvcache_seed_max_resident_tokens: int | None = None kvcache_seed_max_output_tokens: int | None = None kvcache_seed_min_turn_id: int = 1 kvcache_seed_only_multiturn_sessions: bool = False kvcache_seed_allowed_session_ids: frozenset[str] | None = None kvcache_prefill_backup_policy: KvCachePrefillBackupPolicy = ( "release-after-transfer" ) kvcache_seed_max_inflight_decode: int | None = 3 kvcache_seed_max_decode_transfer_queue_reqs: int | None = None kvcache_direct_max_decode_transfer_queue_reqs: int | None = None kvcache_prefill_priority_eviction: bool = False kvcache_prefill_direct_priority: int = -100 kvcache_prefill_normal_priority: int = 100 @dataclass class DirectSessionState: session_id: str server_url: str opened: bool = False last_trace_request: TraceRequest | None = None resident_tokens: int = 0 last_access_s: float = 0.0 active_requests: int = 0 prefill_server_url: str | None = None prefill_opened: bool = False prefill_resident_tokens: int = 0 prefill_last_access_s: float = 0.0 prefill_low_priority: bool = False @dataclass class DecodeResidencyState: capacity_tokens: dict[str, int] = field(default_factory=dict) headroom_tokens: dict[str, int] = field(default_factory=dict) reserved_decode_tokens: dict[str, int] = field(default_factory=dict) resident_tokens_by_server: dict[str, int] = field(default_factory=dict) reserved_tokens_by_server: dict[str, int] = field(default_factory=dict) prefill_capacity_tokens: dict[str, int] = field(default_factory=dict) prefill_headroom_tokens: dict[str, int] = field(default_factory=dict) prefill_resident_tokens_by_server: dict[str, int] = field(default_factory=dict) prefill_reserved_tokens_by_server: dict[str, int] = field(default_factory=dict) decode_evictions_prefill_backed: int = 0 decode_evictions_without_prefill_backup: int = 0 @dataclass(frozen=True) class DecodeLoadSnapshot: timestamp_s: float num_running_reqs: int num_waiting_reqs: int num_used_tokens: int max_total_num_tokens: int token_usage: float decode_prealloc_queue_reqs: int decode_transfer_queue_reqs: int decode_retracted_queue_reqs: int @dataclass(frozen=True) class ExecutionResult: execution_mode: str actual_kv_transfer_blocks: int effective_input_length: int | None cached_tokens: int session_reused: bool session_reset: bool latency_s: float | None ttft_s: float | None tpot_s: float | None prefill_request_priority: int | None = None decode_request_priority: int | None = None error: str | None = None async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]: requests = load_trace(config.trace_path, request_limit=config.request_limit) if config.kvcache_seed_only_multiturn_sessions: session_turns = Counter(request.session_id for request in requests) config = replace( config, kvcache_seed_allowed_session_ids=frozenset( session_id for session_id, turn_count in session_turns.items() if turn_count > 1 ), ) policy = create_policy(config.policy_name) state = RoutingState.create(config.topology) state_lock = asyncio.Lock() semaphore = asyncio.Semaphore(config.concurrency_limit) start_time = time.perf_counter() first_timestamp = requests[0].timestamp_s if requests else 0.0 session_tail_tasks: dict[str, asyncio.Task[RequestMetrics]] = {} direct_sessions: dict[str, DirectSessionState] = {} direct_session_lock = asyncio.Lock() async with httpx.AsyncClient(timeout=config.timeout_s, trust_env=False) as client: decode_residency = await _discover_decode_residency( client=client, config=config, ) tasks = [] for request in requests: if config.pace: target_offset = (request.timestamp_s - first_timestamp) / config.time_scale sleep_s = target_offset - (time.perf_counter() - start_time) if sleep_s > 0: await asyncio.sleep(sleep_s) tasks.append( asyncio.create_task( _run_request( request=request, config=config, client=client, policy=policy, state=state, state_lock=state_lock, semaphore=semaphore, direct_sessions=direct_sessions, direct_session_lock=direct_session_lock, decode_residency=decode_residency, depends_on=session_tail_tasks.get(request.session_id), ) ) ) session_tail_tasks[request.session_id] = tasks[-1] results = await asyncio.gather(*tasks) for session in direct_sessions.values(): if session.opened: try: await _close_streaming_session( client=client, server_url=session.server_url, session_id=session.session_id, allow_missing=True, ) except Exception: pass if session.prefill_opened and session.prefill_server_url is not None: try: await _close_streaming_session( client=client, server_url=session.prefill_server_url, session_id=session.session_id, allow_missing=True, ) except Exception: pass write_metrics_jsonl(config.output_path, results) write_summary_json( config.output_path.with_suffix(config.output_path.suffix + ".summary.json"), results, trace_path=config.trace_path, router_url=config.router_url, ) return results async def _run_request( *, request: TraceRequest, config: ReplayConfig, client: httpx.AsyncClient, policy, state: RoutingState, state_lock: asyncio.Lock, semaphore: asyncio.Semaphore, direct_sessions: dict[str, DirectSessionState], direct_session_lock: asyncio.Lock, decode_residency: DecodeResidencyState, depends_on: asyncio.Task[RequestMetrics] | None, ) -> RequestMetrics: if depends_on is not None: await depends_on async with semaphore: async with state_lock: decision = policy.select(request, topology=config.topology, state=state) try: execution = await _execute_request( client=client, request=request, config=config, decision=decision, direct_sessions=direct_sessions, direct_session_lock=direct_session_lock, decode_residency=decode_residency, ) except Exception as exc: # pragma: no cover - defensive logging path execution = ExecutionResult( execution_mode=config.mechanism_name, actual_kv_transfer_blocks=0, effective_input_length=None, cached_tokens=0, session_reused=False, session_reset=False, latency_s=None, ttft_s=None, tpot_s=None, error=f"{type(exc).__name__}: {exc}", ) async with state_lock: state.finish(request, decision) return RequestMetrics.from_decision( request, decision, mechanism_name=config.mechanism_name, execution_mode=execution.execution_mode, actual_kv_transfer_blocks=execution.actual_kv_transfer_blocks, effective_input_length=execution.effective_input_length, cached_tokens=execution.cached_tokens, session_reused=execution.session_reused, session_reset=execution.session_reset, latency_s=execution.latency_s, ttft_s=execution.ttft_s, tpot_s=execution.tpot_s, prefill_request_priority=execution.prefill_request_priority, decode_request_priority=execution.decode_request_priority, error=execution.error, ) async def _invoke_router( *, client: httpx.AsyncClient, request: TraceRequest, config: ReplayConfig, decode_worker_index: int, session_id: str | None = None, prefill_request_priority: int | None = None, decode_request_priority: int | None = None, ) -> tuple[float, float | None, float | None, int]: headers = _build_headers( request=request, header_mode=config.header_mode, decode_worker_index=decode_worker_index, policy_name=config.policy_name, ) assert config.router_url is not None payload: dict[str, object] = { "input_ids": _build_direct_full_input_ids(request), "sampling_params": { "temperature": 0, "max_new_tokens": max(1, request.output_length), "ignore_eos": True, "no_stop_trim": True, "skip_special_tokens": False, }, "stream": config.stream, } if session_id is not None: payload["session_params"] = {"id": session_id} if prefill_request_priority is not None: payload["smg_prefill_priority"] = prefill_request_priority if decode_request_priority is not None: payload["smg_decode_priority"] = decode_request_priority return await _invoke_generate( client=client, base_url=config.router_url, headers=headers, payload=payload, timeout_s=config.timeout_s, stream_idle_timeout_s=config.stream_idle_timeout_s, stream=config.stream, ) def _build_payload( *, request: TraceRequest, model_name: str, prompt: str, stream: bool, session_params: dict[str, str] | None, exact_output_length: bool, ) -> dict[str, object]: payload: dict[str, object] = { "model": model_name, "messages": [{"role": "user", "content": prompt}], "max_tokens": max(1, request.output_length), "temperature": 0, "stream": stream, } if stream: payload["stream_options"] = {"include_usage": True} if exact_output_length: payload.update( { "min_tokens": max(1, request.output_length), "ignore_eos": True, "no_stop_trim": True, "skip_special_tokens": False, } ) if session_params is not None: payload["session_params"] = session_params return payload async def _invoke_chat_completion( *, client: httpx.AsyncClient, base_url: str, headers: dict[str, str], payload: dict[str, object], timeout_s: float, stream_idle_timeout_s: float | None, stream: bool, ) -> tuple[float, float | None, float | None, int]: start = time.perf_counter() ttft_s: float | None = None cached_tokens = 0 generated_tokens = int(payload.get("max_tokens", 1)) if stream: async with client.stream( "POST", f"{base_url.rstrip('/')}/v1/chat/completions", headers=headers, json=payload, timeout=timeout_s, ) as response: response.raise_for_status() async for line in _aiter_lines( response, idle_timeout_s=stream_idle_timeout_s, ): if not line.startswith("data:"): continue data = line[5:].strip() if data == "[DONE]": break parsed = json.loads(data) cached_tokens = max(cached_tokens, _extract_openai_cached_tokens(parsed)) if _contains_token(parsed) and ttft_s is None: ttft_s = time.perf_counter() - start if _is_terminal_chunk(parsed): break else: response = await client.post( f"{base_url.rstrip('/')}/v1/chat/completions", headers=headers, json=payload, timeout=timeout_s, ) response.raise_for_status() parsed = response.json() cached_tokens = _extract_openai_cached_tokens(parsed) latency_s = time.perf_counter() - start if stream and ttft_s is None and generated_tokens > 0: raise RuntimeError("generate stream ended before producing any token") if ttft_s is None: tpot_s = None else: tpot_s = max(0.0, latency_s - ttft_s) / max(1, generated_tokens) return latency_s, ttft_s, tpot_s, cached_tokens async def _invoke_generate( *, client: httpx.AsyncClient, base_url: str, headers: dict[str, str], payload: dict[str, object], timeout_s: float, stream_idle_timeout_s: float | None, stream: bool, ) -> tuple[float, float | None, float | None, int]: start = time.perf_counter() ttft_s: float | None = None cached_tokens = 0 sampling_params = payload.get("sampling_params", {}) generated_tokens = int(sampling_params.get("max_new_tokens", 1)) if stream: async with client.stream( "POST", f"{base_url.rstrip('/')}/generate", headers=headers, json=payload, timeout=timeout_s, ) as response: response.raise_for_status() async for line in _aiter_lines( response, idle_timeout_s=stream_idle_timeout_s, ): if not line.startswith("data:"): continue data = line[5:].strip() if data == "[DONE]": break parsed = json.loads(data) error = parsed.get("error") if isinstance(error, dict): raise ValueError(error.get("message", json.dumps(error))) cached_tokens = max(cached_tokens, _extract_generate_cached_tokens(parsed)) if _contains_generate_token(parsed) and ttft_s is None: ttft_s = time.perf_counter() - start if _is_generate_terminal_chunk(parsed): break else: response = await client.post( f"{base_url.rstrip('/')}/generate", headers=headers, json=payload, timeout=timeout_s, ) response.raise_for_status() parsed = response.json() error = parsed.get("error") if isinstance(error, dict): raise ValueError(error.get("message", json.dumps(error))) cached_tokens = _extract_generate_cached_tokens(parsed) latency_s = time.perf_counter() - start if stream and ttft_s is None and generated_tokens > 0: raise RuntimeError("generate stream ended before producing any token") if ttft_s is None: tpot_s = None else: tpot_s = max(0.0, latency_s - ttft_s) / max(1, generated_tokens) return latency_s, ttft_s, tpot_s, cached_tokens async def _open_streaming_session( *, client: httpx.AsyncClient, server_url: str, session_id: str, request: TraceRequest, ) -> None: capacity = max( 4096, request.input_length * 16, (request.input_length + request.output_length) * 16, ) response = await client.post( f"{server_url.rstrip('/')}/open_session", json={ "capacity_of_str_len": capacity, "session_id": session_id, "streaming": True, }, timeout=_ADMISSION_PROBE_TIMEOUT_S, ) response.raise_for_status() opened_session_id = response.json() if opened_session_id != session_id: raise ValueError( f"Unexpected session id from {server_url}: {opened_session_id!r} != {session_id!r}" ) async def _close_streaming_session( *, client: httpx.AsyncClient, server_url: str, session_id: str, allow_missing: bool = False, ) -> None: response = await client.post( f"{server_url.rstrip('/')}/close_session", json={"session_id": session_id}, timeout=_ADMISSION_PROBE_TIMEOUT_S, ) if response.is_success: return if allow_missing: response_text = response.text.lower() if response.status_code == 404 or "does not exist" in response_text: return response.raise_for_status() def _extract_internal_state(payload: dict[str, Any]) -> dict[str, Any]: internal_states = payload.get("internal_states") if isinstance(internal_states, list) and internal_states: internal_state = internal_states[0] if isinstance(internal_state, dict): return internal_state return payload def _extract_server_int( payload: dict[str, Any], key: str, ) -> int: internal_state = _extract_internal_state(payload) value = payload.get(key, internal_state.get(key, 0)) return int(value or 0) def _extract_session_cache( payload: dict[str, Any], ) -> dict[str, Any]: internal_state = _extract_internal_state(payload) session_cache = internal_state.get("session_cache") if isinstance(session_cache, dict): return session_cache return {} def _find_session_cache_status( session_cache: dict[str, Any], session_id: str, ) -> dict[str, Any] | None: sessions = session_cache.get("sessions") if not isinstance(sessions, list): return None for session in sessions: if isinstance(session, dict) and session.get("session_id") == session_id: return session return None async def _fetch_decode_server_state( *, client: httpx.AsyncClient, server_url: str, ) -> tuple[dict[str, Any], int, int]: try: response = await client.get( f"{server_url.rstrip('/')}/server_info", timeout=_ADMISSION_PROBE_TIMEOUT_S, ) response.raise_for_status() payload = response.json() except Exception: return {}, 0, 0 return ( _extract_session_cache(payload), _extract_server_int(payload, "max_total_num_tokens"), _extract_server_int(payload, "num_reserved_decode_tokens"), ) async def _query_decode_direct_admission( *, client: httpx.AsyncClient, server_url: str, session_id: str, uncached_input_tokens: int, output_tokens: int, ) -> dict[str, Any]: try: response = await client.post( f"{server_url.rstrip('/')}/session_cache/admit_direct_append", json={ "session_id": session_id, "uncached_input_tokens": max(0, uncached_input_tokens), "output_tokens": max(0, output_tokens), }, timeout=_ADMISSION_PROBE_TIMEOUT_S, ) response.raise_for_status() payload = response.json() if isinstance(payload, dict): return payload except Exception: pass return { "can_admit": False, "resident": False, "reason": "admission-query-failed", "required_tokens": 0, "available_tokens_before": 0, "available_tokens_after": 0, "evicted_session_count": 0, "freed_tokens": 0, } async def _discover_decode_residency( *, client: httpx.AsyncClient, config: ReplayConfig, ) -> DecodeResidencyState: residency = DecodeResidencyState() if config.mechanism_name != "kvcache-centric": return residency for worker in config.topology.decode_workers: _session_cache, max_total_num_tokens, reserved_decode_tokens = ( await _fetch_decode_server_state( client=client, server_url=worker.url, ) ) if max_total_num_tokens <= 0: continue safety_headroom = max( reserved_decode_tokens * 4, max_total_num_tokens // 20, 8192, ) residency.capacity_tokens[worker.url] = max_total_num_tokens residency.headroom_tokens[worker.url] = min( max_total_num_tokens, safety_headroom, ) residency.reserved_decode_tokens[worker.url] = reserved_decode_tokens for worker in config.topology.prefill_workers: _session_cache, max_total_num_tokens, _reserved_decode_tokens = ( await _fetch_decode_server_state( client=client, server_url=worker.url, ) ) if max_total_num_tokens <= 0: continue safety_headroom = max( max_total_num_tokens // 10, 16384, ) residency.prefill_capacity_tokens[worker.url] = max_total_num_tokens residency.prefill_headroom_tokens[worker.url] = min( max_total_num_tokens, safety_headroom, ) return residency def _estimate_session_resident_tokens(request: TraceRequest) -> int: return request.input_length + request.output_length def _seed_filter_reason( *, request: TraceRequest, config: ReplayConfig, inflight_decode_load: int | None = None, ) -> str | None: if request.turn_id < config.kvcache_seed_min_turn_id: return "seed-filter-early-turn" if ( config.kvcache_seed_max_inflight_decode is not None and inflight_decode_load is not None and inflight_decode_load > config.kvcache_seed_max_inflight_decode ): return "seed-filter-inflight-decode-load" if ( config.kvcache_seed_allowed_session_ids is not None and request.session_id not in config.kvcache_seed_allowed_session_ids ): return "seed-filter-single-turn-session" resident_tokens = _estimate_session_resident_tokens(request) if ( config.kvcache_seed_max_resident_tokens is not None and resident_tokens > config.kvcache_seed_max_resident_tokens ): return "seed-filter-resident-tokens" if ( config.kvcache_seed_max_output_tokens is not None and request.output_length > config.kvcache_seed_max_output_tokens ): return "seed-filter-output-tokens" return None def _prefill_priority_for_router_request( *, config: ReplayConfig, direct_to_d_predicted: bool, ) -> int | None: if not config.kvcache_prefill_priority_eviction: return None if direct_to_d_predicted: return config.kvcache_prefill_direct_priority return config.kvcache_prefill_normal_priority def _inspect_direct_request( *, request: TraceRequest, session: DirectSessionState, ) -> tuple[int, bool, bool]: previous = session.last_trace_request if previous is None: return request.input_length, False, False append_length = request.input_length - ( previous.input_length + previous.output_length ) if append_length <= 0: return request.input_length, False, True return append_length, True, False def _add_reserved_tokens( residency: DecodeResidencyState, server_url: str, delta_tokens: int, ) -> None: if delta_tokens <= 0: return residency.reserved_tokens_by_server[server_url] = ( residency.reserved_tokens_by_server.get(server_url, 0) + delta_tokens ) def _release_reserved_tokens( residency: DecodeResidencyState, server_url: str, delta_tokens: int, ) -> None: if delta_tokens <= 0: return remaining = residency.reserved_tokens_by_server.get(server_url, 0) - delta_tokens if remaining > 0: residency.reserved_tokens_by_server[server_url] = remaining else: residency.reserved_tokens_by_server.pop(server_url, None) def _add_prefill_reserved_tokens( residency: DecodeResidencyState, server_url: str, delta_tokens: int, ) -> None: if delta_tokens <= 0: return residency.prefill_reserved_tokens_by_server[server_url] = ( residency.prefill_reserved_tokens_by_server.get(server_url, 0) + delta_tokens ) def _release_prefill_reserved_tokens( residency: DecodeResidencyState, server_url: str, delta_tokens: int, ) -> None: if delta_tokens <= 0: return remaining = ( residency.prefill_reserved_tokens_by_server.get(server_url, 0) - delta_tokens ) if remaining > 0: residency.prefill_reserved_tokens_by_server[server_url] = remaining else: residency.prefill_reserved_tokens_by_server.pop(server_url, None) def _usable_capacity_tokens( residency: DecodeResidencyState, server_url: str, ) -> int: return max( 0, residency.capacity_tokens.get(server_url, 0) - residency.headroom_tokens.get(server_url, 0), ) def _usable_prefill_backup_capacity_tokens( residency: DecodeResidencyState, server_url: str, ) -> int: return max( 0, residency.prefill_capacity_tokens.get(server_url, 0) - residency.prefill_headroom_tokens.get(server_url, 0), ) def _eviction_suffix(evicted_sessions: int, prefill_backed_evictions: int) -> str: if evicted_sessions <= 0: return "" if prefill_backed_evictions >= evicted_sessions: return "-after-prefill-backed-eviction" if prefill_backed_evictions > 0: return "-after-mixed-eviction" return "-after-eviction" def _decode_session_soft_cap( *, residency: DecodeResidencyState, server_url: str, request: TraceRequest, ) -> int: target_tokens = max(1, _estimate_session_resident_tokens(request)) usable_capacity_tokens = _usable_capacity_tokens(residency, server_url) if usable_capacity_tokens <= 0: usable_capacity_tokens = max( 0, residency.capacity_tokens.get(server_url, 0) - residency.headroom_tokens.get(server_url, 0), ) if usable_capacity_tokens <= 0: return 4 return max(1, min(4, usable_capacity_tokens // target_tokens)) def _should_admit_new_decode_session( *, residency: DecodeResidencyState, server_url: str, request: TraceRequest, session: DirectSessionState, direct_sessions: dict[str, DirectSessionState], treat_as_fresh_session: bool, ) -> bool: if ( not treat_as_fresh_session and session.opened and session.server_url == server_url ): return True open_sessions = sum( 1 for candidate in direct_sessions.values() if candidate.opened and candidate.server_url == server_url ) return open_sessions < _decode_session_soft_cap( residency=residency, server_url=server_url, request=request, ) async def _fetch_decode_load_snapshot( *, client: httpx.AsyncClient, server_url: str, ) -> DecodeLoadSnapshot | None: try: response = await client.get( f"{server_url.rstrip('/')}/v1/loads", params={"include": "core,disagg"}, timeout=_ADMISSION_PROBE_TIMEOUT_S, ) response.raise_for_status() payload = response.json() except Exception: return None loads = payload.get("loads") if not isinstance(loads, list) or not loads: return None load = loads[0] disagg = load.get("disaggregation") or {} return DecodeLoadSnapshot( timestamp_s=time.perf_counter(), num_running_reqs=int(load.get("num_running_reqs", 0) or 0), num_waiting_reqs=int(load.get("num_waiting_reqs", 0) or 0), num_used_tokens=int(load.get("num_used_tokens", 0) or 0), max_total_num_tokens=int(load.get("max_total_num_tokens", 0) or 0), token_usage=float(load.get("token_usage", 0.0) or 0.0), decode_prealloc_queue_reqs=int( disagg.get("decode_prealloc_queue_reqs", 0) or 0 ), decode_transfer_queue_reqs=int( disagg.get("decode_transfer_queue_reqs", 0) or 0 ), decode_retracted_queue_reqs=int( disagg.get("decode_retracted_queue_reqs", 0) or 0 ), ) def _decode_load_backpressure_reason( snapshot: DecodeLoadSnapshot | None, *, config: ReplayConfig, routing_mode: Literal["direct", "seed"], ) -> str | None: if snapshot is None: return None if routing_mode == "direct": if ( config.kvcache_direct_max_decode_transfer_queue_reqs is not None and snapshot.decode_transfer_queue_reqs > config.kvcache_direct_max_decode_transfer_queue_reqs ): return "d-transfer-backpressure" if snapshot.decode_retracted_queue_reqs > 0 and snapshot.token_usage >= 0.99: return "d-retracted" if snapshot.token_usage >= 0.992: return "d-token-usage-critical" else: if ( config.kvcache_seed_max_decode_transfer_queue_reqs is not None and snapshot.decode_transfer_queue_reqs > config.kvcache_seed_max_decode_transfer_queue_reqs ): return "d-transfer-backpressure" if snapshot.decode_retracted_queue_reqs > 0: return "d-retracted" if snapshot.token_usage >= 0.985: return "d-token-usage-critical" if routing_mode == "seed" and snapshot.token_usage >= 0.94 and ( snapshot.decode_prealloc_queue_reqs > 0 or snapshot.decode_transfer_queue_reqs > 0 ): return "d-prealloc-backpressure" return None def _is_decode_backpressure_reason(reason: str | None) -> bool: return reason in { "d-retracted", "d-token-usage-critical", "d-prealloc-backpressure", "d-transfer-backpressure", } def _is_stale_decode_session_error(exc: Exception) -> bool: return ( isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 400 ) def _dynamic_decode_headroom_tokens( *, residency: DecodeResidencyState, server_url: str, snapshot: DecodeLoadSnapshot | None, routing_mode: Literal["direct", "seed"], ) -> int: if snapshot is None: return residency.headroom_tokens.get(server_url, 0) base_reserved = max(512, residency.reserved_decode_tokens.get(server_url, 0)) if routing_mode == "direct": direct_queue_pressure = max( 1, snapshot.decode_prealloc_queue_reqs + snapshot.decode_transfer_queue_reqs + snapshot.decode_retracted_queue_reqs, ) capacity_divisor = 24 minimum_headroom = 4096 return max( base_reserved * direct_queue_pressure, snapshot.max_total_num_tokens // capacity_divisor, minimum_headroom, ) disagg_queued = ( snapshot.decode_prealloc_queue_reqs + snapshot.decode_transfer_queue_reqs + snapshot.decode_retracted_queue_reqs ) active_decode_pressure = max(1, snapshot.num_running_reqs + disagg_queued) capacity_divisor = 15 minimum_headroom = 12288 return max( base_reserved * active_decode_pressure, snapshot.max_total_num_tokens // capacity_divisor, minimum_headroom, ) def _commit_session_residency( *, residency: DecodeResidencyState, session: DirectSessionState, request: TraceRequest, reserved_tokens: int, ) -> None: _release_reserved_tokens(residency, session.server_url, reserved_tokens) previous_tokens = session.resident_tokens if session.opened else 0 new_tokens = _estimate_session_resident_tokens(request) delta_tokens = new_tokens - previous_tokens if delta_tokens != 0: residency.resident_tokens_by_server[session.server_url] = ( residency.resident_tokens_by_server.get(session.server_url, 0) + delta_tokens ) session.opened = True session.resident_tokens = new_tokens session.last_trace_request = request session.last_access_s = time.perf_counter() if session.prefill_opened: session.prefill_low_priority = True def _commit_prefill_backup_residency( *, residency: DecodeResidencyState, session: DirectSessionState, request: TraceRequest, prefill_url: str, reserved_tokens: int, ) -> None: _release_prefill_reserved_tokens(residency, prefill_url, reserved_tokens) previous_tokens = session.prefill_resident_tokens if session.prefill_opened else 0 new_tokens = _estimate_session_resident_tokens(request) delta_tokens = new_tokens - previous_tokens if delta_tokens != 0: residency.prefill_resident_tokens_by_server[prefill_url] = ( residency.prefill_resident_tokens_by_server.get(prefill_url, 0) + delta_tokens ) session.prefill_server_url = prefill_url session.prefill_opened = True session.prefill_resident_tokens = new_tokens session.prefill_last_access_s = time.perf_counter() session.prefill_low_priority = session.opened async def _close_prefill_session( *, client: httpx.AsyncClient, session: DirectSessionState, residency: DecodeResidencyState, ) -> None: if not session.prefill_opened or session.prefill_server_url is None: session.prefill_opened = False session.prefill_resident_tokens = 0 session.prefill_low_priority = False return prefill_url = session.prefill_server_url await _close_streaming_session( client=client, server_url=prefill_url, session_id=session.session_id, allow_missing=True, ) remaining = ( residency.prefill_resident_tokens_by_server.get(prefill_url, 0) - session.prefill_resident_tokens ) if remaining > 0: residency.prefill_resident_tokens_by_server[prefill_url] = remaining else: residency.prefill_resident_tokens_by_server.pop(prefill_url, None) session.prefill_opened = False session.prefill_resident_tokens = 0 session.prefill_low_priority = False async def _close_decode_session( *, client: httpx.AsyncClient, session: DirectSessionState, residency: DecodeResidencyState, evicting_for_capacity: bool = False, ) -> None: if not session.opened: session.resident_tokens = 0 return await _close_streaming_session( client=client, server_url=session.server_url, session_id=session.session_id, allow_missing=True, ) remaining = ( residency.resident_tokens_by_server.get(session.server_url, 0) - session.resident_tokens ) if remaining > 0: residency.resident_tokens_by_server[session.server_url] = remaining else: residency.resident_tokens_by_server.pop(session.server_url, None) session.opened = False session.resident_tokens = 0 if session.prefill_opened: residency.decode_evictions_prefill_backed += int(evicting_for_capacity) session.prefill_low_priority = False elif evicting_for_capacity: residency.decode_evictions_without_prefill_backup += 1 async def _reserve_prefill_backup_capacity( *, client: httpx.AsyncClient, request: TraceRequest, prefill_url: str, session: DirectSessionState, direct_sessions: dict[str, DirectSessionState], residency: DecodeResidencyState, ) -> tuple[bool, int, int]: session_cache, max_total_num_tokens, _reserved_decode_tokens = ( await _fetch_decode_server_state( client=client, server_url=prefill_url, ) ) if max_total_num_tokens > 0: residency.prefill_capacity_tokens[prefill_url] = max_total_num_tokens capacity_tokens = residency.prefill_capacity_tokens.get(prefill_url, 0) headroom_tokens = residency.prefill_headroom_tokens.get(prefill_url, 0) if capacity_tokens <= 0: return True, 0, 0 low_occupancy_headroom_tokens = max( headroom_tokens, capacity_tokens // 2, ) target_session_status = _find_session_cache_status( session_cache, session.session_id, ) if ( isinstance(target_session_status, dict) and bool(target_session_status.get("resident")) ): current_tokens = int(target_session_status.get("resident_tokens", 0) or 0) else: current_tokens = ( session.prefill_resident_tokens if session.prefill_opened and session.prefill_server_url == prefill_url else 0 ) target_tokens = _estimate_session_resident_tokens(request) required_extra_tokens = max(0, target_tokens - current_tokens) evicted_sessions = 0 max_backup_sessions = max(1, capacity_tokens // max(1, target_tokens * 2)) max_backup_sessions = min(max_backup_sessions, 4) available_tokens = int(session_cache.get("available_tokens", 0) or 0) if available_tokens <= 0: held_tokens = int(session_cache.get("held_tokens", 0) or 0) available_tokens = max(0, capacity_tokens - held_tokens) available_tokens -= residency.prefill_reserved_tokens_by_server.get(prefill_url, 0) def has_enough_prefill_headroom() -> bool: return available_tokens - required_extra_tokens >= low_occupancy_headroom_tokens def prefill_backup_count() -> int: return sum( 1 for candidate in direct_sessions.values() if candidate.prefill_opened and candidate.prefill_server_url == prefill_url ) while ( required_extra_tokens > 0 and ( not has_enough_prefill_headroom() or ( not session.prefill_opened and prefill_backup_count() >= max_backup_sessions ) ) ): candidates = sorted( ( candidate for candidate in direct_sessions.values() if candidate.prefill_opened and candidate.prefill_server_url == prefill_url and candidate.session_id != session.session_id and candidate.active_requests <= 0 ), key=lambda candidate: ( 0 if candidate.prefill_low_priority else 1, candidate.prefill_last_access_s, ), ) if not candidates: break freed_tokens = candidates[0].prefill_resident_tokens await _close_prefill_session( client=client, session=candidates[0], residency=residency, ) available_tokens += freed_tokens evicted_sessions += 1 if not has_enough_prefill_headroom(): return False, 0, evicted_sessions _add_prefill_reserved_tokens(residency, prefill_url, required_extra_tokens) return True, required_extra_tokens, evicted_sessions async def _reserve_decode_session_capacity( *, client: httpx.AsyncClient, config: ReplayConfig, request: TraceRequest, server_url: str, session: DirectSessionState, direct_sessions: dict[str, DirectSessionState], residency: DecodeResidencyState, treat_as_fresh_session: bool, routing_mode: Literal["direct", "seed"], admission_mode: KvCacheAdmissionMode, ) -> tuple[bool, int, int, int, str | None]: if admission_mode == "router": return await _reserve_decode_session_capacity_from_router_state( client=client, config=config, request=request, server_url=server_url, session=session, direct_sessions=direct_sessions, residency=residency, treat_as_fresh_session=treat_as_fresh_session, routing_mode=routing_mode, ) if treat_as_fresh_session and session.opened: await _close_decode_session( client=client, session=session, residency=residency, ) current_tokens = 0 if treat_as_fresh_session else session.resident_tokens target_tokens = _estimate_session_resident_tokens(request) required_extra_tokens = max(0, target_tokens - current_tokens) prefill_backed_evictions = 0 if routing_mode == "direct" and not treat_as_fresh_session: if not session.opened: return False, 0, 0, 0, "d-session-not-resident" admission = await _query_decode_direct_admission( client=client, server_url=server_url, session_id=session.session_id, uncached_input_tokens=max(0, request.input_length - current_tokens), output_tokens=request.output_length, ) if not bool(admission.get("resident")): return False, 0, 0, 0, str(admission.get("reason") or "d-session-not-resident") if not bool(admission.get("can_admit")): return ( False, 0, int(admission.get("evicted_session_count", 0) or 0), 0, str(admission.get("reason") or "d-no-space"), ) reserved_tokens = int( admission.get("required_tokens", required_extra_tokens) or required_extra_tokens ) _add_reserved_tokens(residency, server_url, reserved_tokens) return ( True, reserved_tokens, int(admission.get("evicted_session_count", 0) or 0), 0, None, ) session_cache, max_total_num_tokens, reserved_decode_tokens = ( await _fetch_decode_server_state( client=client, server_url=server_url, ) ) if max_total_num_tokens > 0: residency.capacity_tokens[server_url] = max_total_num_tokens if reserved_decode_tokens > 0: residency.reserved_decode_tokens[server_url] = reserved_decode_tokens target_session_status = _find_session_cache_status( session_cache, session.session_id, ) if routing_mode == "direct" and not ( isinstance(target_session_status, dict) and bool(target_session_status.get("resident")) ): return False, 0, 0, 0, "d-session-not-resident" load_snapshot = await _fetch_decode_load_snapshot( client=client, server_url=server_url, ) if load_snapshot is not None and load_snapshot.max_total_num_tokens > 0: residency.capacity_tokens[server_url] = load_snapshot.max_total_num_tokens backpressure_reason = _decode_load_backpressure_reason( load_snapshot, config=config, routing_mode=routing_mode, ) if backpressure_reason is not None: return False, 0, 0, 0, backpressure_reason usable_capacity_tokens = _usable_capacity_tokens(residency, server_url) evicted_sessions = 0 while ( required_extra_tokens > 0 and residency.resident_tokens_by_server.get(server_url, 0) + residency.reserved_tokens_by_server.get(server_url, 0) + required_extra_tokens > usable_capacity_tokens + int(session_cache.get("idle_evictable_tokens", 0) or 0) ): candidates = sorted( ( candidate for candidate in direct_sessions.values() if candidate.opened and candidate.server_url == server_url and candidate.session_id != session.session_id and candidate.active_requests <= 0 ), key=lambda candidate: candidate.last_access_s, ) if not candidates: break await _close_decode_session( client=client, session=candidates[0], residency=residency, evicting_for_capacity=True, ) prefill_backed_evictions += int(candidates[0].prefill_opened) evicted_sessions += 1 if evicted_sessions > 0: load_snapshot = await _fetch_decode_load_snapshot( client=client, server_url=server_url, ) session_cache, max_total_num_tokens, reserved_decode_tokens = ( await _fetch_decode_server_state( client=client, server_url=server_url, ) ) if max_total_num_tokens > 0: residency.capacity_tokens[server_url] = max_total_num_tokens if reserved_decode_tokens > 0: residency.reserved_decode_tokens[server_url] = reserved_decode_tokens usable_capacity_tokens = _usable_capacity_tokens(residency, server_url) if load_snapshot is not None: dynamic_headroom = _dynamic_decode_headroom_tokens( residency=residency, server_url=server_url, snapshot=load_snapshot, routing_mode=routing_mode, ) residency.headroom_tokens[server_url] = min( residency.capacity_tokens.get(server_url, dynamic_headroom), dynamic_headroom, ) usable_capacity_tokens = max( 0, residency.capacity_tokens.get(server_url, 0) - dynamic_headroom, ) usable_capacity_tokens = _usable_capacity_tokens(residency, server_url) if load_snapshot is not None: dynamic_headroom = _dynamic_decode_headroom_tokens( residency=residency, server_url=server_url, snapshot=load_snapshot, routing_mode=routing_mode, ) residency.headroom_tokens[server_url] = min( residency.capacity_tokens.get(server_url, dynamic_headroom), dynamic_headroom, ) usable_capacity_tokens = max( 0, residency.capacity_tokens.get(server_url, 0) - dynamic_headroom, ) effective_used_tokens = ( load_snapshot.num_used_tokens if load_snapshot is not None else residency.resident_tokens_by_server.get(server_url, 0) ) + residency.reserved_tokens_by_server.get(server_url, 0) idle_evictable_tokens = int(session_cache.get("idle_evictable_tokens", 0) or 0) if ( routing_mode == "direct" and isinstance(target_session_status, dict) and bool(target_session_status.get("idle_evictable")) ): idle_evictable_tokens = max( 0, idle_evictable_tokens - int(target_session_status.get("resident_tokens", 0) or 0), ) if effective_used_tokens + required_extra_tokens > ( usable_capacity_tokens + idle_evictable_tokens ): return False, 0, evicted_sessions, prefill_backed_evictions, "d-no-space" _add_reserved_tokens(residency, server_url, required_extra_tokens) return True, required_extra_tokens, evicted_sessions, prefill_backed_evictions, None async def _reserve_decode_session_capacity_from_router_state( *, client: httpx.AsyncClient, config: ReplayConfig, request: TraceRequest, server_url: str, session: DirectSessionState, direct_sessions: dict[str, DirectSessionState], residency: DecodeResidencyState, treat_as_fresh_session: bool, routing_mode: Literal["direct", "seed"], ) -> tuple[bool, int, int, int, str | None]: if treat_as_fresh_session and session.opened: await _close_decode_session( client=client, session=session, residency=residency, ) if routing_mode == "direct" and not session.opened: return False, 0, 0, 0, "d-session-not-resident" current_tokens = 0 if treat_as_fresh_session else session.resident_tokens target_tokens = _estimate_session_resident_tokens(request) required_extra_tokens = max(0, target_tokens - current_tokens) usable_capacity_tokens = _usable_capacity_tokens(residency, server_url) # If discovery failed, do not force every request down the P/D fallback path. # The router can still preserve correctness; this only disables proactive # capacity admission until the worker reports capacity again in a later run. if usable_capacity_tokens <= 0: _add_reserved_tokens(residency, server_url, required_extra_tokens) return True, required_extra_tokens, 0, 0, None evicted_sessions = 0 prefill_backed_evictions = 0 while ( required_extra_tokens > 0 and residency.resident_tokens_by_server.get(server_url, 0) + residency.reserved_tokens_by_server.get(server_url, 0) + required_extra_tokens > usable_capacity_tokens ): candidates = sorted( ( candidate for candidate in direct_sessions.values() if candidate.opened and candidate.server_url == server_url and candidate.session_id != session.session_id and candidate.active_requests <= 0 ), key=lambda candidate: candidate.last_access_s, ) if not candidates: break await _close_decode_session( client=client, session=candidates[0], residency=residency, evicting_for_capacity=True, ) prefill_backed_evictions += int(candidates[0].prefill_opened) evicted_sessions += 1 if ( residency.resident_tokens_by_server.get(server_url, 0) + residency.reserved_tokens_by_server.get(server_url, 0) + required_extra_tokens > usable_capacity_tokens ): return False, 0, evicted_sessions, prefill_backed_evictions, "d-no-space" _add_reserved_tokens(residency, server_url, required_extra_tokens) return True, required_extra_tokens, evicted_sessions, prefill_backed_evictions, None def _build_direct_prompt( *, request: TraceRequest, session: DirectSessionState, ) -> tuple[str, int, bool, bool]: append_length, session_reused, session_reset = _inspect_direct_request( request=request, session=session, ) if session_reset: return build_synthetic_prompt(request), request.input_length, False, True if not session_reused: return build_synthetic_prompt(request), request.input_length, False, False return ( build_synthetic_append_chunk(request, append_length), append_length, session_reused, session_reset, ) def _build_direct_full_input_ids( request: TraceRequest, *, block_token_budget: int = 24, ) -> list[int]: input_ids: list[int] = [] for hash_id in request.hash_ids: base = 1000 + (hash_id % 3000) for offset in range(block_token_budget): input_ids.append(1000 + ((base + offset) % 3000)) while len(input_ids) < request.input_length: input_ids.append(5000 + (len(input_ids) % 2000)) return input_ids[: request.input_length] def _build_direct_append_input_ids( request: TraceRequest, append_length: int, ) -> list[int]: if append_length <= 0: return [] base = 9000 + ((request.chat_id + request.turn_id) % 2000) return [ 9000 + ((base + offset) % 2000) for offset in range(append_length) ] async def _invoke_plain_router( *, client: httpx.AsyncClient, request: TraceRequest, config: ReplayConfig, decision, execution_mode: str, ) -> ExecutionResult: prefill_priority = _prefill_priority_for_router_request( config=config, direct_to_d_predicted=False, ) latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_router( client=client, request=request, config=config, decode_worker_index=decision.decode_worker_index, prefill_request_priority=prefill_priority, ) return ExecutionResult( execution_mode=execution_mode, actual_kv_transfer_blocks=decision.kv_transfer_blocks, effective_input_length=request.input_length, cached_tokens=cached_tokens, prefill_request_priority=prefill_priority, session_reused=False, session_reset=False, latency_s=latency_s, ttft_s=ttft_s, tpot_s=tpot_s, ) async def _invoke_kvcache_seeded_router( *, client: httpx.AsyncClient, request: TraceRequest, config: ReplayConfig, decision, prefill_url: str, decode_session: DirectSessionState, direct_sessions: dict[str, DirectSessionState], direct_session_lock: asyncio.Lock, decode_residency: DecodeResidencyState, reserved_tokens: int, execution_mode: str, ) -> ExecutionResult: keep_prefill_backup = False prefill_reserved_tokens = 0 async with direct_session_lock: if config.kvcache_prefill_backup_policy == "capacity-backup": keep_prefill_backup, prefill_reserved_tokens, _prefill_evicted = ( await _reserve_prefill_backup_capacity( client=client, request=request, prefill_url=prefill_url, session=decode_session, direct_sessions=direct_sessions, residency=decode_residency, ) ) if ( decode_session.prefill_opened and decode_session.prefill_server_url != prefill_url ): await _close_prefill_session( client=client, session=decode_session, residency=decode_residency, ) prefill_session_newly_opened = False async with direct_session_lock: if not decode_session.prefill_opened: await _open_streaming_session( client=client, server_url=prefill_url, session_id=request.session_id, request=request, ) decode_session.prefill_opened = True decode_session.prefill_server_url = prefill_url prefill_session_newly_opened = True decode_session_newly_opened = False try: prefill_priority = _prefill_priority_for_router_request( config=config, direct_to_d_predicted=True, ) async with direct_session_lock: if not decode_session.opened: await _open_streaming_session( client=client, server_url=decode_session.server_url, session_id=request.session_id, request=request, ) decode_session.opened = True decode_session_newly_opened = True decode_session.active_requests += 1 latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_router( client=client, request=request, config=config, decode_worker_index=decision.decode_worker_index, session_id=request.session_id, prefill_request_priority=prefill_priority, ) except Exception: async with direct_session_lock: decode_session.active_requests = max(0, decode_session.active_requests - 1) _release_reserved_tokens( decode_residency, decode_session.server_url, reserved_tokens, ) _release_prefill_reserved_tokens( decode_residency, prefill_url, prefill_reserved_tokens, ) if decode_session_newly_opened: await _close_decode_session( client=client, session=decode_session, residency=decode_residency, ) if prefill_session_newly_opened: await _close_prefill_session( client=client, session=decode_session, residency=decode_residency, ) raise async with direct_session_lock: decode_session.active_requests = max(0, decode_session.active_requests - 1) if keep_prefill_backup: _commit_prefill_backup_residency( residency=decode_residency, session=decode_session, request=request, prefill_url=prefill_url, reserved_tokens=prefill_reserved_tokens, ) else: _release_prefill_reserved_tokens( decode_residency, prefill_url, prefill_reserved_tokens, ) await _close_prefill_session( client=client, session=decode_session, residency=decode_residency, ) _commit_session_residency( residency=decode_residency, session=decode_session, request=request, reserved_tokens=reserved_tokens, ) return ExecutionResult( execution_mode=execution_mode, actual_kv_transfer_blocks=decision.kv_transfer_blocks, effective_input_length=request.input_length, cached_tokens=cached_tokens, prefill_request_priority=prefill_priority, session_reused=False, session_reset=False, latency_s=latency_s, ttft_s=ttft_s, tpot_s=tpot_s, ) async def _execute_request( *, client: httpx.AsyncClient, request: TraceRequest, config: ReplayConfig, decision, direct_sessions: dict[str, DirectSessionState], direct_session_lock: asyncio.Lock, decode_residency: DecodeResidencyState, ) -> ExecutionResult: if config.mechanism_name == "pd-disaggregation": if not config.router_url: raise ValueError("router_url is required for pd-disaggregation replay") return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode="pd-disaggregation-router", ) if config.mechanism_name == "pd-colo": return await _invoke_direct( client=client, request=request, config=config, decision=decision, direct_sessions=direct_sessions, direct_session_lock=direct_session_lock, ) if config.mechanism_name == "kvcache-centric": if not config.router_url: raise ValueError("router_url is required for kvcache-centric replay") if not config.topology.decode_workers: raise ValueError("kvcache-centric mechanism requires at least one decode worker") prefill_url = _worker_url_by_id( config.topology.prefill_workers, decision.prefill_worker_id, ) decode_url = config.topology.decode_workers[decision.decode_worker_index].url async with direct_session_lock: decode_session = direct_sessions.get(request.session_id) if decode_session is None: decode_session = DirectSessionState( session_id=request.session_id, server_url=decode_url, ) direct_sessions[request.session_id] = decode_session elif decode_session.server_url != decode_url and decode_session.opened: await _close_decode_session( client=client, session=decode_session, residency=decode_residency, ) decode_session.server_url = decode_url else: decode_session.server_url = decode_url direct_append_length: int | None = None direct_session_reused = False direct_session_reset = False if request.turn_id > 1: async with direct_session_lock: ( direct_append_length, direct_session_reused, direct_session_reset, ) = _inspect_direct_request( request=request, session=decode_session, ) if request.turn_id == 1: seed_filter_reason = _seed_filter_reason( request=request, config=config, inflight_decode_load=decision.inflight_decode_load, ) if seed_filter_reason is not None: return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode=f"pd-router-turn1-{seed_filter_reason}", ) async with direct_session_lock: admit_new_decode_session = _should_admit_new_decode_session( residency=decode_residency, server_url=decode_url, request=request, session=decode_session, direct_sessions=direct_sessions, treat_as_fresh_session=True, ) if not admit_new_decode_session: can_seed = False reserved_tokens = 0 seed_reason = "d-session-cap" else: can_seed, reserved_tokens, _evicted, _p_backed, seed_reason = ( await _reserve_decode_session_capacity( client=client, config=config, request=request, server_url=decode_url, session=decode_session, direct_sessions=direct_sessions, residency=decode_residency, treat_as_fresh_session=True, routing_mode="seed", admission_mode=config.kvcache_admission_mode, ) ) if seed_reason == "d-session-cap": return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode="pd-router-turn1-session-cap", ) if can_seed: return await _invoke_kvcache_seeded_router( client=client, request=request, config=config, decision=decision, prefill_url=prefill_url, decode_session=decode_session, direct_sessions=direct_sessions, direct_session_lock=direct_session_lock, decode_residency=decode_residency, reserved_tokens=reserved_tokens, execution_mode="pd-router-turn1-seed", ) return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode=( "pd-router-turn1-d-backpressure" if seed_reason is not None and seed_reason != "d-no-space" else "pd-router-turn1-no-d-capacity" ), ) if ( _should_bypass_prefill( request=request, config=config, decision=decision, ) and direct_append_length is not None and direct_session_reused and not direct_session_reset and direct_append_length <= config.kvcache_direct_max_uncached_tokens ): async with direct_session_lock: can_direct = ( decode_session.opened and decode_session.server_url == decode_url and direct_session_reused and not direct_session_reset ) direct_reserved_tokens = 0 direct_reason: str | None = None if can_direct: async with direct_session_lock: ( can_direct, direct_reserved_tokens, _evicted, _p_backed, direct_reason, ) = ( await _reserve_decode_session_capacity( client=client, config=config, request=request, server_url=decode_url, session=decode_session, direct_sessions=direct_sessions, residency=decode_residency, treat_as_fresh_session=False, routing_mode="direct", admission_mode=config.kvcache_admission_mode, ) ) if can_direct: try: return await _invoke_decode_session_direct( client=client, request=request, config=config, decision=decision, direct_sessions=direct_sessions, direct_session_lock=direct_session_lock, decode_residency=decode_residency, reserved_tokens=direct_reserved_tokens, ) except Exception as exc: if not _is_stale_decode_session_error(exc): raise async with direct_session_lock: await _close_decode_session( client=client, session=decode_session, residency=decode_residency, ) return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode="pd-router-fallback-stale-d-session", ) if _is_decode_backpressure_reason(direct_reason): return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode="pd-router-fallback-d-backpressure", ) seed_filter_reason = _seed_filter_reason( request=request, config=config, inflight_decode_load=decision.inflight_decode_load, ) if seed_filter_reason is not None: return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode=f"pd-router-fallback-{seed_filter_reason}", ) async with direct_session_lock: admit_new_decode_session = _should_admit_new_decode_session( residency=decode_residency, server_url=decode_url, request=request, session=decode_session, direct_sessions=direct_sessions, treat_as_fresh_session=True, ) if not admit_new_decode_session: can_seed = False reserved_tokens = 0 evicted_sessions = 0 prefill_backed_evictions = 0 seed_reason = "d-session-cap" else: ( can_seed, reserved_tokens, evicted_sessions, prefill_backed_evictions, seed_reason, ) = ( await _reserve_decode_session_capacity( client=client, config=config, request=request, server_url=decode_url, session=decode_session, direct_sessions=direct_sessions, residency=decode_residency, treat_as_fresh_session=True, routing_mode="seed", admission_mode=config.kvcache_admission_mode, ) ) if seed_reason == "d-session-cap": return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode="pd-router-fallback-session-cap", ) if can_seed: return await _invoke_kvcache_seeded_router( client=client, request=request, config=config, decision=decision, prefill_url=prefill_url, decode_session=decode_session, direct_sessions=direct_sessions, direct_session_lock=direct_session_lock, decode_residency=decode_residency, reserved_tokens=reserved_tokens, execution_mode=( "pd-router-d-session-reseed" + _eviction_suffix( evicted_sessions, prefill_backed_evictions, ) ), ) return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode=( "pd-router-fallback-d-backpressure" if _is_decode_backpressure_reason(seed_reason) else "pd-router-fallback-no-d-capacity" ), ) seed_filter_reason = _seed_filter_reason( request=request, config=config, inflight_decode_load=decision.inflight_decode_load, ) if seed_filter_reason is not None: return await _invoke_plain_router( request=request, client=client, config=config, decision=decision, execution_mode=f"pd-router-fallback-large-append-{seed_filter_reason}", ) async with direct_session_lock: admit_new_decode_session = _should_admit_new_decode_session( residency=decode_residency, server_url=decode_url, request=request, session=decode_session, direct_sessions=direct_sessions, treat_as_fresh_session=True, ) if not admit_new_decode_session: can_seed = False reserved_tokens = 0 evicted_sessions = 0 prefill_backed_evictions = 0 seed_reason = "d-session-cap" else: ( can_seed, reserved_tokens, evicted_sessions, prefill_backed_evictions, seed_reason, ) = ( await _reserve_decode_session_capacity( client=client, config=config, request=request, server_url=decode_url, session=decode_session, direct_sessions=direct_sessions, residency=decode_residency, treat_as_fresh_session=True, routing_mode="seed", admission_mode=config.kvcache_admission_mode, ) ) if seed_reason == "d-session-cap": return await _invoke_plain_router( request=request, client=client, config=config, decision=decision, execution_mode="pd-router-fallback-large-append-session-cap", ) if can_seed: return await _invoke_kvcache_seeded_router( client=client, request=request, config=config, decision=decision, prefill_url=prefill_url, decode_session=decode_session, direct_sessions=direct_sessions, direct_session_lock=direct_session_lock, decode_residency=decode_residency, reserved_tokens=reserved_tokens, execution_mode=( "pd-router-large-append-reseed" + _eviction_suffix( evicted_sessions, prefill_backed_evictions, ) ), ) return await _invoke_plain_router( request=request, client=client, config=config, decision=decision, execution_mode=( "pd-router-fallback-d-backpressure" if _is_decode_backpressure_reason(seed_reason) else "pd-router-fallback-large-append" ), ) raise ValueError(f"Unsupported mechanism: {config.mechanism_name}") async def _invoke_direct( *, client: httpx.AsyncClient, request: TraceRequest, config: ReplayConfig, decision, direct_sessions: dict[str, DirectSessionState], direct_session_lock: asyncio.Lock, ) -> ExecutionResult: direct_workers = config.topology.direct_workers if not direct_workers: raise ValueError("pd-colo mechanism requires at least one direct worker") server_url = direct_workers[decision.decode_worker_index].url session = direct_sessions.get(request.session_id) if session is None or session.server_url != server_url: session = DirectSessionState(session_id=request.session_id, server_url=server_url) direct_sessions[request.session_id] = session return await _invoke_session_direct( client=client, request=request, config=config, session=session, execution_mode="pd-colo-direct-session", direct_session_lock=direct_session_lock, ) async def _invoke_session_direct( *, client: httpx.AsyncClient, request: TraceRequest, config: ReplayConfig, session: DirectSessionState, execution_mode: str, decode_residency: DecodeResidencyState | None = None, reserved_tokens: int = 0, direct_session_lock: asyncio.Lock | None = None, ) -> ExecutionResult: _prompt, effective_input_length, session_reused, session_reset = _build_direct_prompt( request=request, session=session, ) if session_reused: input_ids = _build_direct_append_input_ids(request, effective_input_length) else: input_ids = _build_direct_full_input_ids(request) if session.opened and (session_reset or not session_reused): if decode_residency is not None: await _close_decode_session( client=client, session=session, residency=decode_residency, ) else: await _close_streaming_session( client=client, server_url=session.server_url, session_id=session.session_id, allow_missing=True, ) session.opened = False session.resident_tokens = 0 if not session.opened: await _open_streaming_session( client=client, server_url=session.server_url, session_id=session.session_id, request=request, ) session.opened = True if direct_session_lock is not None: async with direct_session_lock: session.active_requests += 1 try: latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_generate( client=client, base_url=session.server_url, headers={"x-request-id": request.request_id}, payload={ "input_ids": input_ids, "sampling_params": { "temperature": 0, "max_new_tokens": max(1, request.output_length), "min_new_tokens": max(1, request.output_length), "ignore_eos": True, "no_stop_trim": True, "skip_special_tokens": False, }, "session_params": {"id": session.session_id}, "stream": config.stream, }, timeout_s=config.timeout_s, stream_idle_timeout_s=config.stream_idle_timeout_s, stream=config.stream, ) finally: if direct_session_lock is not None: async with direct_session_lock: session.active_requests = max(0, session.active_requests - 1) if decode_residency is not None: _commit_session_residency( residency=decode_residency, session=session, request=request, reserved_tokens=reserved_tokens, ) else: session.last_trace_request = request session.last_access_s = time.perf_counter() return ExecutionResult( execution_mode=execution_mode, actual_kv_transfer_blocks=0, effective_input_length=len(input_ids), cached_tokens=cached_tokens, session_reused=session_reused, session_reset=session_reset, latency_s=latency_s, ttft_s=ttft_s, tpot_s=tpot_s, ) async def _invoke_decode_session_direct( *, client: httpx.AsyncClient, request: TraceRequest, config: ReplayConfig, decision, direct_sessions: dict[str, DirectSessionState], direct_session_lock: asyncio.Lock, decode_residency: DecodeResidencyState, reserved_tokens: int, ) -> ExecutionResult: decode_url = config.topology.decode_workers[decision.decode_worker_index].url session = direct_sessions[request.session_id] try: return await _invoke_session_direct( client=client, request=request, config=config, session=session, execution_mode="kvcache-direct-to-d-session", decode_residency=decode_residency, reserved_tokens=reserved_tokens, direct_session_lock=direct_session_lock, ) except Exception: async with direct_session_lock: _release_reserved_tokens( decode_residency, decode_url, reserved_tokens, ) raise def _should_bypass_prefill( *, request: TraceRequest, config: ReplayConfig, decision, block_token_budget: int = 24, ) -> bool: if request.turn_id <= 1: return False if decision.observed_overlap_blocks <= 0: return False uncached_tokens = max(0, decision.kv_transfer_blocks * block_token_budget) return uncached_tokens <= config.kvcache_direct_max_uncached_tokens def _worker_url_by_id(workers, worker_id: str) -> str: for worker in workers: if worker.worker_id == worker_id: return worker.url raise KeyError(f"Unknown worker id: {worker_id}") def _build_headers( *, request: TraceRequest, header_mode: HeaderMode, decode_worker_index: int, policy_name: str, ) -> dict[str, str]: if header_mode == "auto": header_mode = "routing-key" if policy_name == "sticky" else "none" headers = { "x-request-id": request.request_id, } if header_mode == "routing-key": headers["x-smg-routing-key"] = request.session_id elif header_mode == "target-worker": headers["x-smg-target-worker"] = str(decode_worker_index) return headers def _contains_token(payload: dict) -> bool: choices = payload.get("choices") if not isinstance(choices, list) or not choices: return False delta = choices[0].get("delta") if not isinstance(delta, dict): return False reasoning_content = delta.get("reasoning_content") if isinstance(reasoning_content, str) and reasoning_content: return True content = delta.get("content") if isinstance(content, str) and content: return True if isinstance(content, list): return any( isinstance(item, dict) and item.get("text") for item in content ) tool_calls = delta.get("tool_calls") if isinstance(tool_calls, list): for tool_call in tool_calls: if not isinstance(tool_call, dict): continue if tool_call.get("id"): return True function = tool_call.get("function") if not isinstance(function, dict): continue if function.get("name") or function.get("arguments"): return True return False def _contains_generate_token(payload: dict) -> bool: text = payload.get("text") if isinstance(text, str) and text: return True meta_info = payload.get("meta_info") if not isinstance(meta_info, dict): return False return int(meta_info.get("completion_tokens", 0)) > 0 def _is_generate_terminal_chunk(payload: dict) -> bool: meta_info = payload.get("meta_info") if not isinstance(meta_info, dict): return False return meta_info.get("finish_reason") is not None def _extract_generate_cached_tokens(payload: dict) -> int: meta_info = payload.get("meta_info") if not isinstance(meta_info, dict): return 0 return int(meta_info.get("cached_tokens", 0) or 0) def _extract_openai_cached_tokens(payload: dict) -> int: usage = payload.get("usage") if not isinstance(usage, dict): return 0 prompt_tokens_details = usage.get("prompt_tokens_details") if not isinstance(prompt_tokens_details, dict): return 0 return int(prompt_tokens_details.get("cached_tokens", 0) or 0) async def _aiter_lines( response: httpx.Response, *, idle_timeout_s: float | None, ): line_iterator = response.aiter_lines() while True: try: if idle_timeout_s is None or idle_timeout_s <= 0: line = await anext(line_iterator) else: line = await asyncio.wait_for( anext(line_iterator), timeout=idle_timeout_s, ) except StopAsyncIteration: return yield line def _is_terminal_chunk(payload: dict) -> bool: choices = payload.get("choices") if not isinstance(choices, list) or not choices: return False finish_reason = choices[0].get("finish_reason") return finish_reason is not None