vLLM scheduler publishes real state (running/waiting, KV free, and the max-in-progress-prefill signal /metrics lacks) to a tmpfs/redis store ~20Hz; router reads it and avoids GIL-stall (mid-large-prefill) + KV-capacity-wall targets, using real load over 30s-stale shadow counters. Components: engine_state.py (canonical+reader), instrument_engine_state.py (scheduler patch, file/redis writer), migration_target.py (scorer), proxy wiring (--engine-state-uri, off=unchanged). All unit-tested without GPU; not yet run live. See P2_ENGINE_STATE.md. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
141 lines
5.0 KiB
Python
141 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
|
"""Engine-state store: canonical snapshot + writer/reader, shared schema.
|
|
|
|
The vLLM scheduler patch (instrument_engine_state.py) inlines a faithful copy
|
|
of `compute_snapshot` + the file/redis writer (engine process needs no repo
|
|
import). The router (cache_aware_proxy) imports `StateReader` here to read the
|
|
real per-engine state instead of its stale shadow counters.
|
|
|
|
Schema (one record per engine, key = engine_id):
|
|
ts, engine_id, num_running, num_waiting, gpu_blocks_total, gpu_blocks_free,
|
|
gpu_kv_used_frac, pending_prefill_tokens, ongoing_decode_tokens,
|
|
num_prefilling, max_prefill_remaining
|
|
|
|
Transport URIs:
|
|
file:///dev/shm/agentic_engine_state (default; atomic temp+rename)
|
|
redis://host:port/0 (optional; needs redis-py)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
|
|
|
|
def compute_snapshot(scheduler, engine_id: str) -> dict:
|
|
"""Cheap O(batch) read of routing-relevant real state from a live
|
|
vLLM V1 Scheduler (duck-typed for testability)."""
|
|
try:
|
|
pool = scheduler.kv_cache_manager.block_pool
|
|
total = int(pool.num_gpu_blocks)
|
|
free = int(pool.get_num_free_blocks())
|
|
except Exception:
|
|
total = free = -1
|
|
n_run = pend = dec = n_pref = max_pref = 0
|
|
try:
|
|
for r in scheduler.running:
|
|
n_run += 1
|
|
npr = int(getattr(r, "num_prompt_tokens", 0))
|
|
nct = int(getattr(r, "num_computed_tokens", 0))
|
|
if nct < npr:
|
|
rem = npr - nct
|
|
pend += rem
|
|
n_pref += 1
|
|
max_pref = max(max_pref, rem)
|
|
else:
|
|
dec += int(getattr(r, "num_tokens", 0))
|
|
except Exception:
|
|
pass
|
|
n_wait = 0
|
|
try:
|
|
n_wait = len(scheduler.waiting) + len(getattr(scheduler, "skipped_waiting", []))
|
|
for r in list(scheduler.waiting):
|
|
pend += max(0, int(getattr(r, "num_prompt_tokens", 0))
|
|
- int(getattr(r, "num_computed_tokens", 0)))
|
|
except Exception:
|
|
pass
|
|
used = ((total - free) / total) if (total and total > 0) else -1.0
|
|
return {
|
|
"ts": time.time(),
|
|
"engine_id": engine_id,
|
|
"num_running": n_run,
|
|
"num_waiting": int(n_wait),
|
|
"gpu_blocks_total": total,
|
|
"gpu_blocks_free": free,
|
|
"gpu_kv_used_frac": used,
|
|
"pending_prefill_tokens": int(pend),
|
|
"ongoing_decode_tokens": int(dec),
|
|
"num_prefilling": n_pref,
|
|
"max_prefill_remaining": int(max_pref),
|
|
}
|
|
|
|
|
|
class StateWriter:
|
|
def __init__(self, uri: str, engine_id: str):
|
|
self.engine_id = engine_id
|
|
self.kind = None
|
|
if uri.startswith("file://"):
|
|
self.kind = "file"
|
|
self.dir = uri[len("file://"):]
|
|
os.makedirs(self.dir, exist_ok=True)
|
|
self.path = os.path.join(self.dir, f"{engine_id}.json")
|
|
self.tmp = self.path + f".tmp.{os.getpid()}"
|
|
elif uri.startswith("redis://"):
|
|
self.kind = "redis"
|
|
import redis
|
|
self.r = redis.Redis.from_url(uri)
|
|
self.key = f"engine_state:{engine_id}"
|
|
else:
|
|
raise ValueError(f"unsupported engine-state URI: {uri}")
|
|
|
|
def publish(self, state: dict):
|
|
if self.kind == "file":
|
|
with open(self.tmp, "w") as f:
|
|
f.write(json.dumps(state))
|
|
os.replace(self.tmp, self.path)
|
|
elif self.kind == "redis":
|
|
self.r.set(self.key, json.dumps(state), ex=5)
|
|
|
|
|
|
class StateReader:
|
|
"""Router-side reader. read_all() returns {engine_id: state}, dropping
|
|
records older than max_age_s (so a dead/hung engine is ignored)."""
|
|
def __init__(self, uri: str, max_age_s: float = 2.0):
|
|
self.uri = uri
|
|
self.max_age_s = max_age_s
|
|
self.kind = None
|
|
if uri.startswith("file://"):
|
|
self.kind = "file"
|
|
self.dir = uri[len("file://"):]
|
|
elif uri.startswith("redis://"):
|
|
self.kind = "redis"
|
|
import redis
|
|
self.r = redis.Redis.from_url(uri)
|
|
else:
|
|
raise ValueError(f"unsupported engine-state URI: {uri}")
|
|
|
|
def read_all(self) -> dict[str, dict]:
|
|
now = time.time()
|
|
out: dict[str, dict] = {}
|
|
try:
|
|
if self.kind == "file":
|
|
import glob
|
|
for p in glob.glob(os.path.join(self.dir, "*.json")):
|
|
try:
|
|
s = json.load(open(p))
|
|
except Exception:
|
|
continue
|
|
if now - s.get("ts", 0) <= self.max_age_s:
|
|
out[s.get("engine_id", os.path.basename(p)[:-5])] = s
|
|
elif self.kind == "redis":
|
|
for k in self.r.scan_iter("engine_state:*"):
|
|
v = self.r.get(k)
|
|
if not v:
|
|
continue
|
|
s = json.loads(v)
|
|
if now - s.get("ts", 0) <= self.max_age_s:
|
|
out[s.get("engine_id")] = s
|
|
except Exception:
|
|
pass
|
|
return out
|