Chunk-safe + concurrent layer-wise connector (per-step incremental shipping)
Scheduler tracks per-producer block_ids (accumulated from scheduler_output) and emits per-step LWSendMeta with cumulative computed_tokens. Worker lw_wait_for_save records a CUDA event per step and enqueues progress; the sender-loop ship loop drains it, shipping only computed+dst-wanted+unshipped blocks in order (correct under chunked prefill). Per-transfer state = concurrent-safe. Keeps v1 single-transfer version as reference. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, deque
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
@@ -135,11 +136,30 @@ class SendBlockMeta:
|
||||
sending: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LWSendMeta:
|
||||
"""LAYERWISE: per-step incremental producer state (scheduler -> worker).
|
||||
|
||||
Emitted every step a producer prefill makes progress. The worker ships
|
||||
only blocks whose KV is computed this step (chunk-safe), matched to the
|
||||
dst's requested blocks, in block order, exactly once each.
|
||||
"""
|
||||
transfer_id: TransferId
|
||||
p_req_id: ReqId
|
||||
local_block_ids: list[int] # full current src block list for the request
|
||||
total_blocks: int # ceil(num_prompt_tokens / block_size)
|
||||
computed_tokens: int # cumulative tokens computed AFTER this step
|
||||
is_last: bool # request prefill finished this step
|
||||
|
||||
|
||||
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
|
||||
self.reqs_to_send: dict[ReqId, tuple[TransferId, list[int]]] = {}
|
||||
self.reqs_not_processed: set[TransferId] = set()
|
||||
# LAYERWISE: per-step incremental producer progress (separate channel
|
||||
# so the stock reqs_to_send path is untouched when layerwise is off).
|
||||
self.lw_send: dict[ReqId, LWSendMeta] = {}
|
||||
# Hash table sync: scheduler → worker (for direct RDMA read)
|
||||
self.hash_table_updates: dict[str, int] = {} # hex hash → block_id
|
||||
self.hash_table_removals: set[str] = set()
|
||||
@@ -270,13 +290,13 @@ class MooncakeConnector(KVConnectorBase_V1):
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""LAYERWISE: record that this layer's KV is computed so the sender
|
||||
can push it during prefill. No-op unless MOONCAKE_LAYERWISE=1."""
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.note_layer_computed(layer_name)
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
if self.connector_worker is not None:
|
||||
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
|
||||
self.connector_worker.lw_wait_for_save(self._connector_metadata)
|
||||
|
||||
|
||||
class MooncakeConnectorScheduler:
|
||||
@@ -301,9 +321,11 @@ class MooncakeConnectorScheduler:
|
||||
self._reqs_not_processed: set[TransferId] = set()
|
||||
self._req_block_hashes: dict[ReqId, list[bytes]] = {}
|
||||
self._req_token_ids: dict[ReqId, list[int]] = {}
|
||||
# LAYERWISE: capture producer block_ids at alloc (before prefill done).
|
||||
# LAYERWISE producer tracking: req_id -> dict(request, transfer_id,
|
||||
# block_ids[list, grows], total_blocks). Emitted each step in lw_send.
|
||||
self._lw_enabled = os.environ.get("MOONCAKE_LAYERWISE", "0") == "1"
|
||||
self._lw_sent_once: set[ReqId] = set()
|
||||
self._lw_prod: dict[ReqId, dict] = {}
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
|
||||
def set_block_pool(self, block_pool):
|
||||
self._block_pool = block_pool
|
||||
@@ -414,16 +436,22 @@ class MooncakeConnectorScheduler:
|
||||
if not params.get("transfer_id"):
|
||||
logger.warning("Missing transfer_id in kv_transfer_params from router!")
|
||||
elif self._lw_enabled:
|
||||
# LAYERWISE: capture the producer block_ids NOW (at alloc), so
|
||||
# the worker learns local block_ids and sets `ready` before
|
||||
# prefill finishes — enabling per-layer writes during prefill.
|
||||
# LAYERWISE: register producer; emit incremental progress each
|
||||
# step in build_connector_meta. Block list accumulates from
|
||||
# scheduler_output across chunked-prefill steps.
|
||||
try:
|
||||
block_groups = blocks.get_block_ids()
|
||||
local_block_ids = list(block_groups[0]) if block_groups else []
|
||||
except Exception as e:
|
||||
logger.warning("LAYERWISE: failed to get block_ids at alloc: %s", e)
|
||||
local_block_ids = []
|
||||
self._reqs_need_send[request.request_id] = (request, local_block_ids)
|
||||
bg = blocks.get_block_ids()
|
||||
block_ids = list(bg[0]) if bg else []
|
||||
except Exception:
|
||||
block_ids = []
|
||||
total_blocks = math.ceil(
|
||||
request.num_prompt_tokens / self._block_size)
|
||||
self._lw_prod[request.request_id] = {
|
||||
"request": request,
|
||||
"transfer_id": params["transfer_id"],
|
||||
"block_ids": block_ids,
|
||||
"total_blocks": total_blocks,
|
||||
}
|
||||
else:
|
||||
# Add an empty list to worker to create event.
|
||||
self._reqs_need_send[request.request_id] = (request, [])
|
||||
@@ -485,8 +513,53 @@ class MooncakeConnectorScheduler:
|
||||
meta.reqs_not_processed = self._reqs_not_processed
|
||||
self._reqs_not_processed = set()
|
||||
|
||||
if self._lw_enabled and self._lw_prod:
|
||||
self._lw_emit(scheduler_output, meta)
|
||||
|
||||
return meta
|
||||
|
||||
def _lw_emit(self, scheduler_output, meta):
|
||||
"""LAYERWISE: accumulate block_ids from scheduler_output and emit
|
||||
per-step incremental producer progress for each active producer."""
|
||||
# 1) accumulate newly-allocated blocks this step.
|
||||
for nr in scheduler_output.scheduled_new_reqs:
|
||||
st = self._lw_prod.get(nr.req_id)
|
||||
if st is not None and nr.block_ids:
|
||||
st["block_ids"] = list(nr.block_ids[0])
|
||||
cached = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached.req_ids):
|
||||
st = self._lw_prod.get(req_id)
|
||||
if st is None:
|
||||
continue
|
||||
nb = cached.new_block_ids[i]
|
||||
if nb is not None and nb[0]:
|
||||
if req_id in cached.resumed_req_ids:
|
||||
st["block_ids"] = list(nb[0])
|
||||
else:
|
||||
st["block_ids"].extend(nb[0])
|
||||
# 2) emit progress; drop finished producers.
|
||||
nsched = scheduler_output.num_scheduled_tokens
|
||||
finished = []
|
||||
for req_id, st in self._lw_prod.items():
|
||||
req = st["request"]
|
||||
scheduled = nsched.get(req_id, 0)
|
||||
computed_after = req.num_computed_tokens + scheduled
|
||||
is_last = computed_after >= req.num_prompt_tokens
|
||||
if scheduled == 0 and not is_last:
|
||||
continue # no progress this step
|
||||
meta.lw_send[req_id] = LWSendMeta(
|
||||
transfer_id=st["transfer_id"],
|
||||
p_req_id=req_id,
|
||||
local_block_ids=list(st["block_ids"]),
|
||||
total_blocks=st["total_blocks"],
|
||||
computed_tokens=computed_after,
|
||||
is_last=is_last,
|
||||
)
|
||||
if is_last:
|
||||
finished.append(req_id)
|
||||
for req_id in finished:
|
||||
self._lw_prod.pop(req_id, None)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
@@ -539,9 +612,10 @@ class MooncakeConnectorScheduler:
|
||||
delay_free_blocks = len(send_block_ids) > 0
|
||||
|
||||
if self._lw_enabled:
|
||||
# LAYERWISE: the transfer was already driven from the alloc-time
|
||||
# block_ids during prefill; do NOT re-enqueue (would double-send).
|
||||
# Keep blocks alive until the worker signals finished_sending.
|
||||
# LAYERWISE: the worker drives the transfer incrementally during
|
||||
# prefill and signals finished_sending when done; just keep the
|
||||
# blocks alive until then. Do NOT re-enqueue (no post-hoc send).
|
||||
self._lw_prod.pop(request.request_id, None)
|
||||
return delay_free_blocks, None
|
||||
|
||||
if delay_free_blocks:
|
||||
@@ -613,24 +687,15 @@ class MooncakeConnectorWorker:
|
||||
self.device_kv_caches: dict[str, torch.Tensor] = {}
|
||||
self.reqs_need_send: dict[TransferId, SendBlockMeta] = {}
|
||||
|
||||
# --- LAYERWISE (opt-in via MOONCAKE_LAYERWISE=1) ---------------------
|
||||
# Push KV per-layer as prefill computes it, so the RDMA write overlaps
|
||||
# the remaining prefill compute instead of being a post-hoc full
|
||||
# transfer. Off by default => byte-identical upstream behaviour.
|
||||
# --- LAYERWISE worker state (opt-in via MOONCAKE_LAYERWISE=1) --------
|
||||
# Per-transfer incremental-shipping state. wait_for_save (main thread)
|
||||
# enqueues per-step (LWSendMeta, cuda_event); the ship loop on the
|
||||
# sender_loop drains it once the dst handshake provides dst addrs.
|
||||
self._lw_enabled = os.environ.get("MOONCAKE_LAYERWISE", "0") == "1"
|
||||
self._lw_layer_pos: dict[str, int] = {} # layer_name -> 0..N-1
|
||||
self._lw_addr_idx: dict[int, list[int]] = {} # layer_pos -> base-addr idxs
|
||||
self._lw_num_layers: int = 0
|
||||
# Single-producer-transfer-at-a-time (sufficient for the microbench):
|
||||
# _lw_gmax is the highest layer position whose KV is computed in the
|
||||
# current producer prefill; reset in record_send_reqs (which runs when
|
||||
# the producer request is scheduled, before its forward starts).
|
||||
self._lw_gmax: int = -1
|
||||
self._lw_producing: bool = False
|
||||
self._lw_events: dict[int, Any] = {} # layer_pos -> cuda Event (HBM-ready)
|
||||
self._lw_xfers: dict[str, dict] = {}
|
||||
self._lw_lock = threading.Lock()
|
||||
if self._lw_enabled:
|
||||
logger.info("Mooncake LAYERWISE mode ENABLED")
|
||||
logger.info("Mooncake LAYERWISE worker ENABLED")
|
||||
|
||||
# For kv_both, we will act both prefiller and decoder.
|
||||
if not self.is_kv_consumer:
|
||||
@@ -813,6 +878,9 @@ class MooncakeConnectorWorker:
|
||||
async def send_kv_to_decode(
|
||||
self, identity: bytes, sock: zmq.asyncio.Socket, meta: MooncakeXferMetadata
|
||||
):
|
||||
if self._lw_enabled:
|
||||
await self._lw_send_kv(identity, sock, meta)
|
||||
return
|
||||
pending_reqs: dict[ReqId, SendBlockMeta] = {}
|
||||
remote_tp_ranks = self.kv_topo.get_target_remote_ranks(meta.remote_tp_size)
|
||||
if self.tp_rank not in remote_tp_ranks:
|
||||
@@ -837,11 +905,6 @@ class MooncakeConnectorWorker:
|
||||
send_meta = self.reqs_need_send[transfer_id]
|
||||
pending_reqs[d_req_id] = send_meta
|
||||
|
||||
if self._lw_enabled:
|
||||
await self._send_kv_layerwise(
|
||||
identity, sock, meta, pending_reqs, remote_tp_ranks)
|
||||
return
|
||||
|
||||
async def wait_and_ret(
|
||||
d_req_id: ReqId, send_meta: SendBlockMeta
|
||||
) -> tuple[ReqId, SendBlockMeta]:
|
||||
@@ -962,123 +1025,118 @@ class MooncakeConnectorWorker:
|
||||
"Mooncake: Heterogeneous TP is not supported yet."
|
||||
)
|
||||
|
||||
def note_layer_computed(self, layer_name: str):
|
||||
"""LAYERWISE: called from save_kv_layer after layer L's attention runs.
|
||||
|
||||
Records a CUDA event so the sender can wait until L's KV is actually
|
||||
in HBM before RDMA-reading it, and bumps the per-transfer high-water
|
||||
mark of computed layers.
|
||||
"""
|
||||
if not self._lw_enabled or not self._lw_producing:
|
||||
return
|
||||
pos = self._lw_layer_pos.get(layer_name)
|
||||
if pos is None:
|
||||
# ---------------- LAYERWISE worker methods ----------------
|
||||
def lw_wait_for_save(self, metadata: "MooncakeConnectorMetadata"):
|
||||
"""End-of-forward (main thread): record a CUDA event marking this
|
||||
step's KV as in-HBM, and enqueue the step's progress for each active
|
||||
producer. The ship loop (sender_loop) drains it."""
|
||||
if not self._lw_enabled or not metadata.lw_send:
|
||||
return
|
||||
ev = torch.cuda.Event()
|
||||
ev.record(torch.cuda.current_stream())
|
||||
with self._lw_lock:
|
||||
self._lw_events[pos] = ev
|
||||
if pos > self._lw_gmax:
|
||||
self._lw_gmax = pos
|
||||
for req_id, lw in metadata.lw_send.items():
|
||||
tid = lw.transfer_id
|
||||
x = self._lw_xfers.get(tid)
|
||||
if x is None:
|
||||
x = {"pending": deque(), "dst_meta": None,
|
||||
"remote_block_ids": None, "num_remote": None,
|
||||
"d_req_id": None, "shipped": 0, "last_seen": False,
|
||||
"p_req_id": lw.p_req_id}
|
||||
self._lw_xfers[tid] = x
|
||||
x["pending"].append((lw, ev))
|
||||
if lw.is_last:
|
||||
x["last_seen"] = True
|
||||
|
||||
def _build_layer_params(self, d_req_id, send_meta, agent_meta, layer_pos):
|
||||
"""Like _build_transfer_params but only this layer's base-addr slots."""
|
||||
async def _lw_send_kv(self, identity, sock, meta: MooncakeXferMetadata):
|
||||
"""Handshake arrived: register dst addrs, then drain each transfer's
|
||||
per-step queue, shipping computed blocks during prefill."""
|
||||
remote_session = f"{meta.remote_hostname}:{meta.remote_port}"
|
||||
targets = []
|
||||
for d_req_id, (tid, remote_blocks) in meta.req_blocks.items():
|
||||
with self._lw_lock:
|
||||
x = self._lw_xfers.get(tid)
|
||||
if x is None:
|
||||
x = {"pending": deque(), "shipped": 0, "last_seen": False,
|
||||
"p_req_id": None}
|
||||
self._lw_xfers[tid] = x
|
||||
x["dst_meta"] = meta
|
||||
x["d_req_id"] = d_req_id
|
||||
x["remote_block_ids"] = list(remote_blocks)
|
||||
x["num_remote"] = len(remote_blocks)
|
||||
targets.append((d_req_id, tid))
|
||||
|
||||
ok_reqs = []
|
||||
deadline = time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
|
||||
for d_req_id, tid in targets:
|
||||
await self._lw_drain_one(tid, remote_session, deadline)
|
||||
ok_reqs.append(d_req_id)
|
||||
response = MooncakeXferResponse(
|
||||
status=MooncakeXferResponseStatus.FINISH, ok_reqs=ok_reqs)
|
||||
await sock.send_multipart((identity, self._encoder.encode(response)))
|
||||
|
||||
async def _lw_drain_one(self, tid, remote_session, deadline):
|
||||
x = self._lw_xfers[tid]
|
||||
num_remote = x["num_remote"]
|
||||
bs = self.block_size
|
||||
p_req_id = None
|
||||
while True:
|
||||
with self._lw_lock:
|
||||
entry = x["pending"].popleft() if x["pending"] else None
|
||||
last_seen = x["last_seen"]
|
||||
if entry is None:
|
||||
if x["shipped"] >= num_remote:
|
||||
break
|
||||
if time.perf_counter() > deadline:
|
||||
logger.error("lw drain timeout %s shipped=%d/%d",
|
||||
tid, x["shipped"], num_remote)
|
||||
break
|
||||
await asyncio.sleep(0.0005)
|
||||
continue
|
||||
lw, ev = entry
|
||||
p_req_id = lw.p_req_id
|
||||
ev.synchronize() # this step's KV is now in HBM
|
||||
total = lw.total_blocks
|
||||
start = total - num_remote
|
||||
computed_blocks = total if lw.is_last else (lw.computed_tokens // bs)
|
||||
ship_end = min(computed_blocks, total)
|
||||
cursor = start + x["shipped"]
|
||||
if ship_end > cursor and ship_end <= len(lw.local_block_ids):
|
||||
local_slice = lw.local_block_ids[cursor:ship_end]
|
||||
remote_slice = x["remote_block_ids"][
|
||||
x["shipped"]: x["shipped"] + (ship_end - cursor)]
|
||||
if (local_slice and remote_slice
|
||||
and len(local_slice) == len(remote_slice)):
|
||||
ret = await self.sender_loop.run_in_executor(
|
||||
self._sender_executor, self._lw_build_send,
|
||||
x["dst_meta"], local_slice, remote_slice, remote_session)
|
||||
if ret != 0:
|
||||
logger.error("lw send ret=%d tid=%s", ret, tid)
|
||||
x["shipped"] += (ship_end - cursor)
|
||||
if last_seen and x["shipped"] >= num_remote and not x["pending"]:
|
||||
break
|
||||
with self._lw_lock:
|
||||
self._lw_xfers.pop(tid, None)
|
||||
self.reqs_need_send.pop(tid, None)
|
||||
if p_req_id is not None:
|
||||
self.finished_sending_reqs.add(p_req_id)
|
||||
logger.info("lw: transfer %s done (shipped %d/%d blocks)",
|
||||
tid, x["shipped"], num_remote)
|
||||
|
||||
def _lw_build_send(self, dst_meta, local_slice, remote_slice, remote_session):
|
||||
"""Ship local_slice -> remote_slice for ALL layers (one RDMA batch)."""
|
||||
g_local, g_remote = group_concurrent_contiguous(local_slice, remote_slice)
|
||||
block_len = self.block_len
|
||||
src_ptrs: list[int] = []
|
||||
dst_ptrs: list[int] = []
|
||||
lengths: list[int] = []
|
||||
_, remote_block_ids = agent_meta.req_blocks[d_req_id]
|
||||
num_remote = len(remote_block_ids)
|
||||
if num_remote == 0:
|
||||
return src_ptrs, dst_ptrs, lengths
|
||||
local_block_ids = send_meta.local_block_ids
|
||||
if len(local_block_ids) < num_remote:
|
||||
logger.error("layerwise %s: local blocks(%d) < remote(%d)",
|
||||
d_req_id, len(local_block_ids), num_remote)
|
||||
return src_ptrs, dst_ptrs, lengths
|
||||
if len(local_block_ids) > num_remote:
|
||||
local_block_ids = local_block_ids[-num_remote:]
|
||||
g_local, g_remote = group_concurrent_contiguous(
|
||||
local_block_ids, remote_block_ids)
|
||||
block_len = self.block_len
|
||||
for addr_idx in self._lw_addr_idx[layer_pos]:
|
||||
local_layer_addr = self.kv_caches_base_addr[addr_idx]
|
||||
remote_layer_addr = agent_meta.kv_caches_base_addr[addr_idx]
|
||||
for laddr, raddr in zip(self.kv_caches_base_addr,
|
||||
dst_meta.kv_caches_base_addr):
|
||||
for gl, gr in zip(g_local, g_remote):
|
||||
src_ptrs.append(local_layer_addr + gl[0] * block_len)
|
||||
dst_ptrs.append(remote_layer_addr + gr[0] * block_len)
|
||||
src_ptrs.append(laddr + gl[0] * block_len)
|
||||
dst_ptrs.append(raddr + gr[0] * block_len)
|
||||
lengths.append(block_len * len(gl))
|
||||
return src_ptrs, dst_ptrs, lengths
|
||||
|
||||
async def _send_kv_layerwise(
|
||||
self, identity, sock, meta, pending_reqs, remote_tp_ranks,
|
||||
):
|
||||
"""Write each layer's KV as soon as prefill computes it (write mode)."""
|
||||
ready_reqs: list[tuple[ReqId, SendBlockMeta]] = []
|
||||
for d_req_id, send_meta in pending_reqs.items():
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_meta.ready.wait(),
|
||||
timeout=envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("layerwise: timeout waiting block_ids for %s", d_req_id)
|
||||
continue
|
||||
send_meta.sending += 1
|
||||
if not send_meta.need_send:
|
||||
self.resolve_need_send(send_meta, remote_tp_ranks)
|
||||
ready_reqs.append((d_req_id, send_meta))
|
||||
|
||||
if not ready_reqs:
|
||||
response = MooncakeXferResponse(
|
||||
status=MooncakeXferResponseStatus.FINISH,
|
||||
err_reqs=list(pending_reqs), err_msg="layerwise: no ready reqs")
|
||||
await sock.send_multipart((identity, self._encoder.encode(response)))
|
||||
return
|
||||
|
||||
remote_session = f"{meta.remote_hostname}:{meta.remote_port}"
|
||||
t0 = time.perf_counter()
|
||||
deadline = t0 + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
|
||||
for layer_pos in range(self._lw_num_layers):
|
||||
# Wait until this layer's KV is computed (poll; layers are ms).
|
||||
while True:
|
||||
with self._lw_lock:
|
||||
done = self._lw_gmax
|
||||
if done >= layer_pos:
|
||||
break
|
||||
if time.perf_counter() > deadline:
|
||||
logger.error("layerwise: timeout at layer %d (gmax=%d)",
|
||||
layer_pos, done)
|
||||
break
|
||||
await asyncio.sleep(0.0005)
|
||||
# Ensure layer's KV write to HBM is complete before RDMA read.
|
||||
with self._lw_lock:
|
||||
ev = self._lw_events.get(layer_pos)
|
||||
if ev is not None:
|
||||
ev.synchronize()
|
||||
for d_req_id, send_meta in ready_reqs:
|
||||
src_ptrs, dst_ptrs, lengths = self._build_layer_params(
|
||||
d_req_id, send_meta, meta, layer_pos)
|
||||
if src_ptrs:
|
||||
ret = await self.sender_loop.run_in_executor(
|
||||
self._sender_executor, self._send_blocks,
|
||||
remote_session, src_ptrs, dst_ptrs, lengths)
|
||||
if ret != 0:
|
||||
logger.error("layerwise: _send_blocks ret=%d layer=%d",
|
||||
ret, layer_pos)
|
||||
|
||||
logger.info("layerwise: transfer done in %.3fs (%d layers, %d reqs)",
|
||||
time.perf_counter() - t0, self._lw_num_layers, len(ready_reqs))
|
||||
with self._lw_lock:
|
||||
self._lw_producing = False
|
||||
for d_req_id, send_meta in ready_reqs:
|
||||
send_meta.sending -= 1
|
||||
send_meta.sent += 1
|
||||
if send_meta.sent == send_meta.need_send:
|
||||
self.reqs_need_send.pop(send_meta.transfer_id, None)
|
||||
self.finished_sending_reqs.add(send_meta.p_req_id)
|
||||
response = MooncakeXferResponse(
|
||||
status=MooncakeXferResponseStatus.FINISH,
|
||||
ok_reqs=[d for d, _ in ready_reqs])
|
||||
await sock.send_multipart((identity, self._encoder.encode(response)))
|
||||
return self._send_blocks(remote_session, src_ptrs, dst_ptrs, lengths)
|
||||
|
||||
async def _build_transfer_params(
|
||||
self,
|
||||
@@ -1214,18 +1272,6 @@ class MooncakeConnectorWorker:
|
||||
"registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len
|
||||
)
|
||||
|
||||
# --- LAYERWISE: map layer_name -> position, position -> base-addr idxs.
|
||||
if self._lw_enabled:
|
||||
n_per = 2 if split_k_and_v else 1
|
||||
layer_names = list(kv_caches.keys())
|
||||
self._lw_num_layers = len(layer_names)
|
||||
for pos, ln in enumerate(layer_names):
|
||||
self._lw_layer_pos[ln] = pos
|
||||
self._lw_addr_idx[pos] = [pos * n_per + j for j in range(n_per)]
|
||||
logger.info("LAYERWISE: %d layers, %d base-addrs (split_k_and_v=%s)",
|
||||
self._lw_num_layers, len(self.kv_caches_base_addr),
|
||||
split_k_and_v)
|
||||
|
||||
if self.is_kv_consumer:
|
||||
return
|
||||
|
||||
@@ -1546,32 +1592,14 @@ class MooncakeConnectorWorker:
|
||||
async def record_send_reqs(self, metadata: MooncakeConnectorMetadata):
|
||||
for p_req_id, (transfer_id, block_ids) in metadata.reqs_to_send.items():
|
||||
if block_ids:
|
||||
# Already gone through request_finished() — OR, in LAYERWISE
|
||||
# mode, block_ids arrive at alloc time and the SendBlockMeta
|
||||
# may not exist yet, so create it on demand.
|
||||
send_meta = self.reqs_need_send.get(transfer_id)
|
||||
if send_meta is None:
|
||||
send_meta = SendBlockMeta(
|
||||
p_req_id=p_req_id,
|
||||
transfer_id=transfer_id,
|
||||
local_block_ids=[],
|
||||
ready=asyncio.Event(),
|
||||
)
|
||||
self.reqs_need_send[transfer_id] = send_meta
|
||||
# Already gone through request_finished()
|
||||
send_meta = self.reqs_need_send[transfer_id]
|
||||
send_meta.p_req_id = p_req_id
|
||||
send_meta.local_block_ids = block_ids
|
||||
send_meta.expire_time = (
|
||||
time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
|
||||
)
|
||||
send_meta.ready.set()
|
||||
if self._lw_enabled:
|
||||
# Producer scheduled (before its prefill forward) — reset the
|
||||
# layer high-water mark so note_layer_computed tracks THIS
|
||||
# request's layers from scratch.
|
||||
with self._lw_lock:
|
||||
self._lw_gmax = -1
|
||||
self._lw_events.clear()
|
||||
self._lw_producing = True
|
||||
else:
|
||||
# From update_state_after_alloc(),
|
||||
# but not reach request_finished() yet
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user