from __future__ import annotations import asyncio import json import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Literal import httpx from agentic_pd_hybrid.metrics import ( RequestMetrics, write_metrics_jsonl, write_summary_json, ) from agentic_pd_hybrid.policies import RoutingState, create_policy from agentic_pd_hybrid.topology import SingleNodeTopology from agentic_pd_hybrid.trace import ( TraceRequest, build_synthetic_append_chunk, build_synthetic_prompt, load_trace, ) HeaderMode = Literal["none", "routing-key", "target-worker", "auto"] KvCacheAdmissionMode = Literal["router", "worker"] @dataclass(frozen=True) class ReplayConfig: trace_path: Path output_path: Path policy_name: str mechanism_name: str topology: SingleNodeTopology router_url: str | None = None model_name: str | None = None pace: bool = True time_scale: float = 1.0 request_limit: int | None = None concurrency_limit: int = 32 header_mode: HeaderMode = "auto" timeout_s: float = 600.0 stream: bool = True stream_idle_timeout_s: float | None = 900.0 kvcache_direct_max_uncached_tokens: int = 2048 kvcache_admission_mode: KvCacheAdmissionMode = "router" @dataclass class DirectSessionState: session_id: str server_url: str opened: bool = False last_trace_request: TraceRequest | None = None resident_tokens: int = 0 last_access_s: float = 0.0 active_requests: int = 0 prefill_server_url: str | None = None prefill_opened: bool = False prefill_resident_tokens: int = 0 prefill_last_access_s: float = 0.0 prefill_low_priority: bool = False @dataclass class DecodeResidencyState: capacity_tokens: dict[str, int] = field(default_factory=dict) headroom_tokens: dict[str, int] = field(default_factory=dict) reserved_decode_tokens: dict[str, int] = field(default_factory=dict) resident_tokens_by_server: dict[str, int] = field(default_factory=dict) reserved_tokens_by_server: dict[str, int] = field(default_factory=dict) prefill_capacity_tokens: dict[str, int] = field(default_factory=dict) prefill_headroom_tokens: dict[str, int] = field(default_factory=dict) prefill_resident_tokens_by_server: dict[str, int] = field(default_factory=dict) prefill_reserved_tokens_by_server: dict[str, int] = field(default_factory=dict) decode_evictions_prefill_backed: int = 0 decode_evictions_without_prefill_backup: int = 0 @dataclass(frozen=True) class DecodeLoadSnapshot: timestamp_s: float num_running_reqs: int num_waiting_reqs: int num_used_tokens: int max_total_num_tokens: int token_usage: float decode_prealloc_queue_reqs: int decode_transfer_queue_reqs: int decode_retracted_queue_reqs: int @dataclass(frozen=True) class ExecutionResult: execution_mode: str actual_kv_transfer_blocks: int effective_input_length: int | None cached_tokens: int session_reused: bool session_reset: bool latency_s: float | None ttft_s: float | None tpot_s: float | None error: str | None = None async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]: requests = load_trace(config.trace_path, request_limit=config.request_limit) policy = create_policy(config.policy_name) state = RoutingState.create(config.topology) state_lock = asyncio.Lock() semaphore = asyncio.Semaphore(config.concurrency_limit) start_time = time.perf_counter() first_timestamp = requests[0].timestamp_s if requests else 0.0 session_tail_tasks: dict[str, asyncio.Task[RequestMetrics]] = {} direct_sessions: dict[str, DirectSessionState] = {} direct_session_lock = asyncio.Lock() async with httpx.AsyncClient(timeout=config.timeout_s, trust_env=False) as client: decode_residency = await _discover_decode_residency( client=client, config=config, ) tasks = [] for request in requests: if config.pace: target_offset = (request.timestamp_s - first_timestamp) / config.time_scale sleep_s = target_offset - (time.perf_counter() - start_time) if sleep_s > 0: await asyncio.sleep(sleep_s) tasks.append( asyncio.create_task( _run_request( request=request, config=config, client=client, policy=policy, state=state, state_lock=state_lock, semaphore=semaphore, direct_sessions=direct_sessions, direct_session_lock=direct_session_lock, decode_residency=decode_residency, depends_on=session_tail_tasks.get(request.session_id), ) ) ) session_tail_tasks[request.session_id] = tasks[-1] results = await asyncio.gather(*tasks) for session in direct_sessions.values(): if session.opened: try: await _close_streaming_session( client=client, server_url=session.server_url, session_id=session.session_id, allow_missing=True, ) except Exception: pass if session.prefill_opened and session.prefill_server_url is not None: try: await _close_streaming_session( client=client, server_url=session.prefill_server_url, session_id=session.session_id, allow_missing=True, ) except Exception: pass write_metrics_jsonl(config.output_path, results) write_summary_json( config.output_path.with_suffix(config.output_path.suffix + ".summary.json"), results, trace_path=config.trace_path, router_url=config.router_url, ) return results async def _run_request( *, request: TraceRequest, config: ReplayConfig, client: httpx.AsyncClient, policy, state: RoutingState, state_lock: asyncio.Lock, semaphore: asyncio.Semaphore, direct_sessions: dict[str, DirectSessionState], direct_session_lock: asyncio.Lock, decode_residency: DecodeResidencyState, depends_on: asyncio.Task[RequestMetrics] | None, ) -> RequestMetrics: if depends_on is not None: await depends_on async with semaphore: async with state_lock: decision = policy.select(request, topology=config.topology, state=state) try: execution = await _execute_request( client=client, request=request, config=config, decision=decision, direct_sessions=direct_sessions, direct_session_lock=direct_session_lock, decode_residency=decode_residency, ) except Exception as exc: # pragma: no cover - defensive logging path execution = ExecutionResult( execution_mode=config.mechanism_name, actual_kv_transfer_blocks=0, effective_input_length=None, cached_tokens=0, session_reused=False, session_reset=False, latency_s=None, ttft_s=None, tpot_s=None, error=f"{type(exc).__name__}: {exc}", ) async with state_lock: state.finish(request, decision) return RequestMetrics.from_decision( request, decision, mechanism_name=config.mechanism_name, execution_mode=execution.execution_mode, actual_kv_transfer_blocks=execution.actual_kv_transfer_blocks, effective_input_length=execution.effective_input_length, cached_tokens=execution.cached_tokens, session_reused=execution.session_reused, session_reset=execution.session_reset, latency_s=execution.latency_s, ttft_s=execution.ttft_s, tpot_s=execution.tpot_s, error=execution.error, ) async def _invoke_router( *, client: httpx.AsyncClient, request: TraceRequest, config: ReplayConfig, decode_worker_index: int, session_id: str | None = None, ) -> tuple[float, float | None, float | None, int]: headers = _build_headers( request=request, header_mode=config.header_mode, decode_worker_index=decode_worker_index, policy_name=config.policy_name, ) assert config.router_url is not None payload: dict[str, object] = { "input_ids": _build_direct_full_input_ids(request), "sampling_params": { "temperature": 0, "max_new_tokens": max(1, request.output_length), "ignore_eos": True, "no_stop_trim": True, "skip_special_tokens": False, }, "stream": config.stream, } if session_id is not None: payload["session_params"] = {"id": session_id} return await _invoke_generate( client=client, base_url=config.router_url, headers=headers, payload=payload, timeout_s=config.timeout_s, stream_idle_timeout_s=config.stream_idle_timeout_s, stream=config.stream, ) def _build_payload( *, request: TraceRequest, model_name: str, prompt: str, stream: bool, session_params: dict[str, str] | None, exact_output_length: bool, ) -> dict[str, object]: payload: dict[str, object] = { "model": model_name, "messages": [{"role": "user", "content": prompt}], "max_tokens": max(1, request.output_length), "temperature": 0, "stream": stream, } if stream: payload["stream_options"] = {"include_usage": True} if exact_output_length: payload.update( { "min_tokens": max(1, request.output_length), "ignore_eos": True, "no_stop_trim": True, "skip_special_tokens": False, } ) if session_params is not None: payload["session_params"] = session_params return payload async def _invoke_chat_completion( *, client: httpx.AsyncClient, base_url: str, headers: dict[str, str], payload: dict[str, object], timeout_s: float, stream_idle_timeout_s: float | None, stream: bool, ) -> tuple[float, float | None, float | None, int]: start = time.perf_counter() ttft_s: float | None = None cached_tokens = 0 generated_tokens = int(payload.get("max_tokens", 1)) if stream: async with client.stream( "POST", f"{base_url.rstrip('/')}/v1/chat/completions", headers=headers, json=payload, timeout=timeout_s, ) as response: response.raise_for_status() async for line in _aiter_lines( response, idle_timeout_s=stream_idle_timeout_s, ): if not line.startswith("data:"): continue data = line[5:].strip() if data == "[DONE]": break parsed = json.loads(data) cached_tokens = max(cached_tokens, _extract_openai_cached_tokens(parsed)) if _contains_token(parsed) and ttft_s is None: ttft_s = time.perf_counter() - start if _is_terminal_chunk(parsed): break else: response = await client.post( f"{base_url.rstrip('/')}/v1/chat/completions", headers=headers, json=payload, timeout=timeout_s, ) response.raise_for_status() parsed = response.json() cached_tokens = _extract_openai_cached_tokens(parsed) latency_s = time.perf_counter() - start if stream and ttft_s is None and generated_tokens > 0: raise RuntimeError("generate stream ended before producing any token") if ttft_s is None: tpot_s = None else: tpot_s = max(0.0, latency_s - ttft_s) / max(1, generated_tokens) return latency_s, ttft_s, tpot_s, cached_tokens async def _invoke_generate( *, client: httpx.AsyncClient, base_url: str, headers: dict[str, str], payload: dict[str, object], timeout_s: float, stream_idle_timeout_s: float | None, stream: bool, ) -> tuple[float, float | None, float | None, int]: start = time.perf_counter() ttft_s: float | None = None cached_tokens = 0 sampling_params = payload.get("sampling_params", {}) generated_tokens = int(sampling_params.get("max_new_tokens", 1)) if stream: async with client.stream( "POST", f"{base_url.rstrip('/')}/generate", headers=headers, json=payload, timeout=timeout_s, ) as response: response.raise_for_status() async for line in _aiter_lines( response, idle_timeout_s=stream_idle_timeout_s, ): if not line.startswith("data:"): continue data = line[5:].strip() if data == "[DONE]": break parsed = json.loads(data) error = parsed.get("error") if isinstance(error, dict): raise ValueError(error.get("message", json.dumps(error))) cached_tokens = max(cached_tokens, _extract_generate_cached_tokens(parsed)) if _contains_generate_token(parsed) and ttft_s is None: ttft_s = time.perf_counter() - start if _is_generate_terminal_chunk(parsed): break else: response = await client.post( f"{base_url.rstrip('/')}/generate", headers=headers, json=payload, timeout=timeout_s, ) response.raise_for_status() parsed = response.json() error = parsed.get("error") if isinstance(error, dict): raise ValueError(error.get("message", json.dumps(error))) cached_tokens = _extract_generate_cached_tokens(parsed) latency_s = time.perf_counter() - start if stream and ttft_s is None and generated_tokens > 0: raise RuntimeError("generate stream ended before producing any token") if ttft_s is None: tpot_s = None else: tpot_s = max(0.0, latency_s - ttft_s) / max(1, generated_tokens) return latency_s, ttft_s, tpot_s, cached_tokens async def _open_streaming_session( *, client: httpx.AsyncClient, server_url: str, session_id: str, request: TraceRequest, ) -> None: capacity = max( 4096, request.input_length * 16, (request.input_length + request.output_length) * 16, ) response = await client.post( f"{server_url.rstrip('/')}/open_session", json={ "capacity_of_str_len": capacity, "session_id": session_id, "streaming": True, }, ) response.raise_for_status() opened_session_id = response.json() if opened_session_id != session_id: raise ValueError( f"Unexpected session id from {server_url}: {opened_session_id!r} != {session_id!r}" ) async def _close_streaming_session( *, client: httpx.AsyncClient, server_url: str, session_id: str, allow_missing: bool = False, ) -> None: response = await client.post( f"{server_url.rstrip('/')}/close_session", json={"session_id": session_id}, ) if response.is_success: return if allow_missing: response_text = response.text.lower() if response.status_code == 404 or "does not exist" in response_text: return response.raise_for_status() def _extract_internal_state(payload: dict[str, Any]) -> dict[str, Any]: internal_states = payload.get("internal_states") if isinstance(internal_states, list) and internal_states: internal_state = internal_states[0] if isinstance(internal_state, dict): return internal_state return payload def _extract_server_int( payload: dict[str, Any], key: str, ) -> int: internal_state = _extract_internal_state(payload) value = payload.get(key, internal_state.get(key, 0)) return int(value or 0) def _extract_session_cache( payload: dict[str, Any], ) -> dict[str, Any]: internal_state = _extract_internal_state(payload) session_cache = internal_state.get("session_cache") if isinstance(session_cache, dict): return session_cache return {} def _find_session_cache_status( session_cache: dict[str, Any], session_id: str, ) -> dict[str, Any] | None: sessions = session_cache.get("sessions") if not isinstance(sessions, list): return None for session in sessions: if isinstance(session, dict) and session.get("session_id") == session_id: return session return None async def _fetch_decode_server_state( *, client: httpx.AsyncClient, server_url: str, ) -> tuple[dict[str, Any], int, int]: try: response = await client.get(f"{server_url.rstrip('/')}/server_info") response.raise_for_status() payload = response.json() except Exception: return {}, 0, 0 return ( _extract_session_cache(payload), _extract_server_int(payload, "max_total_num_tokens"), _extract_server_int(payload, "num_reserved_decode_tokens"), ) async def _query_decode_direct_admission( *, client: httpx.AsyncClient, server_url: str, session_id: str, uncached_input_tokens: int, output_tokens: int, ) -> dict[str, Any]: try: response = await client.post( f"{server_url.rstrip('/')}/session_cache/admit_direct_append", json={ "session_id": session_id, "uncached_input_tokens": max(0, uncached_input_tokens), "output_tokens": max(0, output_tokens), }, ) response.raise_for_status() payload = response.json() if isinstance(payload, dict): return payload except Exception: pass return { "can_admit": False, "resident": False, "reason": "admission-query-failed", "required_tokens": 0, "available_tokens_before": 0, "available_tokens_after": 0, "evicted_session_count": 0, "freed_tokens": 0, } async def _discover_decode_residency( *, client: httpx.AsyncClient, config: ReplayConfig, ) -> DecodeResidencyState: residency = DecodeResidencyState() if config.mechanism_name != "kvcache-centric": return residency for worker in config.topology.decode_workers: _session_cache, max_total_num_tokens, reserved_decode_tokens = ( await _fetch_decode_server_state( client=client, server_url=worker.url, ) ) if max_total_num_tokens <= 0: continue safety_headroom = max( reserved_decode_tokens * 4, max_total_num_tokens // 20, 8192, ) residency.capacity_tokens[worker.url] = max_total_num_tokens residency.headroom_tokens[worker.url] = min( max_total_num_tokens, safety_headroom, ) residency.reserved_decode_tokens[worker.url] = reserved_decode_tokens for worker in config.topology.prefill_workers: _session_cache, max_total_num_tokens, _reserved_decode_tokens = ( await _fetch_decode_server_state( client=client, server_url=worker.url, ) ) if max_total_num_tokens <= 0: continue safety_headroom = max( max_total_num_tokens // 10, 16384, ) residency.prefill_capacity_tokens[worker.url] = max_total_num_tokens residency.prefill_headroom_tokens[worker.url] = min( max_total_num_tokens, safety_headroom, ) return residency def _estimate_session_resident_tokens(request: TraceRequest) -> int: return request.input_length + request.output_length def _inspect_direct_request( *, request: TraceRequest, session: DirectSessionState, ) -> tuple[int, bool, bool]: previous = session.last_trace_request if previous is None: return request.input_length, False, False append_length = request.input_length - ( previous.input_length + previous.output_length ) if append_length <= 0: return request.input_length, False, True return append_length, True, False def _add_reserved_tokens( residency: DecodeResidencyState, server_url: str, delta_tokens: int, ) -> None: if delta_tokens <= 0: return residency.reserved_tokens_by_server[server_url] = ( residency.reserved_tokens_by_server.get(server_url, 0) + delta_tokens ) def _release_reserved_tokens( residency: DecodeResidencyState, server_url: str, delta_tokens: int, ) -> None: if delta_tokens <= 0: return remaining = residency.reserved_tokens_by_server.get(server_url, 0) - delta_tokens if remaining > 0: residency.reserved_tokens_by_server[server_url] = remaining else: residency.reserved_tokens_by_server.pop(server_url, None) def _add_prefill_reserved_tokens( residency: DecodeResidencyState, server_url: str, delta_tokens: int, ) -> None: if delta_tokens <= 0: return residency.prefill_reserved_tokens_by_server[server_url] = ( residency.prefill_reserved_tokens_by_server.get(server_url, 0) + delta_tokens ) def _release_prefill_reserved_tokens( residency: DecodeResidencyState, server_url: str, delta_tokens: int, ) -> None: if delta_tokens <= 0: return remaining = ( residency.prefill_reserved_tokens_by_server.get(server_url, 0) - delta_tokens ) if remaining > 0: residency.prefill_reserved_tokens_by_server[server_url] = remaining else: residency.prefill_reserved_tokens_by_server.pop(server_url, None) def _usable_capacity_tokens( residency: DecodeResidencyState, server_url: str, ) -> int: return max( 0, residency.capacity_tokens.get(server_url, 0) - residency.headroom_tokens.get(server_url, 0), ) def _usable_prefill_backup_capacity_tokens( residency: DecodeResidencyState, server_url: str, ) -> int: return max( 0, residency.prefill_capacity_tokens.get(server_url, 0) - residency.prefill_headroom_tokens.get(server_url, 0), ) def _eviction_suffix(evicted_sessions: int, prefill_backed_evictions: int) -> str: if evicted_sessions <= 0: return "" if prefill_backed_evictions >= evicted_sessions: return "-after-prefill-backed-eviction" if prefill_backed_evictions > 0: return "-after-mixed-eviction" return "-after-eviction" def _decode_session_soft_cap( *, residency: DecodeResidencyState, server_url: str, request: TraceRequest, ) -> int: target_tokens = max(1, _estimate_session_resident_tokens(request)) usable_capacity_tokens = _usable_capacity_tokens(residency, server_url) if usable_capacity_tokens <= 0: usable_capacity_tokens = max( 0, residency.capacity_tokens.get(server_url, 0) - residency.headroom_tokens.get(server_url, 0), ) if usable_capacity_tokens <= 0: return 4 return max(1, min(4, usable_capacity_tokens // target_tokens)) def _should_admit_new_decode_session( *, residency: DecodeResidencyState, server_url: str, request: TraceRequest, session: DirectSessionState, direct_sessions: dict[str, DirectSessionState], treat_as_fresh_session: bool, ) -> bool: if ( not treat_as_fresh_session and session.opened and session.server_url == server_url ): return True open_sessions = sum( 1 for candidate in direct_sessions.values() if candidate.opened and candidate.server_url == server_url ) return open_sessions < _decode_session_soft_cap( residency=residency, server_url=server_url, request=request, ) async def _fetch_decode_load_snapshot( *, client: httpx.AsyncClient, server_url: str, ) -> DecodeLoadSnapshot | None: try: response = await client.get( f"{server_url.rstrip('/')}/v1/loads", params={"include": "core,disagg"}, ) response.raise_for_status() payload = response.json() except Exception: return None loads = payload.get("loads") if not isinstance(loads, list) or not loads: return None load = loads[0] disagg = load.get("disaggregation") or {} return DecodeLoadSnapshot( timestamp_s=time.perf_counter(), num_running_reqs=int(load.get("num_running_reqs", 0) or 0), num_waiting_reqs=int(load.get("num_waiting_reqs", 0) or 0), num_used_tokens=int(load.get("num_used_tokens", 0) or 0), max_total_num_tokens=int(load.get("max_total_num_tokens", 0) or 0), token_usage=float(load.get("token_usage", 0.0) or 0.0), decode_prealloc_queue_reqs=int( disagg.get("decode_prealloc_queue_reqs", 0) or 0 ), decode_transfer_queue_reqs=int( disagg.get("decode_transfer_queue_reqs", 0) or 0 ), decode_retracted_queue_reqs=int( disagg.get("decode_retracted_queue_reqs", 0) or 0 ), ) def _decode_load_backpressure_reason( snapshot: DecodeLoadSnapshot | None, *, routing_mode: Literal["direct", "seed"], ) -> str | None: if snapshot is None: return None if routing_mode == "direct": if snapshot.decode_retracted_queue_reqs > 0 and snapshot.token_usage >= 0.99: return "d-retracted" if snapshot.token_usage >= 0.992: return "d-token-usage-critical" else: if snapshot.decode_retracted_queue_reqs > 0: return "d-retracted" if snapshot.token_usage >= 0.985: return "d-token-usage-critical" if routing_mode == "seed" and snapshot.token_usage >= 0.94 and ( snapshot.decode_prealloc_queue_reqs > 0 or snapshot.decode_transfer_queue_reqs > 0 ): return "d-prealloc-backpressure" return None def _is_decode_backpressure_reason(reason: str | None) -> bool: return reason in { "d-retracted", "d-token-usage-critical", "d-prealloc-backpressure", } def _dynamic_decode_headroom_tokens( *, residency: DecodeResidencyState, server_url: str, snapshot: DecodeLoadSnapshot | None, routing_mode: Literal["direct", "seed"], ) -> int: if snapshot is None: return residency.headroom_tokens.get(server_url, 0) base_reserved = max(512, residency.reserved_decode_tokens.get(server_url, 0)) if routing_mode == "direct": direct_queue_pressure = max( 1, snapshot.decode_prealloc_queue_reqs + snapshot.decode_transfer_queue_reqs + snapshot.decode_retracted_queue_reqs, ) capacity_divisor = 24 minimum_headroom = 4096 return max( base_reserved * direct_queue_pressure, snapshot.max_total_num_tokens // capacity_divisor, minimum_headroom, ) disagg_queued = ( snapshot.decode_prealloc_queue_reqs + snapshot.decode_transfer_queue_reqs + snapshot.decode_retracted_queue_reqs ) active_decode_pressure = max(1, snapshot.num_running_reqs + disagg_queued) capacity_divisor = 15 minimum_headroom = 12288 return max( base_reserved * active_decode_pressure, snapshot.max_total_num_tokens // capacity_divisor, minimum_headroom, ) def _commit_session_residency( *, residency: DecodeResidencyState, session: DirectSessionState, request: TraceRequest, reserved_tokens: int, ) -> None: _release_reserved_tokens(residency, session.server_url, reserved_tokens) previous_tokens = session.resident_tokens if session.opened else 0 new_tokens = _estimate_session_resident_tokens(request) delta_tokens = new_tokens - previous_tokens if delta_tokens != 0: residency.resident_tokens_by_server[session.server_url] = ( residency.resident_tokens_by_server.get(session.server_url, 0) + delta_tokens ) session.opened = True session.resident_tokens = new_tokens session.last_trace_request = request session.last_access_s = time.perf_counter() if session.prefill_opened: session.prefill_low_priority = True def _commit_prefill_backup_residency( *, residency: DecodeResidencyState, session: DirectSessionState, request: TraceRequest, prefill_url: str, reserved_tokens: int, ) -> None: _release_prefill_reserved_tokens(residency, prefill_url, reserved_tokens) previous_tokens = session.prefill_resident_tokens if session.prefill_opened else 0 new_tokens = _estimate_session_resident_tokens(request) delta_tokens = new_tokens - previous_tokens if delta_tokens != 0: residency.prefill_resident_tokens_by_server[prefill_url] = ( residency.prefill_resident_tokens_by_server.get(prefill_url, 0) + delta_tokens ) session.prefill_server_url = prefill_url session.prefill_opened = True session.prefill_resident_tokens = new_tokens session.prefill_last_access_s = time.perf_counter() session.prefill_low_priority = session.opened async def _close_prefill_session( *, client: httpx.AsyncClient, session: DirectSessionState, residency: DecodeResidencyState, ) -> None: if not session.prefill_opened or session.prefill_server_url is None: session.prefill_opened = False session.prefill_resident_tokens = 0 session.prefill_low_priority = False return prefill_url = session.prefill_server_url await _close_streaming_session( client=client, server_url=prefill_url, session_id=session.session_id, allow_missing=True, ) remaining = ( residency.prefill_resident_tokens_by_server.get(prefill_url, 0) - session.prefill_resident_tokens ) if remaining > 0: residency.prefill_resident_tokens_by_server[prefill_url] = remaining else: residency.prefill_resident_tokens_by_server.pop(prefill_url, None) session.prefill_opened = False session.prefill_resident_tokens = 0 session.prefill_low_priority = False async def _close_decode_session( *, client: httpx.AsyncClient, session: DirectSessionState, residency: DecodeResidencyState, evicting_for_capacity: bool = False, ) -> None: if not session.opened: session.resident_tokens = 0 return await _close_streaming_session( client=client, server_url=session.server_url, session_id=session.session_id, allow_missing=True, ) remaining = ( residency.resident_tokens_by_server.get(session.server_url, 0) - session.resident_tokens ) if remaining > 0: residency.resident_tokens_by_server[session.server_url] = remaining else: residency.resident_tokens_by_server.pop(session.server_url, None) session.opened = False session.resident_tokens = 0 if session.prefill_opened: residency.decode_evictions_prefill_backed += int(evicting_for_capacity) session.prefill_low_priority = False elif evicting_for_capacity: residency.decode_evictions_without_prefill_backup += 1 async def _reserve_prefill_backup_capacity( *, client: httpx.AsyncClient, request: TraceRequest, prefill_url: str, session: DirectSessionState, direct_sessions: dict[str, DirectSessionState], residency: DecodeResidencyState, ) -> tuple[bool, int, int]: session_cache, max_total_num_tokens, _reserved_decode_tokens = ( await _fetch_decode_server_state( client=client, server_url=prefill_url, ) ) if max_total_num_tokens > 0: residency.prefill_capacity_tokens[prefill_url] = max_total_num_tokens capacity_tokens = residency.prefill_capacity_tokens.get(prefill_url, 0) headroom_tokens = residency.prefill_headroom_tokens.get(prefill_url, 0) if capacity_tokens <= 0: return True, 0, 0 low_occupancy_headroom_tokens = max( headroom_tokens, capacity_tokens // 2, ) target_session_status = _find_session_cache_status( session_cache, session.session_id, ) if ( isinstance(target_session_status, dict) and bool(target_session_status.get("resident")) ): current_tokens = int(target_session_status.get("resident_tokens", 0) or 0) else: current_tokens = ( session.prefill_resident_tokens if session.prefill_opened and session.prefill_server_url == prefill_url else 0 ) target_tokens = _estimate_session_resident_tokens(request) required_extra_tokens = max(0, target_tokens - current_tokens) evicted_sessions = 0 max_backup_sessions = max(1, capacity_tokens // max(1, target_tokens * 2)) max_backup_sessions = min(max_backup_sessions, 4) available_tokens = int(session_cache.get("available_tokens", 0) or 0) if available_tokens <= 0: held_tokens = int(session_cache.get("held_tokens", 0) or 0) available_tokens = max(0, capacity_tokens - held_tokens) available_tokens -= residency.prefill_reserved_tokens_by_server.get(prefill_url, 0) def has_enough_prefill_headroom() -> bool: return available_tokens - required_extra_tokens >= low_occupancy_headroom_tokens def prefill_backup_count() -> int: return sum( 1 for candidate in direct_sessions.values() if candidate.prefill_opened and candidate.prefill_server_url == prefill_url ) while ( required_extra_tokens > 0 and ( not has_enough_prefill_headroom() or ( not session.prefill_opened and prefill_backup_count() >= max_backup_sessions ) ) ): candidates = sorted( ( candidate for candidate in direct_sessions.values() if candidate.prefill_opened and candidate.prefill_server_url == prefill_url and candidate.session_id != session.session_id and candidate.active_requests <= 0 ), key=lambda candidate: ( 0 if candidate.prefill_low_priority else 1, candidate.prefill_last_access_s, ), ) if not candidates: break freed_tokens = candidates[0].prefill_resident_tokens await _close_prefill_session( client=client, session=candidates[0], residency=residency, ) available_tokens += freed_tokens evicted_sessions += 1 if not has_enough_prefill_headroom(): return False, 0, evicted_sessions _add_prefill_reserved_tokens(residency, prefill_url, required_extra_tokens) return True, required_extra_tokens, evicted_sessions async def _reserve_decode_session_capacity( *, client: httpx.AsyncClient, request: TraceRequest, server_url: str, session: DirectSessionState, direct_sessions: dict[str, DirectSessionState], residency: DecodeResidencyState, treat_as_fresh_session: bool, routing_mode: Literal["direct", "seed"], admission_mode: KvCacheAdmissionMode, ) -> tuple[bool, int, int, int, str | None]: if admission_mode == "router": return await _reserve_decode_session_capacity_from_router_state( client=client, request=request, server_url=server_url, session=session, direct_sessions=direct_sessions, residency=residency, treat_as_fresh_session=treat_as_fresh_session, routing_mode=routing_mode, ) if treat_as_fresh_session and session.opened: await _close_decode_session( client=client, session=session, residency=residency, ) current_tokens = 0 if treat_as_fresh_session else session.resident_tokens target_tokens = _estimate_session_resident_tokens(request) required_extra_tokens = max(0, target_tokens - current_tokens) prefill_backed_evictions = 0 if routing_mode == "direct" and not treat_as_fresh_session: if not session.opened: return False, 0, 0, 0, "d-session-not-resident" admission = await _query_decode_direct_admission( client=client, server_url=server_url, session_id=session.session_id, uncached_input_tokens=max(0, request.input_length - current_tokens), output_tokens=request.output_length, ) if not bool(admission.get("resident")): return False, 0, 0, 0, str(admission.get("reason") or "d-session-not-resident") if not bool(admission.get("can_admit")): return ( False, 0, int(admission.get("evicted_session_count", 0) or 0), 0, str(admission.get("reason") or "d-no-space"), ) reserved_tokens = int( admission.get("required_tokens", required_extra_tokens) or required_extra_tokens ) _add_reserved_tokens(residency, server_url, reserved_tokens) return ( True, reserved_tokens, int(admission.get("evicted_session_count", 0) or 0), 0, None, ) session_cache, max_total_num_tokens, reserved_decode_tokens = ( await _fetch_decode_server_state( client=client, server_url=server_url, ) ) if max_total_num_tokens > 0: residency.capacity_tokens[server_url] = max_total_num_tokens if reserved_decode_tokens > 0: residency.reserved_decode_tokens[server_url] = reserved_decode_tokens target_session_status = _find_session_cache_status( session_cache, session.session_id, ) if routing_mode == "direct" and not ( isinstance(target_session_status, dict) and bool(target_session_status.get("resident")) ): return False, 0, 0, 0, "d-session-not-resident" load_snapshot = await _fetch_decode_load_snapshot( client=client, server_url=server_url, ) if load_snapshot is not None and load_snapshot.max_total_num_tokens > 0: residency.capacity_tokens[server_url] = load_snapshot.max_total_num_tokens backpressure_reason = _decode_load_backpressure_reason( load_snapshot, routing_mode=routing_mode, ) if backpressure_reason is not None: return False, 0, 0, 0, backpressure_reason usable_capacity_tokens = _usable_capacity_tokens(residency, server_url) evicted_sessions = 0 while ( required_extra_tokens > 0 and residency.resident_tokens_by_server.get(server_url, 0) + residency.reserved_tokens_by_server.get(server_url, 0) + required_extra_tokens > usable_capacity_tokens + int(session_cache.get("idle_evictable_tokens", 0) or 0) ): candidates = sorted( ( candidate for candidate in direct_sessions.values() if candidate.opened and candidate.server_url == server_url and candidate.session_id != session.session_id and candidate.active_requests <= 0 ), key=lambda candidate: candidate.last_access_s, ) if not candidates: break await _close_decode_session( client=client, session=candidates[0], residency=residency, evicting_for_capacity=True, ) prefill_backed_evictions += int(candidates[0].prefill_opened) evicted_sessions += 1 if evicted_sessions > 0: load_snapshot = await _fetch_decode_load_snapshot( client=client, server_url=server_url, ) session_cache, max_total_num_tokens, reserved_decode_tokens = ( await _fetch_decode_server_state( client=client, server_url=server_url, ) ) if max_total_num_tokens > 0: residency.capacity_tokens[server_url] = max_total_num_tokens if reserved_decode_tokens > 0: residency.reserved_decode_tokens[server_url] = reserved_decode_tokens usable_capacity_tokens = _usable_capacity_tokens(residency, server_url) if load_snapshot is not None: dynamic_headroom = _dynamic_decode_headroom_tokens( residency=residency, server_url=server_url, snapshot=load_snapshot, routing_mode=routing_mode, ) residency.headroom_tokens[server_url] = min( residency.capacity_tokens.get(server_url, dynamic_headroom), dynamic_headroom, ) usable_capacity_tokens = max( 0, residency.capacity_tokens.get(server_url, 0) - dynamic_headroom, ) usable_capacity_tokens = _usable_capacity_tokens(residency, server_url) if load_snapshot is not None: dynamic_headroom = _dynamic_decode_headroom_tokens( residency=residency, server_url=server_url, snapshot=load_snapshot, routing_mode=routing_mode, ) residency.headroom_tokens[server_url] = min( residency.capacity_tokens.get(server_url, dynamic_headroom), dynamic_headroom, ) usable_capacity_tokens = max( 0, residency.capacity_tokens.get(server_url, 0) - dynamic_headroom, ) effective_used_tokens = ( load_snapshot.num_used_tokens if load_snapshot is not None else residency.resident_tokens_by_server.get(server_url, 0) ) + residency.reserved_tokens_by_server.get(server_url, 0) idle_evictable_tokens = int(session_cache.get("idle_evictable_tokens", 0) or 0) if ( routing_mode == "direct" and isinstance(target_session_status, dict) and bool(target_session_status.get("idle_evictable")) ): idle_evictable_tokens = max( 0, idle_evictable_tokens - int(target_session_status.get("resident_tokens", 0) or 0), ) if effective_used_tokens + required_extra_tokens > ( usable_capacity_tokens + idle_evictable_tokens ): return False, 0, evicted_sessions, prefill_backed_evictions, "d-no-space" _add_reserved_tokens(residency, server_url, required_extra_tokens) return True, required_extra_tokens, evicted_sessions, prefill_backed_evictions, None async def _reserve_decode_session_capacity_from_router_state( *, client: httpx.AsyncClient, request: TraceRequest, server_url: str, session: DirectSessionState, direct_sessions: dict[str, DirectSessionState], residency: DecodeResidencyState, treat_as_fresh_session: bool, routing_mode: Literal["direct", "seed"], ) -> tuple[bool, int, int, int, str | None]: if treat_as_fresh_session and session.opened: await _close_decode_session( client=client, session=session, residency=residency, ) if routing_mode == "direct" and not session.opened: return False, 0, 0, 0, "d-session-not-resident" current_tokens = 0 if treat_as_fresh_session else session.resident_tokens target_tokens = _estimate_session_resident_tokens(request) required_extra_tokens = max(0, target_tokens - current_tokens) usable_capacity_tokens = _usable_capacity_tokens(residency, server_url) # If discovery failed, do not force every request down the P/D fallback path. # The router can still preserve correctness; this only disables proactive # capacity admission until the worker reports capacity again in a later run. if usable_capacity_tokens <= 0: _add_reserved_tokens(residency, server_url, required_extra_tokens) return True, required_extra_tokens, 0, 0, None evicted_sessions = 0 prefill_backed_evictions = 0 while ( required_extra_tokens > 0 and residency.resident_tokens_by_server.get(server_url, 0) + residency.reserved_tokens_by_server.get(server_url, 0) + required_extra_tokens > usable_capacity_tokens ): candidates = sorted( ( candidate for candidate in direct_sessions.values() if candidate.opened and candidate.server_url == server_url and candidate.session_id != session.session_id and candidate.active_requests <= 0 ), key=lambda candidate: candidate.last_access_s, ) if not candidates: break await _close_decode_session( client=client, session=candidates[0], residency=residency, evicting_for_capacity=True, ) prefill_backed_evictions += int(candidates[0].prefill_opened) evicted_sessions += 1 if ( residency.resident_tokens_by_server.get(server_url, 0) + residency.reserved_tokens_by_server.get(server_url, 0) + required_extra_tokens > usable_capacity_tokens ): return False, 0, evicted_sessions, prefill_backed_evictions, "d-no-space" _add_reserved_tokens(residency, server_url, required_extra_tokens) return True, required_extra_tokens, evicted_sessions, prefill_backed_evictions, None def _build_direct_prompt( *, request: TraceRequest, session: DirectSessionState, ) -> tuple[str, int, bool, bool]: append_length, session_reused, session_reset = _inspect_direct_request( request=request, session=session, ) if session_reset: return build_synthetic_prompt(request), request.input_length, False, True if not session_reused: return build_synthetic_prompt(request), request.input_length, False, False return ( build_synthetic_append_chunk(request, append_length), append_length, session_reused, session_reset, ) def _build_direct_full_input_ids( request: TraceRequest, *, block_token_budget: int = 24, ) -> list[int]: input_ids: list[int] = [] for hash_id in request.hash_ids: base = 1000 + (hash_id % 3000) for offset in range(block_token_budget): input_ids.append(1000 + ((base + offset) % 3000)) while len(input_ids) < request.input_length: input_ids.append(5000 + (len(input_ids) % 2000)) return input_ids[: request.input_length] def _build_direct_append_input_ids( request: TraceRequest, append_length: int, ) -> list[int]: if append_length <= 0: return [] base = 9000 + ((request.chat_id + request.turn_id) % 2000) return [ 9000 + ((base + offset) % 2000) for offset in range(append_length) ] async def _invoke_plain_router( *, client: httpx.AsyncClient, request: TraceRequest, config: ReplayConfig, decision, execution_mode: str, ) -> ExecutionResult: latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_router( client=client, request=request, config=config, decode_worker_index=decision.decode_worker_index, ) return ExecutionResult( execution_mode=execution_mode, actual_kv_transfer_blocks=decision.kv_transfer_blocks, effective_input_length=request.input_length, cached_tokens=cached_tokens, session_reused=False, session_reset=False, latency_s=latency_s, ttft_s=ttft_s, tpot_s=tpot_s, ) async def _invoke_kvcache_seeded_router( *, client: httpx.AsyncClient, request: TraceRequest, config: ReplayConfig, decision, prefill_url: str, decode_session: DirectSessionState, direct_sessions: dict[str, DirectSessionState], direct_session_lock: asyncio.Lock, decode_residency: DecodeResidencyState, reserved_tokens: int, execution_mode: str, ) -> ExecutionResult: async with direct_session_lock: keep_prefill_backup, prefill_reserved_tokens, _prefill_evicted = ( await _reserve_prefill_backup_capacity( client=client, request=request, prefill_url=prefill_url, session=decode_session, direct_sessions=direct_sessions, residency=decode_residency, ) ) if ( decode_session.prefill_opened and decode_session.prefill_server_url != prefill_url ): await _close_prefill_session( client=client, session=decode_session, residency=decode_residency, ) prefill_session_newly_opened = False async with direct_session_lock: if not decode_session.prefill_opened: await _open_streaming_session( client=client, server_url=prefill_url, session_id=request.session_id, request=request, ) decode_session.prefill_opened = True decode_session.prefill_server_url = prefill_url prefill_session_newly_opened = True decode_session_newly_opened = False try: async with direct_session_lock: if not decode_session.opened: await _open_streaming_session( client=client, server_url=decode_session.server_url, session_id=request.session_id, request=request, ) decode_session.opened = True decode_session_newly_opened = True decode_session.active_requests += 1 latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_router( client=client, request=request, config=config, decode_worker_index=decision.decode_worker_index, session_id=request.session_id, ) except Exception: async with direct_session_lock: decode_session.active_requests = max(0, decode_session.active_requests - 1) _release_reserved_tokens( decode_residency, decode_session.server_url, reserved_tokens, ) _release_prefill_reserved_tokens( decode_residency, prefill_url, prefill_reserved_tokens, ) if decode_session_newly_opened: await _close_decode_session( client=client, session=decode_session, residency=decode_residency, ) if prefill_session_newly_opened: await _close_prefill_session( client=client, session=decode_session, residency=decode_residency, ) raise async with direct_session_lock: decode_session.active_requests = max(0, decode_session.active_requests - 1) if keep_prefill_backup: _commit_prefill_backup_residency( residency=decode_residency, session=decode_session, request=request, prefill_url=prefill_url, reserved_tokens=prefill_reserved_tokens, ) else: _release_prefill_reserved_tokens( decode_residency, prefill_url, prefill_reserved_tokens, ) await _close_prefill_session( client=client, session=decode_session, residency=decode_residency, ) _commit_session_residency( residency=decode_residency, session=decode_session, request=request, reserved_tokens=reserved_tokens, ) return ExecutionResult( execution_mode=execution_mode, actual_kv_transfer_blocks=decision.kv_transfer_blocks, effective_input_length=request.input_length, cached_tokens=cached_tokens, session_reused=False, session_reset=False, latency_s=latency_s, ttft_s=ttft_s, tpot_s=tpot_s, ) async def _execute_request( *, client: httpx.AsyncClient, request: TraceRequest, config: ReplayConfig, decision, direct_sessions: dict[str, DirectSessionState], direct_session_lock: asyncio.Lock, decode_residency: DecodeResidencyState, ) -> ExecutionResult: if config.mechanism_name == "pd-disaggregation": if not config.router_url: raise ValueError("router_url is required for pd-disaggregation replay") return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode="pd-disaggregation-router", ) if config.mechanism_name == "pd-colo": return await _invoke_direct( client=client, request=request, config=config, decision=decision, direct_sessions=direct_sessions, direct_session_lock=direct_session_lock, ) if config.mechanism_name == "kvcache-centric": if not config.router_url: raise ValueError("router_url is required for kvcache-centric replay") if not config.topology.decode_workers: raise ValueError("kvcache-centric mechanism requires at least one decode worker") prefill_url = _worker_url_by_id( config.topology.prefill_workers, decision.prefill_worker_id, ) decode_url = config.topology.decode_workers[decision.decode_worker_index].url async with direct_session_lock: decode_session = direct_sessions.get(request.session_id) if decode_session is None: decode_session = DirectSessionState( session_id=request.session_id, server_url=decode_url, ) direct_sessions[request.session_id] = decode_session elif decode_session.server_url != decode_url and decode_session.opened: await _close_decode_session( client=client, session=decode_session, residency=decode_residency, ) decode_session.server_url = decode_url else: decode_session.server_url = decode_url direct_append_length: int | None = None direct_session_reused = False direct_session_reset = False if request.turn_id > 1: async with direct_session_lock: ( direct_append_length, direct_session_reused, direct_session_reset, ) = _inspect_direct_request( request=request, session=decode_session, ) if request.turn_id == 1: async with direct_session_lock: admit_new_decode_session = _should_admit_new_decode_session( residency=decode_residency, server_url=decode_url, request=request, session=decode_session, direct_sessions=direct_sessions, treat_as_fresh_session=True, ) if not admit_new_decode_session: can_seed = False reserved_tokens = 0 seed_reason = "d-session-cap" else: can_seed, reserved_tokens, _evicted, _p_backed, seed_reason = ( await _reserve_decode_session_capacity( client=client, request=request, server_url=decode_url, session=decode_session, direct_sessions=direct_sessions, residency=decode_residency, treat_as_fresh_session=True, routing_mode="seed", admission_mode=config.kvcache_admission_mode, ) ) if seed_reason == "d-session-cap": return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode="pd-router-turn1-session-cap", ) if can_seed: return await _invoke_kvcache_seeded_router( client=client, request=request, config=config, decision=decision, prefill_url=prefill_url, decode_session=decode_session, direct_sessions=direct_sessions, direct_session_lock=direct_session_lock, decode_residency=decode_residency, reserved_tokens=reserved_tokens, execution_mode="pd-router-turn1-seed", ) return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode=( "pd-router-turn1-d-backpressure" if seed_reason is not None and seed_reason != "d-no-space" else "pd-router-turn1-no-d-capacity" ), ) if ( _should_bypass_prefill( request=request, config=config, decision=decision, ) and direct_append_length is not None and direct_session_reused and not direct_session_reset and direct_append_length <= config.kvcache_direct_max_uncached_tokens ): async with direct_session_lock: can_direct = ( decode_session.opened and decode_session.server_url == decode_url and direct_session_reused and not direct_session_reset ) direct_reserved_tokens = 0 direct_reason: str | None = None if can_direct: async with direct_session_lock: ( can_direct, direct_reserved_tokens, _evicted, _p_backed, direct_reason, ) = ( await _reserve_decode_session_capacity( client=client, request=request, server_url=decode_url, session=decode_session, direct_sessions=direct_sessions, residency=decode_residency, treat_as_fresh_session=False, routing_mode="direct", admission_mode=config.kvcache_admission_mode, ) ) if can_direct: return await _invoke_decode_session_direct( client=client, request=request, config=config, decision=decision, direct_sessions=direct_sessions, direct_session_lock=direct_session_lock, decode_residency=decode_residency, reserved_tokens=direct_reserved_tokens, ) if _is_decode_backpressure_reason(direct_reason): return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode="pd-router-fallback-d-backpressure", ) async with direct_session_lock: admit_new_decode_session = _should_admit_new_decode_session( residency=decode_residency, server_url=decode_url, request=request, session=decode_session, direct_sessions=direct_sessions, treat_as_fresh_session=True, ) if not admit_new_decode_session: can_seed = False reserved_tokens = 0 evicted_sessions = 0 prefill_backed_evictions = 0 seed_reason = "d-session-cap" else: ( can_seed, reserved_tokens, evicted_sessions, prefill_backed_evictions, seed_reason, ) = ( await _reserve_decode_session_capacity( client=client, request=request, server_url=decode_url, session=decode_session, direct_sessions=direct_sessions, residency=decode_residency, treat_as_fresh_session=True, routing_mode="seed", admission_mode=config.kvcache_admission_mode, ) ) if seed_reason == "d-session-cap": return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode="pd-router-fallback-session-cap", ) if can_seed: return await _invoke_kvcache_seeded_router( client=client, request=request, config=config, decision=decision, prefill_url=prefill_url, decode_session=decode_session, direct_sessions=direct_sessions, direct_session_lock=direct_session_lock, decode_residency=decode_residency, reserved_tokens=reserved_tokens, execution_mode=( "pd-router-d-session-reseed" + _eviction_suffix( evicted_sessions, prefill_backed_evictions, ) ), ) return await _invoke_plain_router( client=client, request=request, config=config, decision=decision, execution_mode=( "pd-router-fallback-d-backpressure" if _is_decode_backpressure_reason(seed_reason) else "pd-router-fallback-no-d-capacity" ), ) async with direct_session_lock: admit_new_decode_session = _should_admit_new_decode_session( residency=decode_residency, server_url=decode_url, request=request, session=decode_session, direct_sessions=direct_sessions, treat_as_fresh_session=True, ) if not admit_new_decode_session: can_seed = False reserved_tokens = 0 evicted_sessions = 0 prefill_backed_evictions = 0 seed_reason = "d-session-cap" else: ( can_seed, reserved_tokens, evicted_sessions, prefill_backed_evictions, seed_reason, ) = ( await _reserve_decode_session_capacity( client=client, request=request, server_url=decode_url, session=decode_session, direct_sessions=direct_sessions, residency=decode_residency, treat_as_fresh_session=True, routing_mode="seed", admission_mode=config.kvcache_admission_mode, ) ) if seed_reason == "d-session-cap": return await _invoke_plain_router( request=request, client=client, config=config, decision=decision, execution_mode="pd-router-fallback-large-append-session-cap", ) if can_seed: return await _invoke_kvcache_seeded_router( client=client, request=request, config=config, decision=decision, prefill_url=prefill_url, decode_session=decode_session, direct_sessions=direct_sessions, direct_session_lock=direct_session_lock, decode_residency=decode_residency, reserved_tokens=reserved_tokens, execution_mode=( "pd-router-large-append-reseed" + _eviction_suffix( evicted_sessions, prefill_backed_evictions, ) ), ) return await _invoke_plain_router( request=request, client=client, config=config, decision=decision, execution_mode=( "pd-router-fallback-d-backpressure" if _is_decode_backpressure_reason(seed_reason) else "pd-router-fallback-large-append" ), ) raise ValueError(f"Unsupported mechanism: {config.mechanism_name}") async def _invoke_direct( *, client: httpx.AsyncClient, request: TraceRequest, config: ReplayConfig, decision, direct_sessions: dict[str, DirectSessionState], direct_session_lock: asyncio.Lock, ) -> ExecutionResult: direct_workers = config.topology.direct_workers if not direct_workers: raise ValueError("pd-colo mechanism requires at least one direct worker") server_url = direct_workers[decision.decode_worker_index].url session = direct_sessions.get(request.session_id) if session is None or session.server_url != server_url: session = DirectSessionState(session_id=request.session_id, server_url=server_url) direct_sessions[request.session_id] = session return await _invoke_session_direct( client=client, request=request, config=config, session=session, execution_mode="pd-colo-direct-session", direct_session_lock=direct_session_lock, ) async def _invoke_session_direct( *, client: httpx.AsyncClient, request: TraceRequest, config: ReplayConfig, session: DirectSessionState, execution_mode: str, decode_residency: DecodeResidencyState | None = None, reserved_tokens: int = 0, direct_session_lock: asyncio.Lock | None = None, ) -> ExecutionResult: _prompt, effective_input_length, session_reused, session_reset = _build_direct_prompt( request=request, session=session, ) if session_reused: input_ids = _build_direct_append_input_ids(request, effective_input_length) else: input_ids = _build_direct_full_input_ids(request) if session.opened and (session_reset or not session_reused): if decode_residency is not None: await _close_decode_session( client=client, session=session, residency=decode_residency, ) else: await _close_streaming_session( client=client, server_url=session.server_url, session_id=session.session_id, allow_missing=True, ) session.opened = False session.resident_tokens = 0 if not session.opened: await _open_streaming_session( client=client, server_url=session.server_url, session_id=session.session_id, request=request, ) session.opened = True if direct_session_lock is not None: async with direct_session_lock: session.active_requests += 1 try: latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_generate( client=client, base_url=session.server_url, headers={"x-request-id": request.request_id}, payload={ "input_ids": input_ids, "sampling_params": { "temperature": 0, "max_new_tokens": max(1, request.output_length), "min_new_tokens": max(1, request.output_length), "ignore_eos": True, "no_stop_trim": True, "skip_special_tokens": False, }, "session_params": {"id": session.session_id}, "stream": config.stream, }, timeout_s=config.timeout_s, stream_idle_timeout_s=config.stream_idle_timeout_s, stream=config.stream, ) finally: if direct_session_lock is not None: async with direct_session_lock: session.active_requests = max(0, session.active_requests - 1) if decode_residency is not None: _commit_session_residency( residency=decode_residency, session=session, request=request, reserved_tokens=reserved_tokens, ) else: session.last_trace_request = request session.last_access_s = time.perf_counter() return ExecutionResult( execution_mode=execution_mode, actual_kv_transfer_blocks=0, effective_input_length=len(input_ids), cached_tokens=cached_tokens, session_reused=session_reused, session_reset=session_reset, latency_s=latency_s, ttft_s=ttft_s, tpot_s=tpot_s, ) async def _invoke_decode_session_direct( *, client: httpx.AsyncClient, request: TraceRequest, config: ReplayConfig, decision, direct_sessions: dict[str, DirectSessionState], direct_session_lock: asyncio.Lock, decode_residency: DecodeResidencyState, reserved_tokens: int, ) -> ExecutionResult: decode_url = config.topology.decode_workers[decision.decode_worker_index].url session = direct_sessions[request.session_id] try: return await _invoke_session_direct( client=client, request=request, config=config, session=session, execution_mode="kvcache-direct-to-d-session", decode_residency=decode_residency, reserved_tokens=reserved_tokens, direct_session_lock=direct_session_lock, ) except Exception: async with direct_session_lock: _release_reserved_tokens( decode_residency, decode_url, reserved_tokens, ) raise def _should_bypass_prefill( *, request: TraceRequest, config: ReplayConfig, decision, block_token_budget: int = 24, ) -> bool: if request.turn_id <= 1: return False if decision.observed_overlap_blocks <= 0: return False uncached_tokens = max(0, decision.kv_transfer_blocks * block_token_budget) return uncached_tokens <= config.kvcache_direct_max_uncached_tokens def _worker_url_by_id(workers, worker_id: str) -> str: for worker in workers: if worker.worker_id == worker_id: return worker.url raise KeyError(f"Unknown worker id: {worker_id}") def _build_headers( *, request: TraceRequest, header_mode: HeaderMode, decode_worker_index: int, policy_name: str, ) -> dict[str, str]: if header_mode == "auto": header_mode = "routing-key" if policy_name == "sticky" else "none" headers = { "x-request-id": request.request_id, } if header_mode == "routing-key": headers["x-smg-routing-key"] = request.session_id elif header_mode == "target-worker": headers["x-smg-target-worker"] = str(decode_worker_index) return headers def _contains_token(payload: dict) -> bool: choices = payload.get("choices") if not isinstance(choices, list) or not choices: return False delta = choices[0].get("delta") if not isinstance(delta, dict): return False reasoning_content = delta.get("reasoning_content") if isinstance(reasoning_content, str) and reasoning_content: return True content = delta.get("content") if isinstance(content, str) and content: return True if isinstance(content, list): return any( isinstance(item, dict) and item.get("text") for item in content ) tool_calls = delta.get("tool_calls") if isinstance(tool_calls, list): for tool_call in tool_calls: if not isinstance(tool_call, dict): continue if tool_call.get("id"): return True function = tool_call.get("function") if not isinstance(function, dict): continue if function.get("name") or function.get("arguments"): return True return False def _contains_generate_token(payload: dict) -> bool: text = payload.get("text") if isinstance(text, str) and text: return True meta_info = payload.get("meta_info") if not isinstance(meta_info, dict): return False return int(meta_info.get("completion_tokens", 0)) > 0 def _is_generate_terminal_chunk(payload: dict) -> bool: meta_info = payload.get("meta_info") if not isinstance(meta_info, dict): return False return meta_info.get("finish_reason") is not None def _extract_generate_cached_tokens(payload: dict) -> int: meta_info = payload.get("meta_info") if not isinstance(meta_info, dict): return 0 return int(meta_info.get("cached_tokens", 0) or 0) def _extract_openai_cached_tokens(payload: dict) -> int: usage = payload.get("usage") if not isinstance(usage, dict): return 0 prompt_tokens_details = usage.get("prompt_tokens_details") if not isinstance(prompt_tokens_details, dict): return 0 return int(prompt_tokens_details.get("cached_tokens", 0) or 0) async def _aiter_lines( response: httpx.Response, *, idle_timeout_s: float | None, ): line_iterator = response.aiter_lines() while True: try: if idle_timeout_s is None or idle_timeout_s <= 0: line = await anext(line_iterator) else: line = await asyncio.wait_for( anext(line_iterator), timeout=idle_timeout_s, ) except StopAsyncIteration: return yield line def _is_terminal_chunk(payload: dict) -> bool: choices = payload.get("choices") if not isinstance(choices, list) or not choices: return False finish_reason = choices[0].get("finish_reason") return finish_reason is not None