From be948d32b8234751d712437a872124b2b4272a2b Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Thu, 28 May 2026 20:01:26 +0800 Subject: [PATCH] P2: real engine-state feed replaces stale shadow counters for migration targeting vLLM scheduler publishes real state (running/waiting, KV free, and the max-in-progress-prefill signal /metrics lacks) to a tmpfs/redis store ~20Hz; router reads it and avoids GIL-stall (mid-large-prefill) + KV-capacity-wall targets, using real load over 30s-stale shadow counters. Components: engine_state.py (canonical+reader), instrument_engine_state.py (scheduler patch, file/redis writer), migration_target.py (scorer), proxy wiring (--engine-state-uri, off=unchanged). All unit-tested without GPU; not yet run live. See P2_ENGINE_STATE.md. Co-Authored-By: Claude Opus 4.7 --- .../layerwise/P2_ENGINE_STATE.md | 61 +++++ .../layerwise/cache_aware_proxy.WRITEMODE.py | 104 +++++++- .../connector_tax/layerwise/engine_state.py | 140 +++++++++++ .../layerwise/instrument_engine_state.py | 234 ++++++++++++++++++ .../layerwise/migration_target.py | 79 ++++++ 5 files changed, 610 insertions(+), 8 deletions(-) create mode 100644 microbench/connector_tax/layerwise/P2_ENGINE_STATE.md create mode 100644 microbench/connector_tax/layerwise/engine_state.py create mode 100644 microbench/connector_tax/layerwise/instrument_engine_state.py create mode 100644 microbench/connector_tax/layerwise/migration_target.py diff --git a/microbench/connector_tax/layerwise/P2_ENGINE_STATE.md b/microbench/connector_tax/layerwise/P2_ENGINE_STATE.md new file mode 100644 index 0000000..e92a91c --- /dev/null +++ b/microbench/connector_tax/layerwise/P2_ENGINE_STATE.md @@ -0,0 +1,61 @@ +# P2: real engine-state feed for migration target selection + +Problem: the router (`cache_aware_proxy.py`) decides migration targets from +**shadow counters** it maintains itself (incremented at dispatch, decremented +at completion) and reconciles to vLLM `/metrics` only every **30 s** +(`_reconcile_loop`). So every routing/migration decision is on stale state. +Worse, the signal that predicts the ~45% control-plane stall — *is the target +mid-large-prefill?* (a big prefill holds the GIL and starves the mooncake +receiver_loop) — isn't visible at all, and `/metrics` doesn't expose it either. + +Fix: vLLM publishes **real** per-engine state to a shared store ~20 Hz; the +router reads ground truth and avoids GIL-stall / capacity-wall targets. + +## Components (all unit-tested without GPUs) + +- `engine_state.py` — canonical `compute_snapshot(scheduler, id)`, `StateWriter`, + `StateReader`. Schema per engine: `ts, num_running, num_waiting, + gpu_blocks_total/free, gpu_kv_used_frac, pending_prefill_tokens, + ongoing_decode_tokens, num_prefilling, max_prefill_remaining`. +- `instrument_engine_state.py` — vLLM `Scheduler` patch (apply/revert markers + `ES_INSTRUMENT_*`): a daemon thread publishes the snapshot every + `AGENTIC_ENGINE_STATE_PERIOD_MS` (50 ms) off the forward hot path. Inlined + writer (engine process needs no repo import). Coexists with MB5. +- `migration_target.py` — pure target scorer: avoid `max_prefill_remaining ≥ + es_big_prefill_threshold` (GIL stall) and `gpu_kv_used_frac ≥ es_kv_wall_frac` + (capacity wall), then rank by cache-richness and **real** load. +- `cache_aware_proxy.WRITEMODE.py` — wired: `InstanceState.real_state`, + `_engine_state_poll_loop` (instance i ← `engine_{i}`), `_real_load`/Gate-3 and + Mechanism-B now real-state-aware. `--engine-state-uri` flag; off ⇒ identical + to before (shadow only). + +Transport (`AGENTIC_ENGINE_STATE_URI` / `--engine-state-uri`): +`file:///dev/shm/agentic_engine_state` (default, zero-dep, single-node) or +`redis://host:port/0` (multi-node; needs redis-py + server — not installed on +dash0, so file backend is the working default). + +## Tests (no GPU) +- `compute_snapshot` field math (mock scheduler): running/waiting, + max_prefill_remaining, pending, decode, kv_used_frac. +- writer→reader round-trip + staleness drop (file backend). +- target scorer: 5 cases incl. *avoid GIL-stall target even when its shadow + load is lower*, *real load beats stale shadow*, *cache-rich wins*, + *avoid KV wall*, *graceful fallback when feed missing*. +- end-to-end: publish 8 engines (one mid-130k-prefill) → proxy inlined reader → + target selection avoids it. + +## Enabling in a GPU run (when free) +1. `instrument_engine_state.py --apply` on the dash0 venv. +2. `export AGENTIC_ENGINE_STATE_URI=file:///dev/shm/agentic_engine_state` + before the launcher (vLLM instances inherit it; `AGENTIC_WORKER_ID=engine_{i}` + already set by `b3_isolated_policy.sh` → publishes as `engine_{i}`). +3. Proxy: `EXTRA_PROXY_ARGS="--engine-state-uri file:///dev/shm/agentic_engine_state ..."`. +4. Revert the patch + `rm -rf /dev/shm/agentic_engine_state` after. + +## Status / scope +- Built + unit-tested; NOT yet run against live engines (GPU busy). +- Scoped to **migration target selection** (the P2 ask). The same real-load + signal could also de-stale the base `pick_instance_unified_hybrid` LMetric + fallback (the 8007-hotspot class from UNIFIED_ABLATION) — follow-up. +- TP=1 only (one EngineCore/instance → one publisher/engine_id). TP>1 needs + per-rank ids. diff --git a/microbench/connector_tax/layerwise/cache_aware_proxy.WRITEMODE.py b/microbench/connector_tax/layerwise/cache_aware_proxy.WRITEMODE.py index 27a607b..054e8da 100644 --- a/microbench/connector_tax/layerwise/cache_aware_proxy.WRITEMODE.py +++ b/microbench/connector_tax/layerwise/cache_aware_proxy.WRITEMODE.py @@ -111,6 +111,13 @@ class Settings: v3_recent_mig_weight: float = 1.0 # how many "virtual requests" each # recent migration counts as + # P2: real engine-state feed (replaces 30s-stale shadow counters for + # migration target selection). Empty = disabled (use shadow only). + engine_state_uri: str = "" # file:///dev/shm/... or redis://... + engine_state_period_ms: int = 50 # router poll period + es_big_prefill_threshold: int = 16000 # target mid-prefill >= this => avoid (GIL stall) + es_kv_wall_frac: float = 0.90 # target KV usage >= this => avoid (capacity wall) + # Direction B knob: LMetric fallback adds decode-token penalty to score. # score = (pending_prefill + new + lmetric_decode_weight * ongoing_decode_tok) * num_req # Empirical iter-time slope on H100 + Qwen3-30B-A3B: each decode token in @@ -206,6 +213,9 @@ class InstanceState: # recent-migration count over a sliding window, preventing back-to-back # decisions from clustering on the same dst. self.recent_mig_targeted_at: deque[float] = deque(maxlen=64) + # P2: latest real engine state (from the engine-state feed), or None + # when the feed is disabled/stale. Set by _engine_state_poll_loop. + self.real_state: dict | None = None def estimate_cache_hit(self, token_ids: list[int] | None) -> int: if not token_ids or len(token_ids) < BLOCK_SIZE: @@ -672,20 +682,29 @@ def pick_instance_unified_v3( now_mono = _time.monotonic() cutoff = now_mono - SETTINGS.v3_recent_mig_window_s + def _real_load(inst): + # P2: prefer REAL engine state (running+waiting) over the proxy's + # 30s-stale shadow num_requests, when the engine-state feed is fresh. + rs = getattr(inst, "real_state", None) + if rs is not None: + return rs.get("num_running", 0) + rs.get("num_waiting", 0) + return inst.num_requests + def effective_load(inst): # Drop expired entries lazily. while inst.recent_mig_targeted_at and inst.recent_mig_targeted_at[0] < cutoff: inst.recent_mig_targeted_at.popleft() recent = len(inst.recent_mig_targeted_at) - return inst.num_requests + recent * SETTINGS.v3_recent_mig_weight + return _real_load(inst) + recent * SETTINGS.v3_recent_mig_weight + ph_load = _real_load(prefill_host) threshold_loaded = max(1, - int(prefill_host.num_requests * SETTINGS.v3_target_load_ratio)) + int(ph_load * SETTINGS.v3_target_load_ratio)) candidates = [ (i, inst) for i, inst in enumerate(instances) if i != prefill_idx and effective_load(inst) < threshold_loaded - and effective_load(inst) <= prefill_host.num_requests - SETTINGS.v3_min_load_gap + and effective_load(inst) <= ph_load - SETTINGS.v3_min_load_gap ] if not candidates: decision["v3_reason"] = ( @@ -700,11 +719,23 @@ def pick_instance_unified_v3( # cache_hit DESC (more cache = less KV to transfer), then by effective_load # (which includes recent-migration penalty), then by ongoing_tokens. if SETTINGS.v3_prefer_cache_target: - decode_target_idx, decode_target = min( - candidates, - key=lambda x: (-x[1].estimate_cache_hit(token_ids), - effective_load(x[1]), - x[1].ongoing_tokens)) + def _tgt_key(x): + # P2: avoid a target that is mid-large-prefill (holds the GIL, + # stalls the mooncake receiver_loop = the ~45% control-plane + # residual layer-wise can't fix) or near the KV capacity wall, + # before ranking by cache-richness and real load. + inst = x[1] + ch = inst.estimate_cache_hit(token_ids) + rs = getattr(inst, "real_state", None) + stalls = near_wall = 0 + if rs is not None: + if int(rs.get("max_prefill_remaining", 0)) >= SETTINGS.es_big_prefill_threshold: + stalls = 1 + f = rs.get("gpu_kv_used_frac", 0.0) or 0.0 + if float(f) >= SETTINGS.es_kv_wall_frac: + near_wall = 1 + return (stalls, near_wall, -ch, effective_load(inst), inst.ongoing_tokens) + decode_target_idx, decode_target = min(candidates, key=_tgt_key) else: decode_target_idx, decode_target = min( candidates, key=lambda x: (effective_load(x[1]), x[1].ongoing_tokens)) @@ -857,6 +888,57 @@ async def _fetch_vllm_inflight(inst: "InstanceState") -> tuple[int, int] | None: return running, waiting +def _engine_state_read_all(uri: str, max_age_s: float = 2.0) -> dict: + """P2 reader (inlined; mirrors engine_state.StateReader). Returns + {engine_id: state}, dropping records older than max_age_s.""" + now = _time.time() + out: dict = {} + try: + if uri.startswith("file://"): + import glob + d = uri[len("file://"):] + for p in glob.glob(os.path.join(d, "*.json")): + try: + s = json.load(open(p)) + except Exception: + continue + if now - s.get("ts", 0) <= max_age_s: + out[s.get("engine_id", os.path.basename(p)[:-5])] = s + elif uri.startswith("redis://"): + import redis + r = redis.Redis.from_url(uri) + for k in r.scan_iter("engine_state:*"): + v = r.get(k) + if not v: + continue + s = json.loads(v) + if now - s.get("ts", 0) <= max_age_s: + out[s.get("engine_id")] = s + except Exception: + pass + return out + + +async def _engine_state_poll_loop(): + """P2: poll the engine-state feed and attach real_state to each instance. + Instance i is keyed engine_{i} (matches AGENTIC_WORKER_ID in the launcher). + """ + uri = SETTINGS.engine_state_uri + if not uri: + return + period = max(0.01, SETTINGS.engine_state_period_ms / 1000.0) + insts = combined_instances or (prefill_instances + decode_instances) + print(f"[engine-state] polling {uri} every {period*1000:.0f}ms for {len(insts)} instances") + while True: + try: + await asyncio.sleep(period) + except asyncio.CancelledError: + return + states = await asyncio.to_thread(_engine_state_read_all, uri) + for i, inst in enumerate(insts): + inst.real_state = states.get(f"engine_{i}") + + async def _reconcile_loop(): """Periodic shadow-state reconciliation against vLLM /metrics truth. @@ -960,6 +1042,7 @@ async def lifespan(app: FastAPI): _verify_vllm_patch() reconcile_task = asyncio.create_task(_reconcile_loop()) + engine_state_task = asyncio.create_task(_engine_state_poll_loop()) if global_args.combined: is_pd_sep = False @@ -1720,6 +1803,10 @@ def parse_args(): " penalised. 0 = original behavior; 0.01 is a reasonable start.") p.add_argument("--overload-factor", type=float, default=2.0, help="Break session affinity when instance load > factor * avg") + p.add_argument("--engine-state-uri", type=str, default="", + help="P2: real engine-state feed for migration target " + "selection (file:///dev/shm/... or redis://...). " + "Empty=disabled (shadow counters only).") # The four flags below are accepted for bench.sh backward compatibility but # have no effect after the PD-sep offload path was retired (REPORT §3.9, # commits 4c583f2 / cc6e562). Removing them would break scripts/bench.sh and @@ -1752,6 +1839,7 @@ if __name__ == "__main__": global_args = parse_args() SETTINGS.heavy_threshold = global_args.heavy_threshold SETTINGS.overload_factor = global_args.overload_factor + SETTINGS.engine_state_uri = getattr(global_args, 'engine_state_uri', '') or '' SETTINGS.max_offload_inflight = global_args.max_offload_inflight SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05) diff --git a/microbench/connector_tax/layerwise/engine_state.py b/microbench/connector_tax/layerwise/engine_state.py new file mode 100644 index 0000000..36e1021 --- /dev/null +++ b/microbench/connector_tax/layerwise/engine_state.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +"""Engine-state store: canonical snapshot + writer/reader, shared schema. + +The vLLM scheduler patch (instrument_engine_state.py) inlines a faithful copy +of `compute_snapshot` + the file/redis writer (engine process needs no repo +import). The router (cache_aware_proxy) imports `StateReader` here to read the +real per-engine state instead of its stale shadow counters. + +Schema (one record per engine, key = engine_id): + ts, engine_id, num_running, num_waiting, gpu_blocks_total, gpu_blocks_free, + gpu_kv_used_frac, pending_prefill_tokens, ongoing_decode_tokens, + num_prefilling, max_prefill_remaining + +Transport URIs: + file:///dev/shm/agentic_engine_state (default; atomic temp+rename) + redis://host:port/0 (optional; needs redis-py) +""" +from __future__ import annotations + +import json +import os +import time + + +def compute_snapshot(scheduler, engine_id: str) -> dict: + """Cheap O(batch) read of routing-relevant real state from a live + vLLM V1 Scheduler (duck-typed for testability).""" + try: + pool = scheduler.kv_cache_manager.block_pool + total = int(pool.num_gpu_blocks) + free = int(pool.get_num_free_blocks()) + except Exception: + total = free = -1 + n_run = pend = dec = n_pref = max_pref = 0 + try: + for r in scheduler.running: + n_run += 1 + npr = int(getattr(r, "num_prompt_tokens", 0)) + nct = int(getattr(r, "num_computed_tokens", 0)) + if nct < npr: + rem = npr - nct + pend += rem + n_pref += 1 + max_pref = max(max_pref, rem) + else: + dec += int(getattr(r, "num_tokens", 0)) + except Exception: + pass + n_wait = 0 + try: + n_wait = len(scheduler.waiting) + len(getattr(scheduler, "skipped_waiting", [])) + for r in list(scheduler.waiting): + pend += max(0, int(getattr(r, "num_prompt_tokens", 0)) + - int(getattr(r, "num_computed_tokens", 0))) + except Exception: + pass + used = ((total - free) / total) if (total and total > 0) else -1.0 + return { + "ts": time.time(), + "engine_id": engine_id, + "num_running": n_run, + "num_waiting": int(n_wait), + "gpu_blocks_total": total, + "gpu_blocks_free": free, + "gpu_kv_used_frac": used, + "pending_prefill_tokens": int(pend), + "ongoing_decode_tokens": int(dec), + "num_prefilling": n_pref, + "max_prefill_remaining": int(max_pref), + } + + +class StateWriter: + def __init__(self, uri: str, engine_id: str): + self.engine_id = engine_id + self.kind = None + if uri.startswith("file://"): + self.kind = "file" + self.dir = uri[len("file://"):] + os.makedirs(self.dir, exist_ok=True) + self.path = os.path.join(self.dir, f"{engine_id}.json") + self.tmp = self.path + f".tmp.{os.getpid()}" + elif uri.startswith("redis://"): + self.kind = "redis" + import redis + self.r = redis.Redis.from_url(uri) + self.key = f"engine_state:{engine_id}" + else: + raise ValueError(f"unsupported engine-state URI: {uri}") + + def publish(self, state: dict): + if self.kind == "file": + with open(self.tmp, "w") as f: + f.write(json.dumps(state)) + os.replace(self.tmp, self.path) + elif self.kind == "redis": + self.r.set(self.key, json.dumps(state), ex=5) + + +class StateReader: + """Router-side reader. read_all() returns {engine_id: state}, dropping + records older than max_age_s (so a dead/hung engine is ignored).""" + def __init__(self, uri: str, max_age_s: float = 2.0): + self.uri = uri + self.max_age_s = max_age_s + self.kind = None + if uri.startswith("file://"): + self.kind = "file" + self.dir = uri[len("file://"):] + elif uri.startswith("redis://"): + self.kind = "redis" + import redis + self.r = redis.Redis.from_url(uri) + else: + raise ValueError(f"unsupported engine-state URI: {uri}") + + def read_all(self) -> dict[str, dict]: + now = time.time() + out: dict[str, dict] = {} + try: + if self.kind == "file": + import glob + for p in glob.glob(os.path.join(self.dir, "*.json")): + try: + s = json.load(open(p)) + except Exception: + continue + if now - s.get("ts", 0) <= self.max_age_s: + out[s.get("engine_id", os.path.basename(p)[:-5])] = s + elif self.kind == "redis": + for k in self.r.scan_iter("engine_state:*"): + v = self.r.get(k) + if not v: + continue + s = json.loads(v) + if now - s.get("ts", 0) <= self.max_age_s: + out[s.get("engine_id")] = s + except Exception: + pass + return out diff --git a/microbench/connector_tax/layerwise/instrument_engine_state.py b/microbench/connector_tax/layerwise/instrument_engine_state.py new file mode 100644 index 0000000..fd8d099 --- /dev/null +++ b/microbench/connector_tax/layerwise/instrument_engine_state.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +"""Patch vLLM V1 scheduler to publish REAL engine state to a shared store, +so the global router reads ground truth instead of its own stale shadow +counters (reconciled only every 30s). + +Published per engine (key = AGENTIC_ENGINE_ID), throttled ~20 Hz from a +daemon thread (off the forward hot path): + + {ts, num_running, num_waiting, gpu_blocks_total, gpu_blocks_free, + gpu_kv_used_frac, pending_prefill_tokens, ongoing_decode_tokens, + num_prefilling, max_prefill_remaining} + +`max_prefill_remaining` is the key signal /metrics does NOT expose: the +largest in-progress prefill on the engine. A big in-progress prefill holds +the GIL and stalls the mooncake receiver_loop — so the router should avoid +migrating KV to such an instance (P2). + +Transport (env AGENTIC_ENGINE_STATE_URI): + file:///dev/shm/agentic_engine_state (default; atomic temp+rename) + redis://host:port/0 (optional; needs redis-py + server) + +Self-contained (inlined writer) so the engine process needs no repo import. +Apply/revert markers: # ES_INSTRUMENT_START / # ES_INSTRUMENT_END. + +Usage: + python instrument_engine_state.py --apply [--venv PATH] + python instrument_engine_state.py --revert [--venv PATH] + python instrument_engine_state.py --check [--venv PATH] +""" +from __future__ import annotations + +import argparse +import re +from pathlib import Path + +DEFAULT_VENV = Path("/home/admin/cpfs/wjh/agentic-kv/.venv") +TARGET_REL = "lib/python3.12/site-packages/vllm/v1/core/sched/scheduler.py" +START = "# ES_INSTRUMENT_START" +END = "# ES_INSTRUMENT_END" + +# ---- Patch 1: header (writer + publisher thread), before class Scheduler ---- +HEADER_ANCHOR = "class Scheduler(SchedulerInterface):" +HEADER = f'''{START} +import json as _es_json +import os as _es_os +import threading as _es_threading +import time as _es_time + +_ES_URI = _es_os.environ.get("AGENTIC_ENGINE_STATE_URI", "") +_ES_ID = _es_os.environ.get("AGENTIC_ENGINE_ID") or _es_os.environ.get( + "AGENTIC_WORKER_ID", f"engine_{{_es_os.getpid()}}") +_ES_PERIOD_S = float(_es_os.environ.get("AGENTIC_ENGINE_STATE_PERIOD_MS", "50")) / 1000.0 + + +class _ESWriter: + """Pluggable state writer: file:// (atomic temp+rename) or redis://.""" + def __init__(self, uri: str, engine_id: str): + self.engine_id = engine_id + self.kind = None + if uri.startswith("file://"): + self.kind = "file" + self.dir = uri[len("file://"):] + _es_os.makedirs(self.dir, exist_ok=True) + self.path = _es_os.path.join(self.dir, f"{{engine_id}}.json") + self.tmp = self.path + f".tmp.{{_es_os.getpid()}}" + elif uri.startswith("redis://"): + self.kind = "redis" + import redis # lazy + self.r = redis.Redis.from_url(uri) + self.key = f"engine_state:{{engine_id}}" + + def publish(self, state: dict): + try: + if self.kind == "file": + with open(self.tmp, "w") as f: + f.write(_es_json.dumps(state)) + _es_os.replace(self.tmp, self.path) # atomic + elif self.kind == "redis": + self.r.set(self.key, _es_json.dumps(state), ex=5) + except Exception: + pass + + +def _es_compute_snapshot(scheduler) -> dict: + """Cheap O(batch) state read from the live scheduler.""" + try: + kvm = scheduler.kv_cache_manager + pool = kvm.block_pool + total = int(pool.num_gpu_blocks) + free = int(pool.get_num_free_blocks()) + except Exception: + total = free = -1 + n_run = 0 + pend = 0 + dec = 0 + n_pref = 0 + max_pref = 0 + try: + for r in scheduler.running: + n_run += 1 + npr = int(getattr(r, "num_prompt_tokens", 0)) + nct = int(getattr(r, "num_computed_tokens", 0)) + if nct < npr: # still prefilling + rem = npr - nct + pend += rem + n_pref += 1 + if rem > max_pref: + max_pref = rem + else: # decoding + dec += int(getattr(r, "num_tokens", 0)) + except Exception: + pass + n_wait = 0 + try: + n_wait = len(scheduler.waiting) + len(getattr(scheduler, "skipped_waiting", [])) + for r in list(scheduler.waiting): + pend += max(0, int(getattr(r, "num_prompt_tokens", 0)) + - int(getattr(r, "num_computed_tokens", 0))) + except Exception: + pass + used_frac = ((total - free) / total) if (total and total > 0) else -1.0 + return {{ + "ts": _es_time.time(), + "engine_id": _ES_ID, + "num_running": n_run, + "num_waiting": int(n_wait), + "gpu_blocks_total": total, + "gpu_blocks_free": free, + "gpu_kv_used_frac": used_frac, + "pending_prefill_tokens": int(pend), + "ongoing_decode_tokens": int(dec), + "num_prefilling": n_pref, + "max_prefill_remaining": int(max_pref), + }} + + +class _ESPublisher: + def __init__(self, scheduler): + self._sched = scheduler + self._writer = _ESWriter(_ES_URI, _ES_ID) + self._stop = _es_threading.Event() + self._t = _es_threading.Thread(target=self._loop, daemon=True) + self._t.start() + + def _loop(self): + while not self._stop.is_set(): + try: + self._writer.publish(_es_compute_snapshot(self._sched)) + except Exception: + pass + _es_time.sleep(_ES_PERIOD_S) +{END} + + +''' + +# ---- Patch 2: start the publisher at the end of Scheduler.__init__ ---------- +# Anchor on the existing agentic step-log block tail in __init__. +INIT_ANCHOR = """ _step_path = _os.environ.get("AGENTIC_STEP_LOG_PATH")""" +INIT_INSERT = f""" {START} + if _ES_URI: + try: + self._es_publisher = _ESPublisher(self) + logger.info("agentic engine-state publisher: uri=%s id=%s", + _ES_URI, _ES_ID) + except Exception as _e: + logger.warning("engine-state publisher disabled (%r)", _e) + {END} + _step_path = _os.environ.get("AGENTIC_STEP_LOG_PATH")""" + +PATCHES = [ + ("header", HEADER_ANCHOR, HEADER + HEADER_ANCHOR), + ("init", INIT_ANCHOR, INIT_INSERT), +] + + +def find_target(venv: Path) -> Path: + for c in (venv / TARGET_REL, DEFAULT_VENV / TARGET_REL): + if c.is_file(): + return c + raise FileNotFoundError(f"cannot find {TARGET_REL} under {venv}") + + +def is_patched(t: str) -> bool: + return START in t + + +def apply(target: Path): + text = target.read_text() + if is_patched(text): + print(f"[es-instr] already patched: {target}") + return + new = text + for name, src, dst in PATCHES: + if src not in new: + raise RuntimeError(f"patch {name!r}: anchor not found in {target}") + new = new.replace(src, dst, 1) + target.write_text(new) + print(f"[es-instr] applied {len(PATCHES)} patches -> {target}") + + +def revert(target: Path): + text = target.read_text() + if not is_patched(text): + print(f"[es-instr] not patched: {target}") + return + pat = re.compile(r"[ \t]*" + re.escape(START) + r".*?" + re.escape(END) + r"\n", + flags=re.DOTALL) + new = pat.sub("", text) + new = re.sub(r"\n{3,}class Scheduler\(", "\n\nclass Scheduler(", new) + target.write_text(new) + print(f"[es-instr] reverted: {target}") + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--apply", action="store_true") + p.add_argument("--revert", action="store_true") + p.add_argument("--check", action="store_true") + p.add_argument("--venv", type=Path, default=DEFAULT_VENV) + a = p.parse_args() + t = find_target(a.venv) + if a.apply: + apply(t) + elif a.revert: + revert(t) + elif a.check: + print(f"[es-instr] {'PATCHED' if is_patched(t.read_text()) else 'CLEAN'}: {t}") + else: + p.error("specify --apply/--revert/--check") + + +if __name__ == "__main__": + main() diff --git a/microbench/connector_tax/layerwise/migration_target.py b/microbench/connector_tax/layerwise/migration_target.py new file mode 100644 index 0000000..2302ae7 --- /dev/null +++ b/microbench/connector_tax/layerwise/migration_target.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +"""P2: real-state-aware migration target selection. + +Pure helpers (no proxy deps) so they're unit-testable. The router calls +`rank_migration_targets` to pick the decode target, using REAL engine state +(from the engine-state store) when available, falling back to shadow counters. + +Key fix over the shadow-only Mechanism B: deprioritise targets that are +mid-large-prefill (`max_prefill_remaining` high) — those hold the GIL and +stall the mooncake receiver_loop, which is the ~45% control-plane residual +that layer-wise transfer does NOT fix. Also avoid targets near the KV +capacity wall (`gpu_kv_used_frac` high). +""" +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class TargetCandidate: + idx: int + cache_hit: int # estimated transfer bytes saved (tokens) + shadow_num_req: int # proxy shadow counter (fallback) + ongoing_tokens: int # shadow tertiary + real_state: dict | None = None # engine-state record, or None if stale/missing + + +def real_load(c: TargetCandidate) -> float: + """Effective load: prefer real (running + waiting); else shadow.""" + rs = c.real_state + if rs is not None: + return float(rs.get("num_running", 0) + rs.get("num_waiting", 0)) + return float(c.shadow_num_req) + + +def big_prefill_remaining(c: TargetCandidate) -> int: + """Largest in-progress prefill on the candidate (GIL-stall predictor). + 0 when unknown (no real state) so we don't over-penalise blind.""" + rs = c.real_state + return int(rs.get("max_prefill_remaining", 0)) if rs is not None else 0 + + +def kv_used_frac(c: TargetCandidate) -> float: + rs = c.real_state + if rs is not None: + f = rs.get("gpu_kv_used_frac", -1.0) + return float(f) if f is not None and f >= 0 else 0.0 + return 0.0 + + +def target_sort_key( + c: TargetCandidate, + big_prefill_threshold: int = 16000, + kv_wall_frac: float = 0.90, +): + """Sort key (lower = better). Ordering of concerns: + 1. NOT mid-large-prefill (avoid the GIL-stall dst) [bool] + 2. NOT near the KV capacity wall [bool] + 3. most cache-rich (fewest transfer bytes) -> -cache_hit + 4. lowest real load + 5. lowest ongoing_tokens (shadow tertiary tie-break) + """ + stalls = 1 if big_prefill_remaining(c) >= big_prefill_threshold else 0 + near_wall = 1 if kv_used_frac(c) >= kv_wall_frac else 0 + return (stalls, near_wall, -c.cache_hit, real_load(c), c.ongoing_tokens) + + +def rank_migration_targets( + candidates: list[TargetCandidate], + big_prefill_threshold: int = 16000, + kv_wall_frac: float = 0.90, +) -> TargetCandidate | None: + """Return the best candidate, or None if the list is empty.""" + if not candidates: + return None + return min( + candidates, + key=lambda c: target_sort_key(c, big_prefill_threshold, kv_wall_frac), + )