Files
agentic-kvc/microbench/connector_tax/layerwise/engine_state.py
Gahow Wang be948d32b8 P2: real engine-state feed replaces stale shadow counters for migration targeting
vLLM scheduler publishes real state (running/waiting, KV free, and the
max-in-progress-prefill signal /metrics lacks) to a tmpfs/redis store ~20Hz;
router reads it and avoids GIL-stall (mid-large-prefill) + KV-capacity-wall
targets, using real load over 30s-stale shadow counters. Components:
engine_state.py (canonical+reader), instrument_engine_state.py (scheduler
patch, file/redis writer), migration_target.py (scorer), proxy wiring
(--engine-state-uri, off=unchanged). All unit-tested without GPU; not yet
run live. See P2_ENGINE_STATE.md.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-28 20:01:26 +08:00

141 lines
5.0 KiB
Python

#!/usr/bin/env python3
"""Engine-state store: canonical snapshot + writer/reader, shared schema.
The vLLM scheduler patch (instrument_engine_state.py) inlines a faithful copy
of `compute_snapshot` + the file/redis writer (engine process needs no repo
import). The router (cache_aware_proxy) imports `StateReader` here to read the
real per-engine state instead of its stale shadow counters.
Schema (one record per engine, key = engine_id):
ts, engine_id, num_running, num_waiting, gpu_blocks_total, gpu_blocks_free,
gpu_kv_used_frac, pending_prefill_tokens, ongoing_decode_tokens,
num_prefilling, max_prefill_remaining
Transport URIs:
file:///dev/shm/agentic_engine_state (default; atomic temp+rename)
redis://host:port/0 (optional; needs redis-py)
"""
from __future__ import annotations
import json
import os
import time
def compute_snapshot(scheduler, engine_id: str) -> dict:
"""Cheap O(batch) read of routing-relevant real state from a live
vLLM V1 Scheduler (duck-typed for testability)."""
try:
pool = scheduler.kv_cache_manager.block_pool
total = int(pool.num_gpu_blocks)
free = int(pool.get_num_free_blocks())
except Exception:
total = free = -1
n_run = pend = dec = n_pref = max_pref = 0
try:
for r in scheduler.running:
n_run += 1
npr = int(getattr(r, "num_prompt_tokens", 0))
nct = int(getattr(r, "num_computed_tokens", 0))
if nct < npr:
rem = npr - nct
pend += rem
n_pref += 1
max_pref = max(max_pref, rem)
else:
dec += int(getattr(r, "num_tokens", 0))
except Exception:
pass
n_wait = 0
try:
n_wait = len(scheduler.waiting) + len(getattr(scheduler, "skipped_waiting", []))
for r in list(scheduler.waiting):
pend += max(0, int(getattr(r, "num_prompt_tokens", 0))
- int(getattr(r, "num_computed_tokens", 0)))
except Exception:
pass
used = ((total - free) / total) if (total and total > 0) else -1.0
return {
"ts": time.time(),
"engine_id": engine_id,
"num_running": n_run,
"num_waiting": int(n_wait),
"gpu_blocks_total": total,
"gpu_blocks_free": free,
"gpu_kv_used_frac": used,
"pending_prefill_tokens": int(pend),
"ongoing_decode_tokens": int(dec),
"num_prefilling": n_pref,
"max_prefill_remaining": int(max_pref),
}
class StateWriter:
def __init__(self, uri: str, engine_id: str):
self.engine_id = engine_id
self.kind = None
if uri.startswith("file://"):
self.kind = "file"
self.dir = uri[len("file://"):]
os.makedirs(self.dir, exist_ok=True)
self.path = os.path.join(self.dir, f"{engine_id}.json")
self.tmp = self.path + f".tmp.{os.getpid()}"
elif uri.startswith("redis://"):
self.kind = "redis"
import redis
self.r = redis.Redis.from_url(uri)
self.key = f"engine_state:{engine_id}"
else:
raise ValueError(f"unsupported engine-state URI: {uri}")
def publish(self, state: dict):
if self.kind == "file":
with open(self.tmp, "w") as f:
f.write(json.dumps(state))
os.replace(self.tmp, self.path)
elif self.kind == "redis":
self.r.set(self.key, json.dumps(state), ex=5)
class StateReader:
"""Router-side reader. read_all() returns {engine_id: state}, dropping
records older than max_age_s (so a dead/hung engine is ignored)."""
def __init__(self, uri: str, max_age_s: float = 2.0):
self.uri = uri
self.max_age_s = max_age_s
self.kind = None
if uri.startswith("file://"):
self.kind = "file"
self.dir = uri[len("file://"):]
elif uri.startswith("redis://"):
self.kind = "redis"
import redis
self.r = redis.Redis.from_url(uri)
else:
raise ValueError(f"unsupported engine-state URI: {uri}")
def read_all(self) -> dict[str, dict]:
now = time.time()
out: dict[str, dict] = {}
try:
if self.kind == "file":
import glob
for p in glob.glob(os.path.join(self.dir, "*.json")):
try:
s = json.load(open(p))
except Exception:
continue
if now - s.get("ts", 0) <= self.max_age_s:
out[s.get("engine_id", os.path.basename(p)[:-5])] = s
elif self.kind == "redis":
for k in self.r.scan_iter("engine_state:*"):
v = self.r.get(k)
if not v:
continue
s = json.loads(v)
if now - s.get("ts", 0) <= self.max_age_s:
out[s.get("engine_id")] = s
except Exception:
pass
return out