diff --git a/scripts/cache_aware_proxy.py b/scripts/cache_aware_proxy.py index 37e953b..89e0591 100644 --- a/scripts/cache_aware_proxy.py +++ b/scripts/cache_aware_proxy.py @@ -19,7 +19,7 @@ import os import time as _time import urllib.parse import uuid -from collections import OrderedDict, deque +from collections import OrderedDict from contextlib import asynccontextmanager from dataclasses import dataclass @@ -52,9 +52,6 @@ class Settings: cache_gate_ratio: float = 0.0 decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20) migration_discount_cap: int = 5 # max turns to discount - migration_request_factor: float = 1.5 # trigger migration when num_requests > avg * factor - migration_ttft_threshold: float = 5.0 # trigger migration when recent TTFT median > this (seconds) - migration_ttft_window: int = 8 # number of recent TTFTs to track per instance SETTINGS = Settings() @@ -77,7 +74,6 @@ class InstanceState: self.dp_size = 1 # OrderedDict acts as an LRU keyed by block hash; value is unused. self.cached_blocks: OrderedDict[int, None] = OrderedDict() - self.recent_ttfts: deque[float] = deque(maxlen=SETTINGS.migration_ttft_window) def estimate_cache_hit(self, token_ids: list[int] | None) -> int: if not token_ids or len(token_ids) < BLOCK_SIZE: @@ -451,11 +447,7 @@ async def _handle_local_request(api, req_data, headers, token_ids, input_length, if not prefill_done: chosen.pending_prefill_tokens -= estimated_new chosen.ongoing_decode_tokens += input_length - t_first = _time.monotonic() - breakdown["t_first_token"] = t_first - t_recv = breakdown.get("t_proxy_recv") - if t_recv: - chosen.recent_ttfts.append(t_first - t_recv) + breakdown["t_first_token"] = _time.monotonic() prefill_done = True yield chunk chosen.record_prefix( @@ -523,57 +515,6 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h cache_hits = [inst.estimate_cache_hit(token_ids) for inst in combined_instances] best_cache_idx = max(range(len(combined_instances)), key=lambda i: cache_hits[i]) best_cache_hit = cache_hits[best_cache_idx] - - # Session-level migration: force-migrate when instance has high recent TTFT - if (offload_enabled and session_id - and session_id in session_affinity_combined): - mig_src_idx = session_affinity_combined[session_id] - if mig_src_idx < len(combined_instances): - mig_src = combined_instances[mig_src_idx] - src_cache_ratio = cache_hits[mig_src_idx] / max(input_length, 1) - src_ttfts = mig_src.recent_ttfts - src_ttft_med = sorted(src_ttfts)[len(src_ttfts) // 2] if len(src_ttfts) >= 3 else 0 - - if (src_ttft_med > SETTINGS.migration_ttft_threshold - and src_cache_ratio > 0.5 - and input_length >= SETTINGS.heavy_threshold): - # Find instance with lowest recent TTFT - def _inst_ttft_score(i: int) -> float: - t = combined_instances[i].recent_ttfts - if len(t) < 2: - return 0.0 - return sorted(t)[len(t) // 2] - mig_tgt_idx = min(range(len(combined_instances)), key=_inst_ttft_score) - mig_tgt = combined_instances[mig_tgt_idx] - tgt_ttft_med = _inst_ttft_score(mig_tgt_idx) - - if tgt_ttft_med < src_ttft_med * 0.5: - estimated_new = max(0, input_length - cache_hits[mig_src_idx]) - breakdown = { - "request_id": headers.get("X-Request-Id", ""), - "input_length": input_length, - "cache_hit": cache_hits[mig_tgt_idx], - "estimated_new_tokens": estimated_new, - "t_proxy_recv": _time.monotonic(), - "policy": "session_migrate", - "push_cache_hit": cache_hits[mig_src_idx], - "c_inst": mig_src.url, - "routed_to": mig_tgt.url, - } - session_affinity_combined[session_id] = mig_tgt_idx - offload_mode = getattr(global_args, 'offload_mode', 'cached_prefill') - push_cache_hit = cache_hits[mig_src_idx] - if offload_mode == "cached_prefill": - return await _handle_cached_prefill_offload( - api, req_data, headers, token_ids, input_length, - mig_src, mig_tgt, push_cache_hit, estimated_new, - breakdown) - else: - return await _handle_direct_read_offload( - api, req_data, headers, token_ids, input_length, - mig_src, mig_tgt, push_cache_hit, estimated_new, - breakdown) - def _current_offloads() -> int: return sum(i.active_p_offloads for i in combined_instances) @@ -1016,10 +957,6 @@ def parse_args(): "(0.0 disables gate, 1.0 disables offload entirely)") p.add_argument("--decode-iteration-s", type=float, default=0.05, help="Estimated per-request decode iteration time in seconds") - p.add_argument("--migration-request-factor", type=float, default=1.5, - help="Trigger session migration when num_requests > avg * factor") - p.add_argument("--migration-ttft-threshold", type=float, default=5.0, - help="Trigger migration when instance median TTFT > this (seconds)") args = p.parse_args() args.prefill = [] @@ -1042,8 +979,6 @@ if __name__ == "__main__": SETTINGS.max_offload_inflight = global_args.max_offload_inflight SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05) - SETTINGS.migration_request_factor = getattr(global_args, 'migration_request_factor', 1.5) - SETTINGS.migration_ttft_threshold = getattr(global_args, 'migration_ttft_threshold', 5.0) print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % ( SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s, getattr(global_args, 'offload', False)))