Chunk-safe + concurrent layer-wise connector (per-step incremental shipping)

Scheduler tracks per-producer block_ids (accumulated from scheduler_output)
and emits per-step LWSendMeta with cumulative computed_tokens. Worker
lw_wait_for_save records a CUDA event per step and enqueues progress; the
sender-loop ship loop drains it, shipping only computed+dst-wanted+unshipped
blocks in order (correct under chunked prefill). Per-transfer state =
concurrent-safe. Keeps v1 single-transfer version as reference.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-28 17:15:54 +08:00
parent e77bdcac5a
commit 4242bba034
2 changed files with 1892 additions and 179 deletions

View File

@@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import math
import os
import threading
import time
from collections import defaultdict
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import IntEnum
@@ -135,11 +136,30 @@ class SendBlockMeta:
sending: int = 0
@dataclass
class LWSendMeta:
"""LAYERWISE: per-step incremental producer state (scheduler -> worker).
Emitted every step a producer prefill makes progress. The worker ships
only blocks whose KV is computed this step (chunk-safe), matched to the
dst's requested blocks, in block order, exactly once each.
"""
transfer_id: TransferId
p_req_id: ReqId
local_block_ids: list[int] # full current src block list for the request
total_blocks: int # ceil(num_prompt_tokens / block_size)
computed_tokens: int # cumulative tokens computed AFTER this step
is_last: bool # request prefill finished this step
class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
self.reqs_to_send: dict[ReqId, tuple[TransferId, list[int]]] = {}
self.reqs_not_processed: set[TransferId] = set()
# LAYERWISE: per-step incremental producer progress (separate channel
# so the stock reqs_to_send path is untouched when layerwise is off).
self.lw_send: dict[ReqId, LWSendMeta] = {}
# Hash table sync: scheduler → worker (for direct RDMA read)
self.hash_table_updates: dict[str, int] = {} # hex hash → block_id
self.hash_table_removals: set[str] = set()
@@ -270,13 +290,13 @@ class MooncakeConnector(KVConnectorBase_V1):
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""LAYERWISE: record that this layer's KV is computed so the sender
can push it during prefill. No-op unless MOONCAKE_LAYERWISE=1."""
if self.connector_worker is not None:
self.connector_worker.note_layer_computed(layer_name)
"""MooncakeConnector does not save explicitly."""
pass
def wait_for_save(self):
pass
if self.connector_worker is not None:
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
self.connector_worker.lw_wait_for_save(self._connector_metadata)
class MooncakeConnectorScheduler:
@@ -301,9 +321,11 @@ class MooncakeConnectorScheduler:
self._reqs_not_processed: set[TransferId] = set()
self._req_block_hashes: dict[ReqId, list[bytes]] = {}
self._req_token_ids: dict[ReqId, list[int]] = {}
# LAYERWISE: capture producer block_ids at alloc (before prefill done).
# LAYERWISE producer tracking: req_id -> dict(request, transfer_id,
# block_ids[list, grows], total_blocks). Emitted each step in lw_send.
self._lw_enabled = os.environ.get("MOONCAKE_LAYERWISE", "0") == "1"
self._lw_sent_once: set[ReqId] = set()
self._lw_prod: dict[ReqId, dict] = {}
self._block_size = vllm_config.cache_config.block_size
def set_block_pool(self, block_pool):
self._block_pool = block_pool
@@ -414,16 +436,22 @@ class MooncakeConnectorScheduler:
if not params.get("transfer_id"):
logger.warning("Missing transfer_id in kv_transfer_params from router!")
elif self._lw_enabled:
# LAYERWISE: capture the producer block_ids NOW (at alloc), so
# the worker learns local block_ids and sets `ready` before
# prefill finishes — enabling per-layer writes during prefill.
# LAYERWISE: register producer; emit incremental progress each
# step in build_connector_meta. Block list accumulates from
# scheduler_output across chunked-prefill steps.
try:
block_groups = blocks.get_block_ids()
local_block_ids = list(block_groups[0]) if block_groups else []
except Exception as e:
logger.warning("LAYERWISE: failed to get block_ids at alloc: %s", e)
local_block_ids = []
self._reqs_need_send[request.request_id] = (request, local_block_ids)
bg = blocks.get_block_ids()
block_ids = list(bg[0]) if bg else []
except Exception:
block_ids = []
total_blocks = math.ceil(
request.num_prompt_tokens / self._block_size)
self._lw_prod[request.request_id] = {
"request": request,
"transfer_id": params["transfer_id"],
"block_ids": block_ids,
"total_blocks": total_blocks,
}
else:
# Add an empty list to worker to create event.
self._reqs_need_send[request.request_id] = (request, [])
@@ -485,8 +513,53 @@ class MooncakeConnectorScheduler:
meta.reqs_not_processed = self._reqs_not_processed
self._reqs_not_processed = set()
if self._lw_enabled and self._lw_prod:
self._lw_emit(scheduler_output, meta)
return meta
def _lw_emit(self, scheduler_output, meta):
"""LAYERWISE: accumulate block_ids from scheduler_output and emit
per-step incremental producer progress for each active producer."""
# 1) accumulate newly-allocated blocks this step.
for nr in scheduler_output.scheduled_new_reqs:
st = self._lw_prod.get(nr.req_id)
if st is not None and nr.block_ids:
st["block_ids"] = list(nr.block_ids[0])
cached = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached.req_ids):
st = self._lw_prod.get(req_id)
if st is None:
continue
nb = cached.new_block_ids[i]
if nb is not None and nb[0]:
if req_id in cached.resumed_req_ids:
st["block_ids"] = list(nb[0])
else:
st["block_ids"].extend(nb[0])
# 2) emit progress; drop finished producers.
nsched = scheduler_output.num_scheduled_tokens
finished = []
for req_id, st in self._lw_prod.items():
req = st["request"]
scheduled = nsched.get(req_id, 0)
computed_after = req.num_computed_tokens + scheduled
is_last = computed_after >= req.num_prompt_tokens
if scheduled == 0 and not is_last:
continue # no progress this step
meta.lw_send[req_id] = LWSendMeta(
transfer_id=st["transfer_id"],
p_req_id=req_id,
local_block_ids=list(st["block_ids"]),
total_blocks=st["total_blocks"],
computed_tokens=computed_after,
is_last=is_last,
)
if is_last:
finished.append(req_id)
for req_id in finished:
self._lw_prod.pop(req_id, None)
def request_finished(
self,
request: "Request",
@@ -539,9 +612,10 @@ class MooncakeConnectorScheduler:
delay_free_blocks = len(send_block_ids) > 0
if self._lw_enabled:
# LAYERWISE: the transfer was already driven from the alloc-time
# block_ids during prefill; do NOT re-enqueue (would double-send).
# Keep blocks alive until the worker signals finished_sending.
# LAYERWISE: the worker drives the transfer incrementally during
# prefill and signals finished_sending when done; just keep the
# blocks alive until then. Do NOT re-enqueue (no post-hoc send).
self._lw_prod.pop(request.request_id, None)
return delay_free_blocks, None
if delay_free_blocks:
@@ -613,24 +687,15 @@ class MooncakeConnectorWorker:
self.device_kv_caches: dict[str, torch.Tensor] = {}
self.reqs_need_send: dict[TransferId, SendBlockMeta] = {}
# --- LAYERWISE (opt-in via MOONCAKE_LAYERWISE=1) ---------------------
# Push KV per-layer as prefill computes it, so the RDMA write overlaps
# the remaining prefill compute instead of being a post-hoc full
# transfer. Off by default => byte-identical upstream behaviour.
# --- LAYERWISE worker state (opt-in via MOONCAKE_LAYERWISE=1) --------
# Per-transfer incremental-shipping state. wait_for_save (main thread)
# enqueues per-step (LWSendMeta, cuda_event); the ship loop on the
# sender_loop drains it once the dst handshake provides dst addrs.
self._lw_enabled = os.environ.get("MOONCAKE_LAYERWISE", "0") == "1"
self._lw_layer_pos: dict[str, int] = {} # layer_name -> 0..N-1
self._lw_addr_idx: dict[int, list[int]] = {} # layer_pos -> base-addr idxs
self._lw_num_layers: int = 0
# Single-producer-transfer-at-a-time (sufficient for the microbench):
# _lw_gmax is the highest layer position whose KV is computed in the
# current producer prefill; reset in record_send_reqs (which runs when
# the producer request is scheduled, before its forward starts).
self._lw_gmax: int = -1
self._lw_producing: bool = False
self._lw_events: dict[int, Any] = {} # layer_pos -> cuda Event (HBM-ready)
self._lw_xfers: dict[str, dict] = {}
self._lw_lock = threading.Lock()
if self._lw_enabled:
logger.info("Mooncake LAYERWISE mode ENABLED")
logger.info("Mooncake LAYERWISE worker ENABLED")
# For kv_both, we will act both prefiller and decoder.
if not self.is_kv_consumer:
@@ -813,6 +878,9 @@ class MooncakeConnectorWorker:
async def send_kv_to_decode(
self, identity: bytes, sock: zmq.asyncio.Socket, meta: MooncakeXferMetadata
):
if self._lw_enabled:
await self._lw_send_kv(identity, sock, meta)
return
pending_reqs: dict[ReqId, SendBlockMeta] = {}
remote_tp_ranks = self.kv_topo.get_target_remote_ranks(meta.remote_tp_size)
if self.tp_rank not in remote_tp_ranks:
@@ -837,11 +905,6 @@ class MooncakeConnectorWorker:
send_meta = self.reqs_need_send[transfer_id]
pending_reqs[d_req_id] = send_meta
if self._lw_enabled:
await self._send_kv_layerwise(
identity, sock, meta, pending_reqs, remote_tp_ranks)
return
async def wait_and_ret(
d_req_id: ReqId, send_meta: SendBlockMeta
) -> tuple[ReqId, SendBlockMeta]:
@@ -962,123 +1025,118 @@ class MooncakeConnectorWorker:
"Mooncake: Heterogeneous TP is not supported yet."
)
def note_layer_computed(self, layer_name: str):
"""LAYERWISE: called from save_kv_layer after layer L's attention runs.
Records a CUDA event so the sender can wait until L's KV is actually
in HBM before RDMA-reading it, and bumps the per-transfer high-water
mark of computed layers.
"""
if not self._lw_enabled or not self._lw_producing:
return
pos = self._lw_layer_pos.get(layer_name)
if pos is None:
# ---------------- LAYERWISE worker methods ----------------
def lw_wait_for_save(self, metadata: "MooncakeConnectorMetadata"):
"""End-of-forward (main thread): record a CUDA event marking this
step's KV as in-HBM, and enqueue the step's progress for each active
producer. The ship loop (sender_loop) drains it."""
if not self._lw_enabled or not metadata.lw_send:
return
ev = torch.cuda.Event()
ev.record(torch.cuda.current_stream())
with self._lw_lock:
self._lw_events[pos] = ev
if pos > self._lw_gmax:
self._lw_gmax = pos
for req_id, lw in metadata.lw_send.items():
tid = lw.transfer_id
x = self._lw_xfers.get(tid)
if x is None:
x = {"pending": deque(), "dst_meta": None,
"remote_block_ids": None, "num_remote": None,
"d_req_id": None, "shipped": 0, "last_seen": False,
"p_req_id": lw.p_req_id}
self._lw_xfers[tid] = x
x["pending"].append((lw, ev))
if lw.is_last:
x["last_seen"] = True
def _build_layer_params(self, d_req_id, send_meta, agent_meta, layer_pos):
"""Like _build_transfer_params but only this layer's base-addr slots."""
async def _lw_send_kv(self, identity, sock, meta: MooncakeXferMetadata):
"""Handshake arrived: register dst addrs, then drain each transfer's
per-step queue, shipping computed blocks during prefill."""
remote_session = f"{meta.remote_hostname}:{meta.remote_port}"
targets = []
for d_req_id, (tid, remote_blocks) in meta.req_blocks.items():
with self._lw_lock:
x = self._lw_xfers.get(tid)
if x is None:
x = {"pending": deque(), "shipped": 0, "last_seen": False,
"p_req_id": None}
self._lw_xfers[tid] = x
x["dst_meta"] = meta
x["d_req_id"] = d_req_id
x["remote_block_ids"] = list(remote_blocks)
x["num_remote"] = len(remote_blocks)
targets.append((d_req_id, tid))
ok_reqs = []
deadline = time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
for d_req_id, tid in targets:
await self._lw_drain_one(tid, remote_session, deadline)
ok_reqs.append(d_req_id)
response = MooncakeXferResponse(
status=MooncakeXferResponseStatus.FINISH, ok_reqs=ok_reqs)
await sock.send_multipart((identity, self._encoder.encode(response)))
async def _lw_drain_one(self, tid, remote_session, deadline):
x = self._lw_xfers[tid]
num_remote = x["num_remote"]
bs = self.block_size
p_req_id = None
while True:
with self._lw_lock:
entry = x["pending"].popleft() if x["pending"] else None
last_seen = x["last_seen"]
if entry is None:
if x["shipped"] >= num_remote:
break
if time.perf_counter() > deadline:
logger.error("lw drain timeout %s shipped=%d/%d",
tid, x["shipped"], num_remote)
break
await asyncio.sleep(0.0005)
continue
lw, ev = entry
p_req_id = lw.p_req_id
ev.synchronize() # this step's KV is now in HBM
total = lw.total_blocks
start = total - num_remote
computed_blocks = total if lw.is_last else (lw.computed_tokens // bs)
ship_end = min(computed_blocks, total)
cursor = start + x["shipped"]
if ship_end > cursor and ship_end <= len(lw.local_block_ids):
local_slice = lw.local_block_ids[cursor:ship_end]
remote_slice = x["remote_block_ids"][
x["shipped"]: x["shipped"] + (ship_end - cursor)]
if (local_slice and remote_slice
and len(local_slice) == len(remote_slice)):
ret = await self.sender_loop.run_in_executor(
self._sender_executor, self._lw_build_send,
x["dst_meta"], local_slice, remote_slice, remote_session)
if ret != 0:
logger.error("lw send ret=%d tid=%s", ret, tid)
x["shipped"] += (ship_end - cursor)
if last_seen and x["shipped"] >= num_remote and not x["pending"]:
break
with self._lw_lock:
self._lw_xfers.pop(tid, None)
self.reqs_need_send.pop(tid, None)
if p_req_id is not None:
self.finished_sending_reqs.add(p_req_id)
logger.info("lw: transfer %s done (shipped %d/%d blocks)",
tid, x["shipped"], num_remote)
def _lw_build_send(self, dst_meta, local_slice, remote_slice, remote_session):
"""Ship local_slice -> remote_slice for ALL layers (one RDMA batch)."""
g_local, g_remote = group_concurrent_contiguous(local_slice, remote_slice)
block_len = self.block_len
src_ptrs: list[int] = []
dst_ptrs: list[int] = []
lengths: list[int] = []
_, remote_block_ids = agent_meta.req_blocks[d_req_id]
num_remote = len(remote_block_ids)
if num_remote == 0:
return src_ptrs, dst_ptrs, lengths
local_block_ids = send_meta.local_block_ids
if len(local_block_ids) < num_remote:
logger.error("layerwise %s: local blocks(%d) < remote(%d)",
d_req_id, len(local_block_ids), num_remote)
return src_ptrs, dst_ptrs, lengths
if len(local_block_ids) > num_remote:
local_block_ids = local_block_ids[-num_remote:]
g_local, g_remote = group_concurrent_contiguous(
local_block_ids, remote_block_ids)
block_len = self.block_len
for addr_idx in self._lw_addr_idx[layer_pos]:
local_layer_addr = self.kv_caches_base_addr[addr_idx]
remote_layer_addr = agent_meta.kv_caches_base_addr[addr_idx]
for laddr, raddr in zip(self.kv_caches_base_addr,
dst_meta.kv_caches_base_addr):
for gl, gr in zip(g_local, g_remote):
src_ptrs.append(local_layer_addr + gl[0] * block_len)
dst_ptrs.append(remote_layer_addr + gr[0] * block_len)
src_ptrs.append(laddr + gl[0] * block_len)
dst_ptrs.append(raddr + gr[0] * block_len)
lengths.append(block_len * len(gl))
return src_ptrs, dst_ptrs, lengths
async def _send_kv_layerwise(
self, identity, sock, meta, pending_reqs, remote_tp_ranks,
):
"""Write each layer's KV as soon as prefill computes it (write mode)."""
ready_reqs: list[tuple[ReqId, SendBlockMeta]] = []
for d_req_id, send_meta in pending_reqs.items():
try:
await asyncio.wait_for(
send_meta.ready.wait(),
timeout=envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT)
except asyncio.TimeoutError:
logger.warning("layerwise: timeout waiting block_ids for %s", d_req_id)
continue
send_meta.sending += 1
if not send_meta.need_send:
self.resolve_need_send(send_meta, remote_tp_ranks)
ready_reqs.append((d_req_id, send_meta))
if not ready_reqs:
response = MooncakeXferResponse(
status=MooncakeXferResponseStatus.FINISH,
err_reqs=list(pending_reqs), err_msg="layerwise: no ready reqs")
await sock.send_multipart((identity, self._encoder.encode(response)))
return
remote_session = f"{meta.remote_hostname}:{meta.remote_port}"
t0 = time.perf_counter()
deadline = t0 + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
for layer_pos in range(self._lw_num_layers):
# Wait until this layer's KV is computed (poll; layers are ms).
while True:
with self._lw_lock:
done = self._lw_gmax
if done >= layer_pos:
break
if time.perf_counter() > deadline:
logger.error("layerwise: timeout at layer %d (gmax=%d)",
layer_pos, done)
break
await asyncio.sleep(0.0005)
# Ensure layer's KV write to HBM is complete before RDMA read.
with self._lw_lock:
ev = self._lw_events.get(layer_pos)
if ev is not None:
ev.synchronize()
for d_req_id, send_meta in ready_reqs:
src_ptrs, dst_ptrs, lengths = self._build_layer_params(
d_req_id, send_meta, meta, layer_pos)
if src_ptrs:
ret = await self.sender_loop.run_in_executor(
self._sender_executor, self._send_blocks,
remote_session, src_ptrs, dst_ptrs, lengths)
if ret != 0:
logger.error("layerwise: _send_blocks ret=%d layer=%d",
ret, layer_pos)
logger.info("layerwise: transfer done in %.3fs (%d layers, %d reqs)",
time.perf_counter() - t0, self._lw_num_layers, len(ready_reqs))
with self._lw_lock:
self._lw_producing = False
for d_req_id, send_meta in ready_reqs:
send_meta.sending -= 1
send_meta.sent += 1
if send_meta.sent == send_meta.need_send:
self.reqs_need_send.pop(send_meta.transfer_id, None)
self.finished_sending_reqs.add(send_meta.p_req_id)
response = MooncakeXferResponse(
status=MooncakeXferResponseStatus.FINISH,
ok_reqs=[d for d, _ in ready_reqs])
await sock.send_multipart((identity, self._encoder.encode(response)))
return self._send_blocks(remote_session, src_ptrs, dst_ptrs, lengths)
async def _build_transfer_params(
self,
@@ -1214,18 +1272,6 @@ class MooncakeConnectorWorker:
"registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len
)
# --- LAYERWISE: map layer_name -> position, position -> base-addr idxs.
if self._lw_enabled:
n_per = 2 if split_k_and_v else 1
layer_names = list(kv_caches.keys())
self._lw_num_layers = len(layer_names)
for pos, ln in enumerate(layer_names):
self._lw_layer_pos[ln] = pos
self._lw_addr_idx[pos] = [pos * n_per + j for j in range(n_per)]
logger.info("LAYERWISE: %d layers, %d base-addrs (split_k_and_v=%s)",
self._lw_num_layers, len(self.kv_caches_base_addr),
split_k_and_v)
if self.is_kv_consumer:
return
@@ -1546,32 +1592,14 @@ class MooncakeConnectorWorker:
async def record_send_reqs(self, metadata: MooncakeConnectorMetadata):
for p_req_id, (transfer_id, block_ids) in metadata.reqs_to_send.items():
if block_ids:
# Already gone through request_finished() — OR, in LAYERWISE
# mode, block_ids arrive at alloc time and the SendBlockMeta
# may not exist yet, so create it on demand.
send_meta = self.reqs_need_send.get(transfer_id)
if send_meta is None:
send_meta = SendBlockMeta(
p_req_id=p_req_id,
transfer_id=transfer_id,
local_block_ids=[],
ready=asyncio.Event(),
)
self.reqs_need_send[transfer_id] = send_meta
# Already gone through request_finished()
send_meta = self.reqs_need_send[transfer_id]
send_meta.p_req_id = p_req_id
send_meta.local_block_ids = block_ids
send_meta.expire_time = (
time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
)
send_meta.ready.set()
if self._lw_enabled:
# Producer scheduled (before its prefill forward) — reset the
# layer high-water mark so note_layer_computed tracks THIS
# request's layers from scratch.
with self._lw_lock:
self._lw_gmax = -1
self._lw_events.clear()
self._lw_producing = True
else:
# From update_state_after_alloc(),
# but not reach request_finished() yet

File diff suppressed because it is too large Load Diff