Direct RDMA read: D reads cached KV from C's GPU without C's scheduler
Complete implementation of direct RDMA read for KV cache migration: vLLM Mooncake connector (mooncake_connector.py): - PullReqMeta: add direct_read flag + block_hashes - MooncakeConnectorMetadata: add hash_table_updates/removals for scheduler->worker block hash sync - MooncakeConnectorScheduler: set_block_pool() to access BlockPool, build_connector_meta() computes hash table deltas each step, update_state_after_alloc() captures request block hashes for direct_read - MooncakeConnectorWorker: _start_direct_read() + _direct_read_single() implements D-side RDMA read via batch_transfer_sync_read, with HTTP query/unpin to C's bootstrap server Bootstrap server (mooncake_utils.py): - POST /query_blocks: look up block hashes, return block_ids + GPU layout - POST /unpin_blocks: release pin tracking - set_worker_kv_info(): register GPU addresses at init - update_hash_table(): receive scheduler deltas each step Scheduler (scheduler.py): - One-line hookup: pass block_pool to connector after KVCacheManager init Proxy (cache_aware_proxy.py): - _handle_direct_read_offload: sends request ONLY to D with direct_read=True + remote_bootstrap_addr. No request to C at all. - C's scheduler is completely uninvolved (0 GPU time on C) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -351,33 +351,28 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
offload_reason = "colocated_cheaper_%.1fvs%.1f" % (colocated_cost, offload_cost)
|
||||
|
||||
if use_offload:
|
||||
p_inst = p_candidate
|
||||
# Direct RDMA read: D reads cached KV from C_s's GPU, no request to C_s
|
||||
c_inst = best_inst # has cache (not doing any work)
|
||||
d_inst = d_candidate
|
||||
d_idx = combined_instances.index(d_inst)
|
||||
|
||||
# Accounting: reserve both P and D immediately so router sees the load
|
||||
p_new = max(0, input_length - p_inst.estimate_cache_hit(token_ids)) if token_ids else input_length
|
||||
p_inst.ongoing_tokens += input_length
|
||||
p_inst.pending_prefill_tokens += p_new
|
||||
p_inst.num_requests += 1
|
||||
p_inst.active_p_offloads += 1
|
||||
breakdown["p_new_tokens"] = p_new
|
||||
|
||||
d_inst.ongoing_tokens += input_length
|
||||
d_inst.pending_prefill_tokens += estimated_new
|
||||
d_inst.num_requests += 1
|
||||
c_inst.active_p_offloads += 1
|
||||
|
||||
breakdown["route_class"] = "HEAVY_OFFLOAD"
|
||||
breakdown["offload_reason"] = offload_reason
|
||||
breakdown["p_inst"] = p_inst.url
|
||||
breakdown["c_inst"] = c_inst.url
|
||||
breakdown["d_inst"] = d_inst.url
|
||||
breakdown["p_load"] = p_inst.ongoing_tokens
|
||||
breakdown["d_load"] = d_inst.ongoing_tokens
|
||||
breakdown["cache_hit_tokens"] = cache_hit
|
||||
|
||||
if session_id:
|
||||
session_affinity[session_id] = d_idx
|
||||
|
||||
return await _handle_heavy_offload(api, req_data, headers, token_ids,
|
||||
input_length, p_inst, d_inst, breakdown)
|
||||
return await _handle_direct_read_offload(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
c_inst, d_inst, cache_hit, estimated_new, breakdown)
|
||||
else:
|
||||
if estimated_new >= HEAVY_THRESHOLD:
|
||||
breakdown["route_class"] = "HEAVY_COLO"
|
||||
@@ -423,112 +418,31 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
|
||||
|
||||
|
||||
async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
|
||||
p_inst, d_inst, breakdown):
|
||||
"""HEAVY request with cache-aware KV migration.
|
||||
|
||||
C_s (p_inst, has cache) exports cached KV blocks via Mooncake.
|
||||
D (d_inst, idle) pulls cached blocks + does local prefill for new tokens + decodes.
|
||||
C_s's blocks are auto-freed by Mooncake after D confirms receipt.
|
||||
|
||||
On export failure, falls back to co-located prefill+decode on d_inst.
|
||||
async def _handle_direct_read_offload(api, req_data, headers, token_ids,
|
||||
input_length, c_inst, d_inst,
|
||||
cache_hit, estimated_new, breakdown):
|
||||
"""HEAVY request: D direct-RDMA-reads cached KV from C_s, then does
|
||||
local prefill for new tokens + decode. C_s's scheduler is NOT involved.
|
||||
"""
|
||||
request_id = headers.get("X-Request-Id", "")
|
||||
estimated_new = breakdown.get("estimated_new_tokens", 0)
|
||||
cache_hit = breakdown.get("cache_hit", 0)
|
||||
p_prefill_release = breakdown.get("p_new_tokens", estimated_new)
|
||||
|
||||
# Step 1: C_s exports cached KV blocks
|
||||
# Send TRUNCATED prompt (only cached portion) so C_s does 0 compute
|
||||
# (full prefix cache hit), then pushes cached blocks to Mooncake.
|
||||
breakdown["t_export_sent"] = _time.monotonic()
|
||||
export_ok = False
|
||||
|
||||
# Truncate prompt to cached portion (aligned to BLOCK_SIZE)
|
||||
# Align cache_hit to block boundary for remote_num_tokens
|
||||
cached_tokens = (cache_hit // BLOCK_SIZE) * BLOCK_SIZE
|
||||
if cached_tokens > 0 and token_ids:
|
||||
export_prompt = token_ids[:cached_tokens]
|
||||
else:
|
||||
export_prompt = token_ids
|
||||
breakdown["t_offload_sent"] = _time.monotonic()
|
||||
|
||||
try:
|
||||
export_data = {
|
||||
"model": req_data.get("model", ""),
|
||||
"prompt": export_prompt,
|
||||
"max_tokens": 1,
|
||||
"temperature": 0,
|
||||
"stream": False,
|
||||
"kv_transfer_params": {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"transfer_id": "xfer-" + request_id,
|
||||
},
|
||||
}
|
||||
|
||||
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||||
resp = await asyncio.wait_for(
|
||||
p_inst.client.post(api, json=export_data, headers=p_headers),
|
||||
timeout=PREFILL_TIMEOUT_S,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
await resp.aclose()
|
||||
breakdown["t_export_done"] = _time.monotonic()
|
||||
breakdown["exported_tokens"] = cached_tokens if cached_tokens > 0 else len(export_prompt)
|
||||
export_ok = True
|
||||
except Exception as e:
|
||||
breakdown["t_export_done"] = _time.monotonic()
|
||||
breakdown["export_error"] = str(e)
|
||||
finally:
|
||||
p_inst.ongoing_tokens -= input_length
|
||||
p_inst.pending_prefill_tokens -= p_prefill_release
|
||||
p_inst.num_requests -= 1
|
||||
p_inst.active_p_offloads = max(0, p_inst.active_p_offloads - 1)
|
||||
|
||||
if not export_ok:
|
||||
breakdown["route_class"] = "HEAVY_COLO_FALLBACK"
|
||||
d_inst.pending_prefill_tokens += estimated_new
|
||||
|
||||
async def generate_fallback():
|
||||
prefill_done = False
|
||||
try:
|
||||
async with d_inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if not prefill_done:
|
||||
d_inst.pending_prefill_tokens -= estimated_new
|
||||
d_inst.ongoing_decode_tokens += input_length
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
prefill_done = True
|
||||
yield chunk
|
||||
d_inst.record_prefix(token_ids)
|
||||
finally:
|
||||
if not prefill_done:
|
||||
d_inst.pending_prefill_tokens -= estimated_new
|
||||
else:
|
||||
d_inst.ongoing_decode_tokens -= input_length
|
||||
d_inst.ongoing_tokens -= input_length
|
||||
d_inst.num_requests -= 1
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
return StreamingResponse(generate_fallback(), media_type="text/event-stream")
|
||||
|
||||
# Step 2: D pulls cached blocks + does local prefill for new tokens + decodes
|
||||
exported_tokens = breakdown.get("exported_tokens", 0)
|
||||
d_inst.pending_prefill_tokens += estimated_new
|
||||
breakdown["t_decode_sent"] = _time.monotonic()
|
||||
|
||||
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
|
||||
bootstrap_addr = "http://%s:%s" % (parsed.hostname, p_inst.bootstrap_port)
|
||||
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
|
||||
bootstrap_addr = "http://%s:%s" % (parsed.hostname, c_inst.bootstrap_port)
|
||||
|
||||
# Send full prompt to D with direct_read flag
|
||||
decode_data = req_data.copy()
|
||||
decode_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": False,
|
||||
"do_remote_prefill": True,
|
||||
"direct_read": True,
|
||||
"remote_bootstrap_addr": bootstrap_addr,
|
||||
"remote_engine_id": p_inst.engine_id.get(0, ""),
|
||||
"remote_engine_id": c_inst.engine_id.get(0, ""),
|
||||
"transfer_id": "xfer-" + request_id,
|
||||
"remote_num_tokens": exported_tokens,
|
||||
"remote_num_tokens": cached_tokens,
|
||||
}
|
||||
|
||||
async def generate():
|
||||
@@ -551,6 +465,7 @@ async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
|
||||
d_inst.ongoing_decode_tokens -= input_length
|
||||
d_inst.ongoing_tokens -= input_length
|
||||
d_inst.num_requests -= 1
|
||||
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
|
||||
@@ -108,6 +108,9 @@ class PullReqMeta:
|
||||
expire_time: float = float("inf")
|
||||
# Designed for one D pairing to multiple P
|
||||
pull_tasks_count: int = 0
|
||||
# Direct RDMA read: D reads from C's GPU memory without C's scheduler
|
||||
direct_read: bool = False
|
||||
block_hashes: list[bytes] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -124,11 +127,12 @@ class SendBlockMeta:
|
||||
|
||||
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
# Use (engine_id, dp_rank) to group reqs with same dp.
|
||||
# See comments in MooncakeBootstrapServer.
|
||||
self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
|
||||
self.reqs_to_send: dict[ReqId, tuple[TransferId, list[int]]] = {}
|
||||
self.reqs_not_processed: set[TransferId] = set()
|
||||
# Hash table sync: scheduler → worker (for direct RDMA read)
|
||||
self.hash_table_updates: dict[str, int] = {} # hex hash → block_id
|
||||
self.hash_table_removals: set[str] = set()
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
@@ -136,6 +140,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
load_remote_cache: bool = True,
|
||||
block_hashes: list[bytes] | None = None,
|
||||
):
|
||||
transfer_id = kv_transfer_params["transfer_id"]
|
||||
if load_remote_cache:
|
||||
@@ -146,6 +151,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
remote_engine_id=remote_engine_id,
|
||||
remote_bootstrap_addr=kv_transfer_params["remote_bootstrap_addr"],
|
||||
transfer_id=transfer_id,
|
||||
direct_read=bool(kv_transfer_params.get("direct_read")),
|
||||
block_hashes=block_hashes or [],
|
||||
)
|
||||
else:
|
||||
self.reqs_to_send[request_id] = (transfer_id, local_block_ids)
|
||||
@@ -173,6 +180,10 @@ class MooncakeConnector(KVConnectorBase_V1):
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
|
||||
|
||||
def set_block_pool(self, block_pool):
|
||||
if self.connector_scheduler is not None:
|
||||
self.connector_scheduler.set_block_pool(block_pool)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
@@ -250,6 +261,8 @@ class MooncakeConnectorScheduler:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
self.vllm_config = vllm_config
|
||||
self._block_pool = None
|
||||
self._known_hash_keys: set = set()
|
||||
|
||||
assert vllm_config.kv_transfer_config
|
||||
self.is_kv_producer: bool = (
|
||||
@@ -260,14 +273,14 @@ class MooncakeConnectorScheduler:
|
||||
)
|
||||
logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id)
|
||||
|
||||
# Requests that need to start recv/send.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
# Reqs to remove from processed set because they're not to send after
|
||||
# remote prefill or aborted.
|
||||
self._reqs_not_processed: set[TransferId] = set()
|
||||
# Per-request block hashes for direct_read passthrough
|
||||
self._req_block_hashes: dict[ReqId, list[bytes]] = {}
|
||||
|
||||
def set_block_pool(self, block_pool):
|
||||
self._block_pool = block_pool
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
@@ -348,13 +361,21 @@ class MooncakeConnectorScheduler:
|
||||
else:
|
||||
local_block_ids = []
|
||||
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
|
||||
# Capture block hashes for direct_read
|
||||
if params.get("direct_read") and hasattr(request, "block_hashes"):
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
num_remote_blocks = (
|
||||
(num_external_tokens + block_size - 1) // block_size
|
||||
)
|
||||
self._req_block_hashes[request.request_id] = [
|
||||
bytes(h) for h in request.block_hashes[:num_remote_blocks]
|
||||
]
|
||||
else:
|
||||
logger.warning(
|
||||
"Got invalid KVTransferParams: %s. This "
|
||||
"request will not utilize KVTransfer",
|
||||
params,
|
||||
)
|
||||
# Only trigger 1 KV transfer per request.
|
||||
params["do_remote_prefill"] = False
|
||||
|
||||
if params.get("do_remote_decode"):
|
||||
@@ -371,7 +392,6 @@ class MooncakeConnectorScheduler:
|
||||
) -> KVConnectorMetadata:
|
||||
meta = MooncakeConnectorMetadata()
|
||||
|
||||
# Loop through scheduled reqs and convert to PullReqMeta.
|
||||
if not self.is_kv_producer:
|
||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
@@ -379,9 +399,30 @@ class MooncakeConnectorScheduler:
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
block_hashes=self._req_block_hashes.pop(req_id, None),
|
||||
)
|
||||
self._reqs_need_recv.clear()
|
||||
|
||||
# Sync hash table to worker for direct RDMA read block lookups
|
||||
if self._block_pool is not None:
|
||||
cache = self._block_pool.cached_block_hash_to_block._cache
|
||||
current_keys = set(cache.keys())
|
||||
new_keys = current_keys - self._known_hash_keys
|
||||
removed_keys = self._known_hash_keys - current_keys
|
||||
if new_keys or removed_keys:
|
||||
from vllm.v1.core.kv_cache_utils import get_block_hash
|
||||
for k in new_keys:
|
||||
block = cache[k]
|
||||
if isinstance(block, dict):
|
||||
bid = next(iter(block.values())).block_id
|
||||
else:
|
||||
bid = block.block_id
|
||||
meta.hash_table_updates[get_block_hash(k).hex()] = bid
|
||||
meta.hash_table_removals = {
|
||||
get_block_hash(k).hex() for k in removed_keys
|
||||
}
|
||||
self._known_hash_keys = current_keys.copy()
|
||||
|
||||
if not self.is_kv_consumer:
|
||||
for req_id, (req, block_ids) in self._reqs_need_send.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
@@ -971,7 +1012,6 @@ class MooncakeConnectorWorker:
|
||||
"registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len
|
||||
)
|
||||
|
||||
# No need to launch server for D node.
|
||||
if self.is_kv_consumer:
|
||||
return
|
||||
|
||||
@@ -979,7 +1019,13 @@ class MooncakeConnectorWorker:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._mooncake_sender_listener(ready_event), self.sender_loop
|
||||
)
|
||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||
ready_event.wait()
|
||||
|
||||
if self.bootstrap_server is not None:
|
||||
self.bootstrap_server.set_worker_kv_info(
|
||||
self.kv_caches_base_addr, self.block_len,
|
||||
self.hostname, self.rpc_port,
|
||||
)
|
||||
|
||||
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
|
||||
finished_recving_reqs = self.finished_recving_reqs
|
||||
@@ -1197,6 +1243,93 @@ class MooncakeConnectorWorker:
|
||||
|
||||
self.receive_kv(remote_engine_id, pull_metas)
|
||||
|
||||
async def _start_direct_read(
|
||||
self, reqs_by_engine: dict[EngineId, dict[ReqId, PullReqMeta]]
|
||||
):
|
||||
"""Direct RDMA read: D reads cached KV blocks from C's GPU memory
|
||||
without involving C's scheduler.
|
||||
"""
|
||||
for _engine_id, pull_metas in reqs_by_engine.items():
|
||||
for req_id, pm in pull_metas.items():
|
||||
asyncio.create_task(
|
||||
self._direct_read_single(req_id, pm)
|
||||
)
|
||||
|
||||
async def _direct_read_single(self, req_id: ReqId, pm: PullReqMeta):
|
||||
pin_token = f"dr-{req_id}-{self.tp_rank}"
|
||||
bootstrap_url = pm.remote_bootstrap_addr
|
||||
|
||||
try:
|
||||
# 1. Query C's bootstrap for block mapping
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.post(
|
||||
f"{bootstrap_url}/query_blocks",
|
||||
json={
|
||||
"block_hashes": [h.hex() for h in pm.block_hashes],
|
||||
"pin_token": pin_token,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
mapping = resp.json()
|
||||
|
||||
remote_block_ids = [b for b in mapping["block_ids"] if b is not None]
|
||||
num_matched = len(remote_block_ids)
|
||||
|
||||
if num_matched == 0:
|
||||
logger.debug("direct_read %s: no cache hit on remote", req_id)
|
||||
self.finished_recving_reqs.add(req_id)
|
||||
return
|
||||
|
||||
# 2. Compute RDMA addresses
|
||||
local_block_ids = pm.local_block_ids[:num_matched]
|
||||
remote_base_addrs = mapping["kv_caches_base_addr"]
|
||||
remote_block_len = mapping["block_len"]
|
||||
remote_session = f"{mapping['hostname']}:{mapping['rpc_port']}"
|
||||
|
||||
src_ptrs: list[int] = []
|
||||
dst_ptrs: list[int] = []
|
||||
lengths: list[int] = []
|
||||
|
||||
for local_layer_addr, remote_layer_addr in zip(
|
||||
self.kv_caches_base_addr, remote_base_addrs
|
||||
):
|
||||
grp_local, grp_remote = group_concurrent_contiguous(
|
||||
local_block_ids, remote_block_ids
|
||||
)
|
||||
for gl, gr in zip(grp_local, grp_remote):
|
||||
src_ptrs.append(remote_layer_addr + gr[0] * remote_block_len)
|
||||
dst_ptrs.append(local_layer_addr + gl[0] * self.block_len)
|
||||
lengths.append(remote_block_len * len(gr))
|
||||
|
||||
# 3. RDMA READ (D pulls from C's GPU memory)
|
||||
logger.debug(
|
||||
"direct_read %s: reading %d blocks from %s",
|
||||
req_id, num_matched, remote_session,
|
||||
)
|
||||
ret = self.engine.batch_transfer_sync_read(
|
||||
remote_session, src_ptrs, dst_ptrs, lengths,
|
||||
)
|
||||
|
||||
if ret != 0:
|
||||
logger.error("direct_read %s: RDMA read failed ret=%d", req_id, ret)
|
||||
return
|
||||
|
||||
logger.debug("direct_read %s: success (%d blocks)", req_id, num_matched)
|
||||
self.finished_recving_reqs.add(req_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("direct_read %s failed: %s", req_id, e)
|
||||
finally:
|
||||
# 4. Unpin blocks on C
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
await client.post(
|
||||
f"{bootstrap_url}/unpin_blocks",
|
||||
json={"pin_token": pin_token},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _start_load_kv(
|
||||
self, reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]]
|
||||
):
|
||||
@@ -1237,11 +1370,34 @@ class MooncakeConnectorWorker:
|
||||
assert not send_meta.ready.is_set()
|
||||
|
||||
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
||||
if not self.is_kv_producer and metadata.reqs_to_recv:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._start_load_kv(metadata.reqs_to_recv), self.receiver_loop
|
||||
# Sync hash table to bootstrap server (for direct RDMA read queries)
|
||||
if self.bootstrap_server is not None and (
|
||||
metadata.hash_table_updates or metadata.hash_table_removals
|
||||
):
|
||||
self.bootstrap_server.update_hash_table(
|
||||
metadata.hash_table_updates, metadata.hash_table_removals
|
||||
)
|
||||
|
||||
if not self.is_kv_producer and metadata.reqs_to_recv:
|
||||
# Split direct_read vs normal pull requests
|
||||
direct_reqs: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
|
||||
normal_reqs: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
|
||||
for engine_id, pull_metas in metadata.reqs_to_recv.items():
|
||||
for req_id, pm in pull_metas.items():
|
||||
if pm.direct_read:
|
||||
direct_reqs[engine_id][req_id] = pm
|
||||
else:
|
||||
normal_reqs[engine_id][req_id] = pm
|
||||
|
||||
if normal_reqs:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._start_load_kv(normal_reqs), self.receiver_loop
|
||||
)
|
||||
if direct_reqs:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._start_direct_read(direct_reqs), self.receiver_loop
|
||||
)
|
||||
|
||||
if not self.is_kv_consumer and (
|
||||
metadata.reqs_to_send or metadata.reqs_not_processed
|
||||
):
|
||||
|
||||
@@ -32,10 +32,28 @@ class EngineEntry:
|
||||
worker_addr: dict[int, dict[int, WorkerAddr]]
|
||||
|
||||
|
||||
class QueryBlocksRequest(BaseModel):
|
||||
block_hashes: list[str] # hex-encoded BlockHash values
|
||||
pin_token: str
|
||||
|
||||
|
||||
class QueryBlocksResponse(BaseModel):
|
||||
block_ids: list[int | None] # None = cache miss (prefix match stops)
|
||||
kv_caches_base_addr: list[int]
|
||||
block_len: int
|
||||
hostname: str
|
||||
rpc_port: int
|
||||
|
||||
|
||||
class UnpinBlocksRequest(BaseModel):
|
||||
pin_token: str
|
||||
|
||||
|
||||
class MooncakeBootstrapServer:
|
||||
"""
|
||||
A centralized server running on the global rank 0 prefiller worker.
|
||||
Prefiller workers register their connection info (IP, port, ranks) here.
|
||||
Also serves block mapping queries for direct RDMA read.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, host: str, port: int):
|
||||
@@ -48,13 +66,19 @@ class MooncakeBootstrapServer:
|
||||
self.server_thread: threading.Thread | None = None
|
||||
self.server: uvicorn.Server | None = None
|
||||
|
||||
# Direct RDMA read support: block hash → block_id mapping
|
||||
self._hash_table: dict[str, int] = {} # hex BlockHash → block_id
|
||||
self._kv_info: dict | None = None # set by worker at register_kv_caches
|
||||
self._pinned: dict[str, list[int]] = {} # pin_token → block_ids
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def _register_routes(self):
|
||||
# All methods are async. No need to use lock to protect data.
|
||||
self.app.post("/register")(self.register_worker)
|
||||
self.app.get("/query", response_model=dict[int, EngineEntry])(self.query)
|
||||
self.app.post("/query_blocks")(self.query_blocks)
|
||||
self.app.post("/unpin_blocks")(self.unpin_blocks)
|
||||
|
||||
def start(self):
|
||||
if self.server_thread:
|
||||
@@ -125,3 +149,54 @@ class MooncakeBootstrapServer:
|
||||
|
||||
async def query(self) -> dict[int, EngineEntry]:
|
||||
return self.workers
|
||||
|
||||
def set_worker_kv_info(
|
||||
self,
|
||||
kv_caches_base_addr: list[int],
|
||||
block_len: int,
|
||||
hostname: str,
|
||||
rpc_port: int,
|
||||
):
|
||||
self._kv_info = {
|
||||
"kv_caches_base_addr": kv_caches_base_addr,
|
||||
"block_len": block_len,
|
||||
"hostname": hostname,
|
||||
"rpc_port": rpc_port,
|
||||
}
|
||||
|
||||
def update_hash_table(
|
||||
self,
|
||||
updates: dict[str, int],
|
||||
removals: set[str],
|
||||
):
|
||||
for k in removals:
|
||||
self._hash_table.pop(k, None)
|
||||
self._hash_table.update(updates)
|
||||
|
||||
async def query_blocks(self, req: QueryBlocksRequest):
|
||||
if self._kv_info is None:
|
||||
raise HTTPException(503, "Worker KV info not registered yet")
|
||||
|
||||
block_ids: list[int | None] = []
|
||||
pinned_ids: list[int] = []
|
||||
for h in req.block_hashes:
|
||||
bid = self._hash_table.get(h)
|
||||
if bid is not None:
|
||||
block_ids.append(bid)
|
||||
pinned_ids.append(bid)
|
||||
else:
|
||||
block_ids.append(None)
|
||||
break # prefix match: stop at first miss
|
||||
|
||||
self._pinned[req.pin_token] = pinned_ids
|
||||
return QueryBlocksResponse(
|
||||
block_ids=block_ids,
|
||||
kv_caches_base_addr=self._kv_info["kv_caches_base_addr"],
|
||||
block_len=self._kv_info["block_len"],
|
||||
hostname=self._kv_info["hostname"],
|
||||
rpc_port=self._kv_info["rpc_port"],
|
||||
)
|
||||
|
||||
async def unpin_blocks(self, req: UnpinBlocksRequest):
|
||||
self._pinned.pop(req.pin_token, None)
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -234,6 +234,8 @@ class Scheduler(SchedulerInterface):
|
||||
hash_block_size=self.block_size,
|
||||
metrics_collector=self.kv_metrics_collector,
|
||||
)
|
||||
if self.connector is not None and hasattr(self.connector, "set_block_pool"):
|
||||
self.connector.set_block_pool(self.kv_cache_manager.block_pool)
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
|
||||
|
||||
Reference in New Issue
Block a user