Scheduler tracks per-producer block_ids (accumulated from scheduler_output) and emits per-step LWSendMeta with cumulative computed_tokens. Worker lw_wait_for_save records a CUDA event per step and enqueues progress; the sender-loop ship loop drains it, shipping only computed+dst-wanted+unshipped blocks in order (correct under chunked prefill). Per-transfer state = concurrent-safe. Keeps v1 single-transfer version as reference. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1714 lines
69 KiB
Python
1714 lines
69 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import asyncio
|
|
import math
|
|
import os
|
|
import threading
|
|
import time
|
|
from collections import defaultdict, deque
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from dataclasses import dataclass, field
|
|
from enum import IntEnum
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import httpx
|
|
import msgspec
|
|
import numpy as np
|
|
import torch
|
|
import zmq
|
|
import zmq.asyncio
|
|
|
|
from vllm import envs
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
|
EngineId,
|
|
TpKVTopology,
|
|
get_current_attn_backend,
|
|
)
|
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|
KVConnectorBase_V1,
|
|
KVConnectorMetadata,
|
|
KVConnectorRole,
|
|
)
|
|
from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_utils import (
|
|
MooncakeBootstrapServer,
|
|
RegisterWorkerPayload,
|
|
)
|
|
from vllm.distributed.parallel_state import (
|
|
get_pp_group,
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
is_local_first_rank,
|
|
)
|
|
from vllm.forward_context import ForwardContext
|
|
from vllm.logger import init_logger
|
|
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
|
from vllm.v1.attention.backend import AttentionMetadata
|
|
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.request import RequestStatus
|
|
|
|
try:
|
|
from mooncake.engine import TransferEngine
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Please install mooncake by following the instructions at "
|
|
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md "
|
|
"to run VLLM with MooncakeTransferEngine."
|
|
) from e
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
from vllm.v1.request import Request
|
|
|
|
ReqId = str # Internal scheduler request ID
|
|
TransferId = str # KV transfer coordination ID (shared by P/D)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# Module-level block pool for bootstrap server access (kv_both same-process only)
|
|
_shared_block_pool = None
|
|
|
|
def _set_shared_block_pool(bp):
|
|
global _shared_block_pool
|
|
_shared_block_pool = bp
|
|
|
|
|
|
class MooncakeXferMetadata(
|
|
msgspec.Struct,
|
|
omit_defaults=True, # type: ignore[call-arg]
|
|
):
|
|
remote_hostname: str
|
|
remote_port: int
|
|
remote_tp_size: int
|
|
remote_tp_rank: int
|
|
req_blocks: dict[ReqId, tuple[TransferId, list[int]]]
|
|
kv_caches_base_addr: list[int]
|
|
|
|
|
|
class MooncakeXferResponseStatus(IntEnum):
|
|
# Transfer finished
|
|
FINISH = 0
|
|
# Continue to receive
|
|
CONTINUE = 1
|
|
# Something wrong, see err_msg
|
|
ERROR = 2
|
|
|
|
|
|
class MooncakeXferResponse(
|
|
msgspec.Struct,
|
|
omit_defaults=True, # type: ignore[call-arg]
|
|
):
|
|
status: MooncakeXferResponseStatus
|
|
ok_reqs: list[ReqId] | None = None
|
|
err_reqs: list[ReqId] | None = None
|
|
err_msg: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class PullReqMeta:
|
|
d_req_id: ReqId
|
|
transfer_id: TransferId
|
|
local_block_ids: list[int]
|
|
remote_engine_id: EngineId
|
|
remote_bootstrap_addr: str
|
|
# Set expire time to avoid infinitely sending requests.
|
|
expire_time: float = float("inf")
|
|
# Designed for one D pairing to multiple P
|
|
pull_tasks_count: int = 0
|
|
# Direct RDMA read: D reads from C's GPU memory without C's scheduler
|
|
direct_read: bool = False
|
|
block_hashes: list[bytes] = field(default_factory=list)
|
|
prompt_token_ids: list[int] = field(default_factory=list)
|
|
remote_num_tokens: int = 0
|
|
|
|
|
|
@dataclass
|
|
class SendBlockMeta:
|
|
p_req_id: ReqId
|
|
transfer_id: TransferId
|
|
local_block_ids: list[int]
|
|
ready: asyncio.Event
|
|
expire_time: float = float("inf")
|
|
need_send: int = 0
|
|
sent: int = 0
|
|
sending: int = 0
|
|
|
|
|
|
@dataclass
|
|
class LWSendMeta:
|
|
"""LAYERWISE: per-step incremental producer state (scheduler -> worker).
|
|
|
|
Emitted every step a producer prefill makes progress. The worker ships
|
|
only blocks whose KV is computed this step (chunk-safe), matched to the
|
|
dst's requested blocks, in block order, exactly once each.
|
|
"""
|
|
transfer_id: TransferId
|
|
p_req_id: ReqId
|
|
local_block_ids: list[int] # full current src block list for the request
|
|
total_blocks: int # ceil(num_prompt_tokens / block_size)
|
|
computed_tokens: int # cumulative tokens computed AFTER this step
|
|
is_last: bool # request prefill finished this step
|
|
|
|
|
|
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
|
def __init__(self):
|
|
self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
|
|
self.reqs_to_send: dict[ReqId, tuple[TransferId, list[int]]] = {}
|
|
self.reqs_not_processed: set[TransferId] = set()
|
|
# LAYERWISE: per-step incremental producer progress (separate channel
|
|
# so the stock reqs_to_send path is untouched when layerwise is off).
|
|
self.lw_send: dict[ReqId, LWSendMeta] = {}
|
|
# Hash table sync: scheduler → worker (for direct RDMA read)
|
|
self.hash_table_updates: dict[str, int] = {} # hex hash → block_id
|
|
self.hash_table_removals: set[str] = set()
|
|
self.token_hash_updates: dict[str, int] = {} # str(hash(tokens)) → block_id
|
|
|
|
def add_new_req(
|
|
self,
|
|
request_id: ReqId,
|
|
local_block_ids: list[int],
|
|
kv_transfer_params: dict[str, Any],
|
|
load_remote_cache: bool = True,
|
|
block_hashes: list[bytes] | None = None,
|
|
prompt_token_ids: list[int] | None = None,
|
|
):
|
|
transfer_id = kv_transfer_params["transfer_id"]
|
|
if load_remote_cache:
|
|
remote_engine_id = kv_transfer_params["remote_engine_id"]
|
|
remote_num = kv_transfer_params.get("remote_num_tokens", 0)
|
|
self.reqs_to_recv[remote_engine_id][request_id] = PullReqMeta(
|
|
d_req_id=request_id,
|
|
local_block_ids=local_block_ids,
|
|
remote_engine_id=remote_engine_id,
|
|
remote_bootstrap_addr=kv_transfer_params["remote_bootstrap_addr"],
|
|
transfer_id=transfer_id,
|
|
direct_read=bool(kv_transfer_params.get("direct_read")),
|
|
block_hashes=block_hashes or [],
|
|
prompt_token_ids=prompt_token_ids or [],
|
|
remote_num_tokens=remote_num,
|
|
)
|
|
else:
|
|
self.reqs_to_send[request_id] = (transfer_id, local_block_ids)
|
|
|
|
|
|
class MooncakeConnector(KVConnectorBase_V1):
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
role: KVConnectorRole,
|
|
kv_cache_config: "KVCacheConfig | None" = None,
|
|
):
|
|
super().__init__(vllm_config, role, kv_cache_config)
|
|
|
|
assert vllm_config.kv_transfer_config is not None
|
|
assert vllm_config.kv_transfer_config.engine_id is not None
|
|
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
|
|
|
if role == KVConnectorRole.SCHEDULER:
|
|
self.connector_scheduler: MooncakeConnectorScheduler | None = (
|
|
MooncakeConnectorScheduler(vllm_config, self.engine_id)
|
|
)
|
|
self.connector_worker: MooncakeConnectorWorker | None = None
|
|
elif role == KVConnectorRole.WORKER:
|
|
self.connector_scheduler = None
|
|
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
|
|
|
|
def set_block_pool(self, block_pool):
|
|
if self.connector_scheduler is not None:
|
|
self.connector_scheduler.set_block_pool(block_pool)
|
|
# Also store module-level for bootstrap server access (same process for kv_both TP=1)
|
|
_set_shared_block_pool(block_pool)
|
|
|
|
############################################################
|
|
# Scheduler Side Methods
|
|
############################################################
|
|
|
|
def get_num_new_matched_tokens(
|
|
self, request: "Request", num_computed_tokens: int
|
|
) -> tuple[int, bool]:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.get_num_new_matched_tokens(
|
|
request, num_computed_tokens
|
|
)
|
|
|
|
def update_state_after_alloc(
|
|
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
|
):
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.update_state_after_alloc(
|
|
request, blocks, num_external_tokens
|
|
)
|
|
|
|
def build_connector_meta(
|
|
self,
|
|
scheduler_output: SchedulerOutput,
|
|
) -> KVConnectorMetadata:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
|
|
|
def request_finished(
|
|
self,
|
|
request: "Request",
|
|
block_ids: list[int],
|
|
) -> tuple[bool, dict[str, Any] | None]:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.request_finished(request, block_ids)
|
|
|
|
############################################################
|
|
# Worker Side Methods
|
|
############################################################
|
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
|
assert self.connector_worker is not None
|
|
self.connector_worker.register_kv_caches(kv_caches)
|
|
|
|
def get_finished(
|
|
self, finished_req_ids: set[str]
|
|
) -> tuple[set[str] | None, set[str] | None]:
|
|
"""Get the finished recving and sending requests."""
|
|
assert self.connector_worker is not None
|
|
return self.connector_worker.get_finished()
|
|
|
|
def get_block_ids_with_load_errors(self) -> set[int]:
|
|
assert self.connector_worker is not None
|
|
return self.connector_worker.get_block_ids_with_load_errors()
|
|
|
|
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
|
assert self.connector_worker is not None
|
|
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
|
|
self.connector_worker.start_load_kv(self._connector_metadata)
|
|
|
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
|
"""MooncakeConnector does not do layerwise saving."""
|
|
pass
|
|
|
|
def save_kv_layer(
|
|
self,
|
|
layer_name: str,
|
|
kv_layer: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
**kwargs,
|
|
) -> None:
|
|
"""MooncakeConnector does not save explicitly."""
|
|
pass
|
|
|
|
def wait_for_save(self):
|
|
if self.connector_worker is not None:
|
|
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
|
|
self.connector_worker.lw_wait_for_save(self._connector_metadata)
|
|
|
|
|
|
class MooncakeConnectorScheduler:
|
|
"""Implementation of Scheduler side methods"""
|
|
|
|
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
|
self.vllm_config = vllm_config
|
|
self._block_pool = None
|
|
self._known_hash_keys: set = set()
|
|
|
|
assert vllm_config.kv_transfer_config
|
|
self.is_kv_producer: bool = (
|
|
vllm_config.kv_transfer_config.kv_role == "kv_producer"
|
|
)
|
|
self.is_kv_consumer: bool = (
|
|
vllm_config.kv_transfer_config.kv_role == "kv_consumer"
|
|
)
|
|
logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id)
|
|
|
|
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
|
self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {}
|
|
self._reqs_not_processed: set[TransferId] = set()
|
|
self._req_block_hashes: dict[ReqId, list[bytes]] = {}
|
|
self._req_token_ids: dict[ReqId, list[int]] = {}
|
|
# LAYERWISE producer tracking: req_id -> dict(request, transfer_id,
|
|
# block_ids[list, grows], total_blocks). Emitted each step in lw_send.
|
|
self._lw_enabled = os.environ.get("MOONCAKE_LAYERWISE", "0") == "1"
|
|
self._lw_prod: dict[ReqId, dict] = {}
|
|
self._block_size = vllm_config.cache_config.block_size
|
|
|
|
def set_block_pool(self, block_pool):
|
|
self._block_pool = block_pool
|
|
|
|
def get_num_new_matched_tokens(
|
|
self, request: "Request", num_computed_tokens: int
|
|
) -> tuple[int, bool]:
|
|
"""
|
|
For remote prefill, pull all prompt blocks from remote
|
|
asynchronously relative to engine execution.
|
|
|
|
Args:
|
|
request (Request): the request object.
|
|
num_computed_tokens (int): the number of locally
|
|
computed tokens for this request
|
|
Returns:
|
|
* the number of tokens that can be loaded from the
|
|
external KV cache beyond what is already computed.
|
|
* true if the external KV cache tokens will be loaded
|
|
asynchronously (between scheduler steps).
|
|
"""
|
|
|
|
params = request.kv_transfer_params
|
|
logger.debug(
|
|
"MooncakeConnector get_num_new_matched_tokens: "
|
|
"num_computed_tokens=%s, kv_transfer_params=%s",
|
|
num_computed_tokens,
|
|
params,
|
|
)
|
|
|
|
if not params:
|
|
return 0, False
|
|
|
|
if params.get("do_remote_prefill"):
|
|
assert not self.is_kv_producer
|
|
token_ids = request.prompt_token_ids or []
|
|
# Partial remote prefill: only pull remote_num_tokens from remote,
|
|
# compute the rest locally. Falls back to full remote prefill
|
|
# when remote_num_tokens is not set.
|
|
remote_total = params.get("remote_num_tokens", len(token_ids))
|
|
remote_total = min(remote_total, len(token_ids))
|
|
count = max(0, remote_total - num_computed_tokens)
|
|
if count > 0:
|
|
return count, True
|
|
|
|
return 0, False
|
|
|
|
def update_state_after_alloc(
|
|
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
|
):
|
|
params = request.kv_transfer_params
|
|
logger.debug(
|
|
"MooncakeConnector update_state_after_alloc: "
|
|
"req_id=%s num_external_tokens=%s, kv_transfer_params=%s",
|
|
request.request_id,
|
|
num_external_tokens,
|
|
params,
|
|
)
|
|
|
|
if not params:
|
|
return
|
|
|
|
if params.get("do_remote_prefill"):
|
|
assert not self.is_kv_producer
|
|
if all(
|
|
p in params
|
|
for p in ("remote_engine_id", "remote_bootstrap_addr", "transfer_id")
|
|
):
|
|
if num_external_tokens > 0:
|
|
all_unhashed = blocks.get_unhashed_block_ids()
|
|
# Partial remote prefill: only receive blocks for the
|
|
# external portion, leave the rest for local compute.
|
|
if params.get("remote_num_tokens") is not None:
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
num_remote_blocks = (
|
|
(num_external_tokens + block_size - 1) // block_size
|
|
)
|
|
local_block_ids = all_unhashed[:num_remote_blocks]
|
|
else:
|
|
local_block_ids = all_unhashed
|
|
else:
|
|
local_block_ids = []
|
|
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
|
|
if params.get("direct_read"):
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
num_remote_blocks = (
|
|
(num_external_tokens + block_size - 1) // block_size
|
|
)
|
|
if hasattr(request, "block_hashes"):
|
|
self._req_block_hashes[request.request_id] = [
|
|
bytes(h) for h in request.block_hashes[:num_remote_blocks]
|
|
]
|
|
# Store prompt token_ids for token-based lookup on C
|
|
if hasattr(request, "prompt_token_ids") and request.prompt_token_ids:
|
|
self._req_token_ids[request.request_id] = list(
|
|
request.prompt_token_ids[:num_remote_blocks * block_size]
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Got invalid KVTransferParams: %s. This "
|
|
"request will not utilize KVTransfer",
|
|
params,
|
|
)
|
|
params["do_remote_prefill"] = False
|
|
|
|
if params.get("do_remote_decode"):
|
|
assert not self.is_kv_consumer
|
|
if not params.get("transfer_id"):
|
|
logger.warning("Missing transfer_id in kv_transfer_params from router!")
|
|
elif self._lw_enabled:
|
|
# LAYERWISE: register producer; emit incremental progress each
|
|
# step in build_connector_meta. Block list accumulates from
|
|
# scheduler_output across chunked-prefill steps.
|
|
try:
|
|
bg = blocks.get_block_ids()
|
|
block_ids = list(bg[0]) if bg else []
|
|
except Exception:
|
|
block_ids = []
|
|
total_blocks = math.ceil(
|
|
request.num_prompt_tokens / self._block_size)
|
|
self._lw_prod[request.request_id] = {
|
|
"request": request,
|
|
"transfer_id": params["transfer_id"],
|
|
"block_ids": block_ids,
|
|
"total_blocks": total_blocks,
|
|
}
|
|
else:
|
|
# Add an empty list to worker to create event.
|
|
self._reqs_need_send[request.request_id] = (request, [])
|
|
|
|
def build_connector_meta(
|
|
self,
|
|
scheduler_output: SchedulerOutput,
|
|
) -> KVConnectorMetadata:
|
|
meta = MooncakeConnectorMetadata()
|
|
|
|
if not self.is_kv_producer:
|
|
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
|
assert req.kv_transfer_params is not None
|
|
meta.add_new_req(
|
|
request_id=req_id,
|
|
local_block_ids=block_ids,
|
|
kv_transfer_params=req.kv_transfer_params,
|
|
block_hashes=self._req_block_hashes.pop(req_id, None),
|
|
prompt_token_ids=self._req_token_ids.pop(req_id, None),
|
|
)
|
|
self._reqs_need_recv.clear()
|
|
|
|
# Sync hash table to worker for direct RDMA read block lookups
|
|
if self._block_pool is not None:
|
|
cache = self._block_pool.cached_block_hash_to_block._cache
|
|
current_keys = set(cache.keys())
|
|
new_keys = current_keys - self._known_hash_keys
|
|
removed_keys = self._known_hash_keys - current_keys
|
|
if new_keys or removed_keys:
|
|
from vllm.v1.core.kv_cache_utils import get_block_hash
|
|
for k in new_keys:
|
|
block = cache[k]
|
|
if isinstance(block, dict):
|
|
bid = next(iter(block.values())).block_id
|
|
else:
|
|
bid = block.block_id
|
|
meta.hash_table_updates[get_block_hash(k).hex()] = bid
|
|
meta.hash_table_removals = {
|
|
get_block_hash(k).hex() for k in removed_keys
|
|
}
|
|
self._known_hash_keys = current_keys.copy()
|
|
logger.info("hash_table_sync: +%d -%d (total known=%d)",
|
|
len(new_keys), len(removed_keys), len(self._known_hash_keys))
|
|
else:
|
|
if not hasattr(self, '_bp_warned'):
|
|
logger.warning("_block_pool is None, hash table sync disabled")
|
|
self._bp_warned = True
|
|
|
|
if not self.is_kv_consumer:
|
|
for req_id, (req, block_ids) in self._reqs_need_send.items():
|
|
assert req.kv_transfer_params is not None
|
|
meta.add_new_req(
|
|
request_id=req_id,
|
|
local_block_ids=block_ids,
|
|
kv_transfer_params=req.kv_transfer_params,
|
|
load_remote_cache=False,
|
|
)
|
|
self._reqs_need_send.clear()
|
|
meta.reqs_not_processed = self._reqs_not_processed
|
|
self._reqs_not_processed = set()
|
|
|
|
if self._lw_enabled and self._lw_prod:
|
|
self._lw_emit(scheduler_output, meta)
|
|
|
|
return meta
|
|
|
|
def _lw_emit(self, scheduler_output, meta):
|
|
"""LAYERWISE: accumulate block_ids from scheduler_output and emit
|
|
per-step incremental producer progress for each active producer."""
|
|
# 1) accumulate newly-allocated blocks this step.
|
|
for nr in scheduler_output.scheduled_new_reqs:
|
|
st = self._lw_prod.get(nr.req_id)
|
|
if st is not None and nr.block_ids:
|
|
st["block_ids"] = list(nr.block_ids[0])
|
|
cached = scheduler_output.scheduled_cached_reqs
|
|
for i, req_id in enumerate(cached.req_ids):
|
|
st = self._lw_prod.get(req_id)
|
|
if st is None:
|
|
continue
|
|
nb = cached.new_block_ids[i]
|
|
if nb is not None and nb[0]:
|
|
if req_id in cached.resumed_req_ids:
|
|
st["block_ids"] = list(nb[0])
|
|
else:
|
|
st["block_ids"].extend(nb[0])
|
|
# 2) emit progress; drop finished producers.
|
|
nsched = scheduler_output.num_scheduled_tokens
|
|
finished = []
|
|
for req_id, st in self._lw_prod.items():
|
|
req = st["request"]
|
|
scheduled = nsched.get(req_id, 0)
|
|
computed_after = req.num_computed_tokens + scheduled
|
|
is_last = computed_after >= req.num_prompt_tokens
|
|
if scheduled == 0 and not is_last:
|
|
continue # no progress this step
|
|
meta.lw_send[req_id] = LWSendMeta(
|
|
transfer_id=st["transfer_id"],
|
|
p_req_id=req_id,
|
|
local_block_ids=list(st["block_ids"]),
|
|
total_blocks=st["total_blocks"],
|
|
computed_tokens=computed_after,
|
|
is_last=is_last,
|
|
)
|
|
if is_last:
|
|
finished.append(req_id)
|
|
for req_id in finished:
|
|
self._lw_prod.pop(req_id, None)
|
|
|
|
def request_finished(
|
|
self,
|
|
request: "Request",
|
|
block_ids: list[int],
|
|
) -> tuple[bool, dict[str, Any] | None]:
|
|
"""
|
|
Once a request is finished, determine whether request blocks
|
|
should be freed now or will be sent asynchronously and freed later.
|
|
"""
|
|
|
|
params = request.kv_transfer_params
|
|
logger.debug(
|
|
"MooncakeConnector request_finished, req_id=%s, request_status=%s, "
|
|
"kv_transfer_params=%s",
|
|
request.request_id,
|
|
request.status,
|
|
params,
|
|
)
|
|
if not params or not params.get("transfer_id"):
|
|
return False, None
|
|
|
|
if params.get("do_remote_prefill"):
|
|
# If do_remote_prefill is still True when the request is finished,
|
|
# update_state_after_alloc must not have been called (the request
|
|
# must have been aborted before it was scheduled).
|
|
# To avoid stranding the prefill blocks in the prefill instance,
|
|
# we must add empty block_ids to _reqs_need_recv so that our
|
|
# worker side will notify and free blocks in the prefill instance.
|
|
assert not self.is_kv_producer
|
|
self._reqs_need_recv[request.request_id] = (request, [])
|
|
params["do_remote_prefill"] = False
|
|
return False, None
|
|
|
|
if not params.get("do_remote_decode"):
|
|
return False, None
|
|
|
|
assert not self.is_kv_consumer
|
|
|
|
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
|
|
# Also include the case of a P/D Prefill request with immediate
|
|
# block free (eg abort). Stop tracking this request.
|
|
self._reqs_not_processed.add(params["transfer_id"])
|
|
return False, None
|
|
|
|
# TODO: check whether block_ids actually ever be 0. If not we could
|
|
# remove the conditional below
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
prompt_blocks = (request.num_prompt_tokens + block_size - 1) // block_size
|
|
send_block_ids = block_ids[:prompt_blocks]
|
|
delay_free_blocks = len(send_block_ids) > 0
|
|
|
|
if self._lw_enabled:
|
|
# LAYERWISE: the worker drives the transfer incrementally during
|
|
# prefill and signals finished_sending when done; just keep the
|
|
# blocks alive until then. Do NOT re-enqueue (no post-hoc send).
|
|
self._lw_prod.pop(request.request_id, None)
|
|
return delay_free_blocks, None
|
|
|
|
if delay_free_blocks:
|
|
self._reqs_need_send[request.request_id] = (request, send_block_ids)
|
|
|
|
return delay_free_blocks, None
|
|
|
|
|
|
class MooncakeConnectorWorker:
|
|
"""Implementation of Worker side methods"""
|
|
|
|
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
|
logger.info("Initializing Mooncake Transfer Engine worker %s", engine_id)
|
|
|
|
self.vllm_config = vllm_config
|
|
|
|
self.engine = TransferEngine()
|
|
self.hostname = get_ip()
|
|
|
|
assert (kv_transfer_config := vllm_config.kv_transfer_config)
|
|
self.is_kv_producer: bool = kv_transfer_config.kv_role == "kv_producer"
|
|
self.is_kv_consumer: bool = kv_transfer_config.kv_role == "kv_consumer"
|
|
self.num_sender_workers = kv_transfer_config.kv_connector_extra_config.get(
|
|
"num_workers", 10
|
|
)
|
|
# Create more tasks than workers to keep the thread pool saturated.
|
|
# Tasks can await async events, so a surplus (2x is a robust heuristic)
|
|
# prevents workers from idling.
|
|
self.num_sender_tasks = self.num_sender_workers * 2
|
|
protocol = kv_transfer_config.kv_connector_extra_config.get( # type: ignore[union-attr]
|
|
"mooncake_protocol", "rdma"
|
|
)
|
|
logger.info(
|
|
"The Mooncake Transfer Engine is using %s as its protocol.", protocol
|
|
)
|
|
ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", protocol, "")
|
|
if ret_value != 0:
|
|
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
|
|
|
|
self.rpc_port = self.engine.get_rpc_port()
|
|
|
|
logger.debug(
|
|
"Mooncake Transfer Engine initialized at %s:%d",
|
|
self.hostname,
|
|
self.rpc_port,
|
|
)
|
|
|
|
self._remote_agents: dict[EngineId, dict[int, dict[int, str]]] = {}
|
|
self._pending_bootstrap_queries: dict[str, asyncio.Event] = {}
|
|
self.side_channel_port: int = 0 # we will bind it in register_kv_caches()
|
|
self.engine_id: EngineId = engine_id
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.num_blocks = 0
|
|
self.bootstrap_server = None
|
|
|
|
assert (parallel_config := vllm_config.parallel_config)
|
|
dp_rank = parallel_config.data_parallel_index
|
|
dp_local_rank = parallel_config.data_parallel_rank_local
|
|
self.dp_rank = dp_local_rank if parallel_config.local_engines_only else dp_rank
|
|
pp_size = vllm_config.parallel_config.pipeline_parallel_size
|
|
if pp_size > 1:
|
|
raise ValueError(
|
|
"Mooncake Transfer Engine does not support pipeline parallelism yet."
|
|
)
|
|
self.pp_rank = get_pp_group().rank_in_group
|
|
|
|
self.kv_caches_base_addr: list[int] = []
|
|
self.device_kv_caches: dict[str, torch.Tensor] = {}
|
|
self.reqs_need_send: dict[TransferId, SendBlockMeta] = {}
|
|
|
|
# --- LAYERWISE worker state (opt-in via MOONCAKE_LAYERWISE=1) --------
|
|
# Per-transfer incremental-shipping state. wait_for_save (main thread)
|
|
# enqueues per-step (LWSendMeta, cuda_event); the ship loop on the
|
|
# sender_loop drains it once the dst handshake provides dst addrs.
|
|
self._lw_enabled = os.environ.get("MOONCAKE_LAYERWISE", "0") == "1"
|
|
self._lw_xfers: dict[str, dict] = {}
|
|
self._lw_lock = threading.Lock()
|
|
if self._lw_enabled:
|
|
logger.info("Mooncake LAYERWISE worker ENABLED")
|
|
|
|
# For kv_both, we will act both prefiller and decoder.
|
|
if not self.is_kv_consumer:
|
|
# Background threads for sending kvcaches to D.
|
|
self._sender_executor = ThreadPoolExecutor(
|
|
max_workers=self.num_sender_workers,
|
|
thread_name_prefix="vllm-mooncake-sender",
|
|
)
|
|
logger.debug(
|
|
"Mooncake Prefiller: use %d workers to send kvcaches",
|
|
self.num_sender_workers,
|
|
)
|
|
# An asyncio queue to buffer incoming requests for the sender
|
|
self.sender_worker_queue = asyncio.Queue[tuple[bytes, bytes]]()
|
|
self.sender_loop = asyncio.new_event_loop()
|
|
# Background thread for processing new sending requests.
|
|
self._sender_listener_t = threading.Thread(
|
|
target=_async_loop, args=(self.sender_loop,), daemon=True
|
|
)
|
|
self._sender_listener_t.start()
|
|
|
|
# Start bootstrap server on global rank 0.
|
|
if should_launch_bootstrap_server(vllm_config):
|
|
_, port = get_mooncake_bootstrap_addr(vllm_config)
|
|
self.bootstrap_server = MooncakeBootstrapServer(
|
|
vllm_config, "0.0.0.0", port
|
|
)
|
|
self.bootstrap_server.start()
|
|
|
|
if not self.is_kv_producer:
|
|
self.receiver_loop = asyncio.new_event_loop()
|
|
self._mooncake_receiver_t = threading.Thread(
|
|
target=_async_loop, args=(self.receiver_loop,), daemon=True
|
|
)
|
|
self._mooncake_receiver_t.start()
|
|
logger.debug("Mooncake Decoder: start receiver thread")
|
|
|
|
self.finished_sending_reqs: set[ReqId] = set()
|
|
self.finished_recving_reqs: set[ReqId] = set()
|
|
self.failed_recving_block_ids: set[int] = set()
|
|
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
self.model_config = vllm_config.model_config
|
|
self.cache_config = vllm_config.cache_config
|
|
self.use_mla = self.model_config.use_mla
|
|
|
|
# Get the attention backend from the first layer
|
|
# NOTE (NickLucche) models with multiple backends are not supported yet
|
|
backend = get_current_attn_backend(vllm_config)
|
|
self.backend_name = backend.get_name()
|
|
self.kv_cache_layout = get_kv_cache_layout()
|
|
logger.debug("Detected attention backend %s", self.backend_name)
|
|
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
|
|
|
self._tp_size: dict[EngineId, int] = {self.engine_id: self.tp_size}
|
|
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
|
|
self.kv_topo = TpKVTopology(
|
|
tp_rank=self.tp_rank,
|
|
engine_id=self.engine_id,
|
|
remote_tp_size=self._tp_size, # shared state
|
|
remote_block_size=self._block_size, # shared state
|
|
is_mla=self.use_mla,
|
|
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
|
attn_backends=[backend],
|
|
)
|
|
|
|
self.async_zmq_ctx = zmq.asyncio.Context()
|
|
self._encoder = msgspec.msgpack.Encoder()
|
|
self._xfer_meta_decoder = msgspec.msgpack.Decoder(MooncakeXferMetadata)
|
|
self._xfer_resp_decoder = msgspec.msgpack.Decoder(MooncakeXferResponse)
|
|
|
|
def __del__(self):
|
|
self.shutdown()
|
|
|
|
def shutdown(self):
|
|
"""Cleanup background threads on destruction."""
|
|
self.async_zmq_ctx.term()
|
|
if not self.is_kv_consumer:
|
|
self._sender_executor.shutdown(wait=False)
|
|
if self.sender_loop.is_running():
|
|
self.sender_loop.call_soon_threadsafe(self.sender_loop.stop)
|
|
self._sender_listener_t.join()
|
|
if should_launch_bootstrap_server(self.vllm_config):
|
|
self.bootstrap_server.shutdown()
|
|
if not self.is_kv_producer and self.receiver_loop.is_running():
|
|
self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop)
|
|
self._mooncake_receiver_t.join()
|
|
|
|
async def register_worker_with_bootstrap(self):
|
|
host, port = get_mooncake_bootstrap_addr(self.vllm_config)
|
|
url = make_zmq_path("http", host, port) + "/register"
|
|
worker_addr = make_zmq_path("tcp", self.hostname, self.side_channel_port)
|
|
payload = RegisterWorkerPayload(
|
|
engine_id=self.engine_id,
|
|
dp_rank=self.dp_rank,
|
|
tp_rank=self.tp_rank,
|
|
pp_rank=self.pp_rank,
|
|
addr=worker_addr,
|
|
)
|
|
while True:
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(url, json=payload.model_dump())
|
|
response.raise_for_status()
|
|
logger.debug("Successfully registered with bootstrap server at %s", url)
|
|
break
|
|
except httpx.ConnectError:
|
|
# Bootstrap server not ready, wait for a while and retry.
|
|
await asyncio.sleep(1)
|
|
except Exception as e:
|
|
err_msg = (
|
|
e.response.text if isinstance(e, httpx.HTTPStatusError) else str(e)
|
|
)
|
|
logger.error(
|
|
"Error registering %s with bootstrap server: %s", payload, err_msg
|
|
)
|
|
raise e
|
|
|
|
async def _mooncake_sender_listener(self, ready_event: threading.Event):
|
|
"""
|
|
Background thread that listens for Mooncake requests, dispatches them
|
|
to a thread pool, and sends acknowledgments upon completion.
|
|
"""
|
|
|
|
sock = self.async_zmq_ctx.socket(zmq.ROUTER)
|
|
self.side_channel_port = sock.bind_to_random_port(f"tcp://{self.hostname}")
|
|
logger.debug(
|
|
"Mooncake sender starting listening on path: tcp://%s:%d",
|
|
self.hostname,
|
|
self.side_channel_port,
|
|
)
|
|
|
|
await self.register_worker_with_bootstrap()
|
|
|
|
# Create async worker tasks that process items from the queue
|
|
sender_tasks = [
|
|
asyncio.create_task(self._sender_worker(sock))
|
|
for _ in range(self.num_sender_tasks)
|
|
]
|
|
|
|
ready_event.set()
|
|
|
|
try:
|
|
while True:
|
|
identity, metadata_bytes = await sock.recv_multipart()
|
|
await self.sender_worker_queue.put((identity, metadata_bytes))
|
|
except zmq.ContextTerminated:
|
|
logger.debug("ZMQ context terminated, exiting Mooncake sender thread.")
|
|
except Exception as e:
|
|
logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e))
|
|
finally:
|
|
# Clean up worker tasks
|
|
for task in sender_tasks:
|
|
task.cancel()
|
|
await asyncio.gather(*sender_tasks, return_exceptions=True)
|
|
sock.close()
|
|
|
|
async def _sender_worker(self, sock: zmq.asyncio.Socket):
|
|
while True:
|
|
try:
|
|
identity, metadata_bytes = await self.sender_worker_queue.get()
|
|
try:
|
|
metadata = self._xfer_meta_decoder.decode(metadata_bytes)
|
|
await self.send_kv_to_decode(identity, sock, metadata)
|
|
except Exception as e:
|
|
logger.error("Error processing Mooncake xfer request: %s", e)
|
|
error_response = MooncakeXferResponse(
|
|
status=MooncakeXferResponseStatus.ERROR, err_msg=str(e)
|
|
)
|
|
await sock.send_multipart(
|
|
(identity, self._encoder.encode(error_response))
|
|
)
|
|
finally:
|
|
self.sender_worker_queue.task_done()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error("Error in _sender_worker: %s", e)
|
|
|
|
async def send_kv_to_decode(
|
|
self, identity: bytes, sock: zmq.asyncio.Socket, meta: MooncakeXferMetadata
|
|
):
|
|
if self._lw_enabled:
|
|
await self._lw_send_kv(identity, sock, meta)
|
|
return
|
|
pending_reqs: dict[ReqId, SendBlockMeta] = {}
|
|
remote_tp_ranks = self.kv_topo.get_target_remote_ranks(meta.remote_tp_size)
|
|
if self.tp_rank not in remote_tp_ranks:
|
|
# This D worker does not pair with the P worker.
|
|
msg = f"This P tp_rank {self.tp_rank} not in remote D target ranks {remote_tp_ranks}" # noqa: E501
|
|
logger.error(msg)
|
|
response = MooncakeXferResponse(
|
|
status=MooncakeXferResponseStatus.ERROR,
|
|
err_msg=msg,
|
|
)
|
|
await sock.send_multipart((identity, self._encoder.encode(response)))
|
|
return
|
|
for d_req_id, (transfer_id, _) in meta.req_blocks.items():
|
|
if transfer_id not in self.reqs_need_send:
|
|
# This req is not enqueued in P side yet, create it here.
|
|
self.reqs_need_send[transfer_id] = SendBlockMeta(
|
|
p_req_id="",
|
|
transfer_id=transfer_id,
|
|
local_block_ids=[],
|
|
ready=asyncio.Event(),
|
|
)
|
|
send_meta = self.reqs_need_send[transfer_id]
|
|
pending_reqs[d_req_id] = send_meta
|
|
|
|
async def wait_and_ret(
|
|
d_req_id: ReqId, send_meta: SendBlockMeta
|
|
) -> tuple[ReqId, SendBlockMeta]:
|
|
await send_meta.ready.wait()
|
|
return d_req_id, send_meta
|
|
|
|
wait_tasks = [
|
|
asyncio.create_task(wait_and_ret(d_req_id, send_meta))
|
|
for d_req_id, send_meta in pending_reqs.items()
|
|
]
|
|
|
|
while wait_tasks:
|
|
done, pending = await asyncio.wait(
|
|
wait_tasks,
|
|
timeout=envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
|
|
if not done:
|
|
# Timeout, abort all pending requests.
|
|
for task in wait_tasks:
|
|
task.cancel()
|
|
logger.warning(
|
|
"Timeout waiting for P side ready: %s", list(pending_reqs)
|
|
)
|
|
response = MooncakeXferResponse(
|
|
status=MooncakeXferResponseStatus.FINISH,
|
|
err_reqs=list(pending_reqs),
|
|
err_msg="Timeout waiting for P side ready.",
|
|
)
|
|
await sock.send_multipart((identity, self._encoder.encode(response)))
|
|
break
|
|
|
|
wait_tasks = list(pending)
|
|
response_status = (
|
|
MooncakeXferResponseStatus.CONTINUE
|
|
if wait_tasks
|
|
else MooncakeXferResponseStatus.FINISH
|
|
)
|
|
ready_reqs: list[tuple[ReqId, SendBlockMeta]] = []
|
|
for task in done:
|
|
d_req_id, send_meta = task.result()
|
|
del pending_reqs[d_req_id]
|
|
# Do we still in reqs_need_send (not expired)?
|
|
if send_meta.transfer_id in self.reqs_need_send:
|
|
# Mark it sending to avoid expiration.
|
|
send_meta.sending += 1
|
|
if not send_meta.need_send:
|
|
self.resolve_need_send(send_meta, remote_tp_ranks)
|
|
ready_reqs.append((d_req_id, send_meta))
|
|
else:
|
|
# Otherwise (expired, very unlikely), just forget it.
|
|
logger.warning(
|
|
"Request %s expired before sending on P side.", d_req_id
|
|
)
|
|
|
|
src_ptrs, dst_ptrs, lengths, err_reqs = await self._build_transfer_params(
|
|
ready_reqs, meta
|
|
)
|
|
|
|
if err_reqs:
|
|
response = MooncakeXferResponse(
|
|
status=response_status,
|
|
err_reqs=err_reqs,
|
|
err_msg="P num blocks less than D",
|
|
)
|
|
await sock.send_multipart((identity, self._encoder.encode(response)))
|
|
|
|
if src_ptrs:
|
|
remote_session = f"{meta.remote_hostname}:{meta.remote_port}"
|
|
ret_value = await self.sender_loop.run_in_executor(
|
|
self._sender_executor,
|
|
self._send_blocks,
|
|
remote_session,
|
|
src_ptrs,
|
|
dst_ptrs,
|
|
lengths,
|
|
)
|
|
|
|
if ret_value != 0:
|
|
err_reqs = []
|
|
for d_req_id, send_meta in ready_reqs:
|
|
send_meta.sending -= 1
|
|
err_reqs.append(d_req_id)
|
|
# Do best effort to transfer the remaining reqs.
|
|
response = MooncakeXferResponse(
|
|
status=response_status,
|
|
err_reqs=err_reqs,
|
|
err_msg=f"Mooncake transfer engine returned {ret_value}",
|
|
)
|
|
await sock.send_multipart(
|
|
(identity, self._encoder.encode(response))
|
|
)
|
|
continue
|
|
|
|
for d_req_id, send_meta in ready_reqs:
|
|
# TODO: for heterogeneous TP (one P pairs to multiple D),
|
|
# we need to check whether all headers are sent.
|
|
# If not, we should set expire_time to normal and skip the below.
|
|
send_meta.sending -= 1
|
|
send_meta.sent += 1
|
|
if send_meta.sent == send_meta.need_send:
|
|
del self.reqs_need_send[send_meta.transfer_id]
|
|
self.finished_sending_reqs.add(send_meta.p_req_id)
|
|
|
|
response = MooncakeXferResponse(
|
|
status=response_status,
|
|
ok_reqs=[d_req_id for d_req_id, _ in ready_reqs],
|
|
)
|
|
await sock.send_multipart((identity, self._encoder.encode(response)))
|
|
|
|
def resolve_need_send(self, send_meta: SendBlockMeta, remote_tp_ranks: list[int]):
|
|
# Prepare for heterogeneous TP (one P pairs to multiple D)
|
|
send_meta.need_send = len(remote_tp_ranks)
|
|
if send_meta.need_send != 1:
|
|
logger.error("Mooncake: Heterogeneous TP is not supported yet.")
|
|
raise NotImplementedError(
|
|
"Mooncake: Heterogeneous TP is not supported yet."
|
|
)
|
|
|
|
# ---------------- LAYERWISE worker methods ----------------
|
|
def lw_wait_for_save(self, metadata: "MooncakeConnectorMetadata"):
|
|
"""End-of-forward (main thread): record a CUDA event marking this
|
|
step's KV as in-HBM, and enqueue the step's progress for each active
|
|
producer. The ship loop (sender_loop) drains it."""
|
|
if not self._lw_enabled or not metadata.lw_send:
|
|
return
|
|
ev = torch.cuda.Event()
|
|
ev.record(torch.cuda.current_stream())
|
|
with self._lw_lock:
|
|
for req_id, lw in metadata.lw_send.items():
|
|
tid = lw.transfer_id
|
|
x = self._lw_xfers.get(tid)
|
|
if x is None:
|
|
x = {"pending": deque(), "dst_meta": None,
|
|
"remote_block_ids": None, "num_remote": None,
|
|
"d_req_id": None, "shipped": 0, "last_seen": False,
|
|
"p_req_id": lw.p_req_id}
|
|
self._lw_xfers[tid] = x
|
|
x["pending"].append((lw, ev))
|
|
if lw.is_last:
|
|
x["last_seen"] = True
|
|
|
|
async def _lw_send_kv(self, identity, sock, meta: MooncakeXferMetadata):
|
|
"""Handshake arrived: register dst addrs, then drain each transfer's
|
|
per-step queue, shipping computed blocks during prefill."""
|
|
remote_session = f"{meta.remote_hostname}:{meta.remote_port}"
|
|
targets = []
|
|
for d_req_id, (tid, remote_blocks) in meta.req_blocks.items():
|
|
with self._lw_lock:
|
|
x = self._lw_xfers.get(tid)
|
|
if x is None:
|
|
x = {"pending": deque(), "shipped": 0, "last_seen": False,
|
|
"p_req_id": None}
|
|
self._lw_xfers[tid] = x
|
|
x["dst_meta"] = meta
|
|
x["d_req_id"] = d_req_id
|
|
x["remote_block_ids"] = list(remote_blocks)
|
|
x["num_remote"] = len(remote_blocks)
|
|
targets.append((d_req_id, tid))
|
|
|
|
ok_reqs = []
|
|
deadline = time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
|
|
for d_req_id, tid in targets:
|
|
await self._lw_drain_one(tid, remote_session, deadline)
|
|
ok_reqs.append(d_req_id)
|
|
response = MooncakeXferResponse(
|
|
status=MooncakeXferResponseStatus.FINISH, ok_reqs=ok_reqs)
|
|
await sock.send_multipart((identity, self._encoder.encode(response)))
|
|
|
|
async def _lw_drain_one(self, tid, remote_session, deadline):
|
|
x = self._lw_xfers[tid]
|
|
num_remote = x["num_remote"]
|
|
bs = self.block_size
|
|
p_req_id = None
|
|
while True:
|
|
with self._lw_lock:
|
|
entry = x["pending"].popleft() if x["pending"] else None
|
|
last_seen = x["last_seen"]
|
|
if entry is None:
|
|
if x["shipped"] >= num_remote:
|
|
break
|
|
if time.perf_counter() > deadline:
|
|
logger.error("lw drain timeout %s shipped=%d/%d",
|
|
tid, x["shipped"], num_remote)
|
|
break
|
|
await asyncio.sleep(0.0005)
|
|
continue
|
|
lw, ev = entry
|
|
p_req_id = lw.p_req_id
|
|
ev.synchronize() # this step's KV is now in HBM
|
|
total = lw.total_blocks
|
|
start = total - num_remote
|
|
computed_blocks = total if lw.is_last else (lw.computed_tokens // bs)
|
|
ship_end = min(computed_blocks, total)
|
|
cursor = start + x["shipped"]
|
|
if ship_end > cursor and ship_end <= len(lw.local_block_ids):
|
|
local_slice = lw.local_block_ids[cursor:ship_end]
|
|
remote_slice = x["remote_block_ids"][
|
|
x["shipped"]: x["shipped"] + (ship_end - cursor)]
|
|
if (local_slice and remote_slice
|
|
and len(local_slice) == len(remote_slice)):
|
|
ret = await self.sender_loop.run_in_executor(
|
|
self._sender_executor, self._lw_build_send,
|
|
x["dst_meta"], local_slice, remote_slice, remote_session)
|
|
if ret != 0:
|
|
logger.error("lw send ret=%d tid=%s", ret, tid)
|
|
x["shipped"] += (ship_end - cursor)
|
|
if last_seen and x["shipped"] >= num_remote and not x["pending"]:
|
|
break
|
|
with self._lw_lock:
|
|
self._lw_xfers.pop(tid, None)
|
|
self.reqs_need_send.pop(tid, None)
|
|
if p_req_id is not None:
|
|
self.finished_sending_reqs.add(p_req_id)
|
|
logger.info("lw: transfer %s done (shipped %d/%d blocks)",
|
|
tid, x["shipped"], num_remote)
|
|
|
|
def _lw_build_send(self, dst_meta, local_slice, remote_slice, remote_session):
|
|
"""Ship local_slice -> remote_slice for ALL layers (one RDMA batch)."""
|
|
g_local, g_remote = group_concurrent_contiguous(local_slice, remote_slice)
|
|
block_len = self.block_len
|
|
src_ptrs: list[int] = []
|
|
dst_ptrs: list[int] = []
|
|
lengths: list[int] = []
|
|
for laddr, raddr in zip(self.kv_caches_base_addr,
|
|
dst_meta.kv_caches_base_addr):
|
|
for gl, gr in zip(g_local, g_remote):
|
|
src_ptrs.append(laddr + gl[0] * block_len)
|
|
dst_ptrs.append(raddr + gr[0] * block_len)
|
|
lengths.append(block_len * len(gl))
|
|
return self._send_blocks(remote_session, src_ptrs, dst_ptrs, lengths)
|
|
|
|
async def _build_transfer_params(
|
|
self,
|
|
ready_reqs: list[tuple[ReqId, SendBlockMeta]],
|
|
agent_meta: MooncakeXferMetadata,
|
|
) -> tuple[list[int], list[int], list[int], list[ReqId]]:
|
|
src_ptrs = []
|
|
dst_ptrs = []
|
|
lengths = []
|
|
err_reqs: list[ReqId] = []
|
|
local_base_addr = self.kv_caches_base_addr
|
|
remote_base_addr = agent_meta.kv_caches_base_addr
|
|
block_len = self.block_len
|
|
remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}"
|
|
|
|
for d_req_id, send_meta in ready_reqs:
|
|
_, remote_block_ids = agent_meta.req_blocks[d_req_id]
|
|
num_remote_blocks = len(remote_block_ids)
|
|
if num_remote_blocks == 0:
|
|
continue
|
|
|
|
local_block_ids = send_meta.local_block_ids
|
|
# Partial prefix cache hit: just read uncomputed blocks.
|
|
num_local_blocks = len(local_block_ids)
|
|
if num_local_blocks < num_remote_blocks:
|
|
logger.error(
|
|
"req %s: local blocks(%d) less than remote blocks(%d)!",
|
|
d_req_id,
|
|
num_local_blocks,
|
|
num_remote_blocks,
|
|
)
|
|
err_reqs.append(d_req_id)
|
|
continue
|
|
if num_local_blocks > num_remote_blocks:
|
|
local_block_ids = local_block_ids[-num_remote_blocks:]
|
|
|
|
# Group by indices
|
|
group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous(
|
|
local_block_ids, remote_block_ids
|
|
)
|
|
|
|
for local_layer_addr, remote_layer_addr in zip(
|
|
local_base_addr, remote_base_addr
|
|
):
|
|
for group_local_block_id, group_remote_block_id in zip(
|
|
group_local_block_ids, group_remote_block_ids
|
|
):
|
|
src_ptrs.append(
|
|
local_layer_addr + group_local_block_id[0] * block_len
|
|
)
|
|
dst_ptrs.append(
|
|
remote_layer_addr + group_remote_block_id[0] * block_len
|
|
)
|
|
lengths.append(block_len * len(group_local_block_id))
|
|
|
|
logger.debug(
|
|
"Sending kv_caches for request %s (%d blocks) to %s",
|
|
d_req_id,
|
|
num_remote_blocks,
|
|
remote_session,
|
|
)
|
|
|
|
return src_ptrs, dst_ptrs, lengths, err_reqs
|
|
|
|
def _send_blocks(
|
|
self,
|
|
remote_session: str,
|
|
src_ptrs: list[int],
|
|
dst_ptrs: list[int],
|
|
lengths: list[int],
|
|
) -> int:
|
|
start_time = time.perf_counter()
|
|
ret_value = self.engine.batch_transfer_sync_write(
|
|
remote_session, src_ptrs, dst_ptrs, lengths
|
|
)
|
|
if ret_value == 0:
|
|
logger.debug(
|
|
"Sending to %s done, took %s",
|
|
remote_session,
|
|
time.perf_counter() - start_time,
|
|
)
|
|
return ret_value
|
|
|
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
|
"""Register the KV Cache data in mooncake."""
|
|
|
|
logger.info("Registering KV_Caches. use_mla: %s", self.use_mla)
|
|
|
|
kv_data_ptrs = []
|
|
kv_data_lens = []
|
|
seen_base_addresses = []
|
|
|
|
split_k_and_v = self.kv_topo.split_k_and_v
|
|
tensor_size_bytes = None
|
|
for layer_name, cache_or_caches in kv_caches.items():
|
|
logger.debug(
|
|
"registering layer %s with shape %s", layer_name, cache_or_caches.shape
|
|
)
|
|
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
|
|
|
for cache in cache_list:
|
|
base_addr = cache.data_ptr()
|
|
if base_addr in seen_base_addresses:
|
|
continue
|
|
|
|
seen_base_addresses.append(base_addr)
|
|
curr_tensor_size_bytes = cache.nbytes
|
|
|
|
if tensor_size_bytes is None:
|
|
tensor_size_bytes = curr_tensor_size_bytes
|
|
self.num_blocks = cache.shape[0]
|
|
|
|
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
|
"All kv cache tensors must have the same size"
|
|
)
|
|
kernel_block_size = cache.shape[-2 if self.use_mla else -3]
|
|
assert self.block_size == kernel_block_size
|
|
kv_data_ptrs.append(base_addr)
|
|
kv_data_lens.append(tensor_size_bytes)
|
|
|
|
self.kv_caches_base_addr = seen_base_addresses
|
|
|
|
ret_value = self.engine.batch_register_memory(kv_data_ptrs, kv_data_lens)
|
|
if ret_value != 0:
|
|
raise RuntimeError("Mooncake batch memory registration failed.")
|
|
|
|
assert tensor_size_bytes is not None
|
|
assert self.num_blocks != 0
|
|
assert tensor_size_bytes % self.num_blocks == 0
|
|
self.block_len = tensor_size_bytes // self.num_blocks
|
|
self.device_kv_caches = kv_caches
|
|
logger.debug(
|
|
"registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len
|
|
)
|
|
|
|
if self.is_kv_consumer:
|
|
return
|
|
|
|
ready_event = threading.Event()
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._mooncake_sender_listener(ready_event), self.sender_loop
|
|
)
|
|
ready_event.wait()
|
|
|
|
if self.bootstrap_server is not None:
|
|
self.bootstrap_server.set_worker_kv_info(
|
|
self.kv_caches_base_addr, self.block_len,
|
|
self.block_size, self.hostname, self.rpc_port,
|
|
transfer_engine=self.engine,
|
|
)
|
|
if _shared_block_pool is not None:
|
|
self.bootstrap_server.set_block_pool(_shared_block_pool)
|
|
|
|
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
|
|
finished_recving_reqs = self.finished_recving_reqs
|
|
self.finished_recving_reqs = set()
|
|
return finished_recving_reqs
|
|
|
|
def get_block_ids_with_load_errors(self) -> set[int]:
|
|
failed = self.failed_recving_block_ids
|
|
self.failed_recving_block_ids = set()
|
|
return failed
|
|
|
|
async def fetch_finished_sending_reqs(self) -> set[ReqId]:
|
|
finished_sending_reqs = self.finished_sending_reqs
|
|
self.finished_sending_reqs = set()
|
|
|
|
# Handle timeout to avoid stranding blocks on remote.
|
|
now = time.perf_counter()
|
|
|
|
expired_transfer_id = []
|
|
for transfer_id, send_meta in self.reqs_need_send.items():
|
|
if (
|
|
send_meta.p_req_id
|
|
and send_meta.expire_time < now
|
|
and send_meta.sending == 0
|
|
):
|
|
logger.warning(
|
|
"Request %s timed out after %d seconds without "
|
|
"being sent. Freeing its blocks on the producer side.",
|
|
send_meta.p_req_id,
|
|
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
|
|
)
|
|
finished_sending_reqs.add(send_meta.p_req_id)
|
|
expired_transfer_id.append(transfer_id)
|
|
|
|
for transfer_id in expired_transfer_id:
|
|
del self.reqs_need_send[transfer_id]
|
|
|
|
return finished_sending_reqs
|
|
|
|
def get_finished(self) -> tuple[set[str] | None, set[str] | None]:
|
|
"""
|
|
Get requests that are done sending or recving on this specific worker.
|
|
The scheduler process (via the MultiprocExecutor) will use this output
|
|
to track which workers are done.
|
|
"""
|
|
recv_fut = None
|
|
send_fut = None
|
|
if not self.is_kv_producer:
|
|
recv_fut = asyncio.run_coroutine_threadsafe(
|
|
self.fetch_finished_recving_reqs(), self.receiver_loop
|
|
)
|
|
|
|
if not self.is_kv_consumer:
|
|
send_fut = asyncio.run_coroutine_threadsafe(
|
|
self.fetch_finished_sending_reqs(), self.sender_loop
|
|
)
|
|
|
|
finished_recving_reqs = recv_fut.result() if recv_fut else set()
|
|
finished_sending_reqs = send_fut.result() if send_fut else set()
|
|
|
|
if finished_sending_reqs or finished_recving_reqs:
|
|
logger.debug(
|
|
"Rank %s, get_finished: %s requests done sending "
|
|
"and %s requests done recving",
|
|
self.tp_rank,
|
|
len(finished_sending_reqs),
|
|
len(finished_recving_reqs),
|
|
)
|
|
|
|
return finished_sending_reqs or None, finished_recving_reqs or None
|
|
|
|
async def receive_kv_from_single_worker(
|
|
self,
|
|
worker_addr: str,
|
|
pull_metas: dict[ReqId, PullReqMeta],
|
|
):
|
|
req_ids = set(pull_metas)
|
|
metadata = MooncakeXferMetadata(
|
|
remote_hostname=self.hostname,
|
|
remote_port=self.rpc_port,
|
|
remote_tp_size=self.tp_size,
|
|
remote_tp_rank=self.tp_rank,
|
|
req_blocks={
|
|
req_id: (pull_meta.transfer_id, pull_meta.local_block_ids)
|
|
for req_id, pull_meta in pull_metas.items()
|
|
},
|
|
kv_caches_base_addr=self.kv_caches_base_addr,
|
|
)
|
|
|
|
encoded_data = self._encoder.encode(metadata)
|
|
logger.debug(
|
|
"Size of encoded MooncakeXferMetadata: %d bytes", len(encoded_data)
|
|
)
|
|
logger.debug(
|
|
"Sending kv transfer request for %s on path: %s", req_ids, worker_addr
|
|
)
|
|
|
|
# Send query for the request.
|
|
try:
|
|
with make_zmq_socket(
|
|
self.async_zmq_ctx, worker_addr, zmq.DEALER, bind=False, linger=0
|
|
) as sock:
|
|
# If something goes wrong, let P wait timeout first (in asyncio.wait()).
|
|
sock.setsockopt(
|
|
zmq.RCVTIMEO, (envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT + 60) * 1000
|
|
)
|
|
await sock.send(encoded_data)
|
|
while True:
|
|
ret_msg = await sock.recv()
|
|
response = self._xfer_resp_decoder.decode(ret_msg)
|
|
if response.status == MooncakeXferResponseStatus.ERROR:
|
|
logger.error(
|
|
"Error happens during transferring kvcache for %s: %s",
|
|
req_ids,
|
|
response.err_msg,
|
|
)
|
|
return
|
|
self.process_pulling_result(response, pull_metas)
|
|
if response.status == MooncakeXferResponseStatus.FINISH:
|
|
break
|
|
except zmq.ContextTerminated:
|
|
logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.")
|
|
except Exception as e:
|
|
logger.error("MooncakeXferMetadata transfer failed for %s: %s", req_ids, e)
|
|
for req_id in req_ids:
|
|
pull_meta = pull_metas[req_id]
|
|
self.failed_recving_block_ids.update(pull_meta.local_block_ids)
|
|
self.finished_recving_reqs.add(pull_meta.d_req_id)
|
|
return
|
|
|
|
def process_pulling_result(
|
|
self,
|
|
response: MooncakeXferResponse,
|
|
pull_metas: dict[ReqId, PullReqMeta],
|
|
):
|
|
ok_reqs: list[ReqId] = response.ok_reqs or []
|
|
|
|
for req_id in ok_reqs:
|
|
pull_meta = pull_metas[req_id]
|
|
# No race because we are in async loop.
|
|
pull_meta.pull_tasks_count -= 1
|
|
if pull_meta.pull_tasks_count == 0:
|
|
self.finished_recving_reqs.add(pull_meta.d_req_id)
|
|
|
|
if ok_reqs:
|
|
logger.debug("pulling kv_caches for %s finished", ok_reqs)
|
|
|
|
if response.err_reqs:
|
|
logger.error(
|
|
"pulling kv_caches for %s failed: %s",
|
|
response.err_reqs,
|
|
response.err_msg,
|
|
)
|
|
for req_id in response.err_reqs:
|
|
pull_meta = pull_metas.get(req_id)
|
|
if pull_meta is None:
|
|
continue
|
|
self.failed_recving_block_ids.update(pull_meta.local_block_ids)
|
|
self.finished_recving_reqs.add(pull_meta.d_req_id)
|
|
|
|
async def _connect_to_prefiller_bootstrap(self, remote_bootstrap_addr: str):
|
|
url = remote_bootstrap_addr + "/query"
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(url)
|
|
response.raise_for_status()
|
|
data: dict = response.json()
|
|
for _, dp_entry in data.items():
|
|
remote_engine_id = dp_entry["engine_id"]
|
|
self._remote_agents[remote_engine_id] = {
|
|
int(tp_rank): {
|
|
int(pp_rank): worker_addr
|
|
for pp_rank, worker_addr in tp_entry.items()
|
|
}
|
|
for tp_rank, tp_entry in dp_entry["worker_addr"].items()
|
|
}
|
|
self._tp_size[remote_engine_id] = len(dp_entry["worker_addr"])
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to connect to bootstrap server %s: %s",
|
|
remote_bootstrap_addr,
|
|
e,
|
|
)
|
|
|
|
# Always notify others regardless of connection success or failure.
|
|
self._pending_bootstrap_queries[remote_bootstrap_addr].set()
|
|
del self._pending_bootstrap_queries[remote_bootstrap_addr]
|
|
|
|
def receive_kv(
|
|
self,
|
|
remote_engine_id: EngineId,
|
|
pull_metas: dict[ReqId, PullReqMeta],
|
|
):
|
|
remote_tp_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
|
|
remote_engine_id
|
|
)
|
|
count = len(remote_tp_ranks)
|
|
if count != 1:
|
|
logger.error("Mooncake: Heterogeneous TP is not supported yet.")
|
|
raise NotImplementedError(
|
|
"Mooncake: Heterogeneous TP is not supported yet."
|
|
)
|
|
for pull_meta in pull_metas.values():
|
|
pull_meta.pull_tasks_count = count
|
|
for remote_tp_rank in remote_tp_ranks:
|
|
worker_addr = self._remote_agents[remote_engine_id][remote_tp_rank][0]
|
|
asyncio.create_task(
|
|
self.receive_kv_from_single_worker(worker_addr, pull_metas)
|
|
)
|
|
|
|
async def handle_new_engine_id(
|
|
self,
|
|
remote_engine_id: EngineId,
|
|
pull_metas: dict[ReqId, PullReqMeta],
|
|
):
|
|
remote_bootstrap_addr = next(iter(pull_metas.values())).remote_bootstrap_addr
|
|
if remote_bootstrap_addr not in self._pending_bootstrap_queries:
|
|
self._pending_bootstrap_queries[remote_bootstrap_addr] = asyncio.Event()
|
|
await self._connect_to_prefiller_bootstrap(remote_bootstrap_addr)
|
|
else:
|
|
await self._pending_bootstrap_queries[remote_bootstrap_addr].wait()
|
|
|
|
if remote_engine_id not in self._remote_agents:
|
|
logger.error(
|
|
"Failed to find remote engine_id %s from bootstrap server %s",
|
|
remote_engine_id,
|
|
remote_bootstrap_addr,
|
|
)
|
|
return
|
|
|
|
self.receive_kv(remote_engine_id, pull_metas)
|
|
|
|
async def _start_direct_read(
|
|
self, reqs_by_engine: dict[EngineId, dict[ReqId, PullReqMeta]]
|
|
):
|
|
"""Direct RDMA read: D reads cached KV blocks from C's GPU memory
|
|
without involving C's scheduler.
|
|
"""
|
|
for _engine_id, pull_metas in reqs_by_engine.items():
|
|
for req_id, pm in pull_metas.items():
|
|
asyncio.create_task(
|
|
self._direct_read_single(req_id, pm)
|
|
)
|
|
|
|
async def _direct_read_single(self, req_id: ReqId, pm: PullReqMeta):
|
|
"""Bootstrap-triggered PUSH: D asks C's bootstrap to push matched blocks.
|
|
|
|
C's bootstrap looks up cached blocks by token_ids, then uses C's
|
|
TransferEngine to RDMA WRITE (push) them directly into D's GPU memory.
|
|
C's scheduler is NOT involved.
|
|
"""
|
|
bootstrap_url = pm.remote_bootstrap_addr
|
|
num_remote_tokens = pm.remote_num_tokens or len(pm.prompt_token_ids)
|
|
|
|
try:
|
|
local_block_ids = pm.local_block_ids
|
|
d_session = f"{self.hostname}:{self.rpc_port}"
|
|
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
resp = await client.post(
|
|
f"{bootstrap_url}/push_blocks",
|
|
json={
|
|
"token_ids": pm.prompt_token_ids,
|
|
"num_tokens": num_remote_tokens,
|
|
"dst_block_ids": local_block_ids,
|
|
"dst_base_addrs": self.kv_caches_base_addr,
|
|
"dst_block_len": self.block_len,
|
|
"dst_session": d_session,
|
|
},
|
|
)
|
|
resp.raise_for_status()
|
|
result = resp.json()
|
|
|
|
matched = result.get("matched", 0)
|
|
pushed = result.get("pushed", False)
|
|
|
|
if matched > 0 and pushed:
|
|
logger.info("direct_push %s: %d blocks pushed from C", req_id, matched)
|
|
else:
|
|
logger.debug("direct_push %s: %d matched, pushed=%s", req_id, matched, pushed)
|
|
self.failed_recving_block_ids.update(local_block_ids)
|
|
|
|
self.finished_recving_reqs.add(req_id)
|
|
|
|
except Exception as e:
|
|
logger.error("direct_push %s failed: %s", req_id, e)
|
|
self.failed_recving_block_ids.update(pm.local_block_ids)
|
|
self.finished_recving_reqs.add(req_id)
|
|
|
|
async def _start_load_kv(
|
|
self, reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]]
|
|
):
|
|
for remote_engine_id, pull_metas in reqs_to_recv.items():
|
|
if remote_engine_id not in self._remote_agents:
|
|
asyncio.create_task(
|
|
self.handle_new_engine_id(remote_engine_id, pull_metas)
|
|
)
|
|
else:
|
|
self.receive_kv(remote_engine_id, pull_metas)
|
|
|
|
async def record_send_reqs(self, metadata: MooncakeConnectorMetadata):
|
|
for p_req_id, (transfer_id, block_ids) in metadata.reqs_to_send.items():
|
|
if block_ids:
|
|
# Already gone through request_finished()
|
|
send_meta = self.reqs_need_send[transfer_id]
|
|
send_meta.p_req_id = p_req_id
|
|
send_meta.local_block_ids = block_ids
|
|
send_meta.expire_time = (
|
|
time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
|
|
)
|
|
send_meta.ready.set()
|
|
else:
|
|
# From update_state_after_alloc(),
|
|
# but not reach request_finished() yet
|
|
# This may be already created by send_kv_to_decode()
|
|
# when D is sending MooncakeXferMetadata.
|
|
if transfer_id not in self.reqs_need_send:
|
|
self.reqs_need_send[transfer_id] = SendBlockMeta(
|
|
p_req_id=p_req_id,
|
|
transfer_id=transfer_id,
|
|
local_block_ids=[],
|
|
ready=asyncio.Event(),
|
|
)
|
|
for transfer_id in metadata.reqs_not_processed:
|
|
send_meta = self.reqs_need_send.pop(transfer_id)
|
|
if send_meta:
|
|
assert not send_meta.ready.is_set()
|
|
|
|
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
|
# Sync hash table to bootstrap server (for direct RDMA read queries)
|
|
if self.bootstrap_server is not None and (
|
|
metadata.hash_table_updates or metadata.hash_table_removals
|
|
):
|
|
self.bootstrap_server.update_hash_table(
|
|
metadata.hash_table_updates, metadata.hash_table_removals
|
|
)
|
|
|
|
if not self.is_kv_producer and metadata.reqs_to_recv:
|
|
# Split direct_read vs normal pull requests
|
|
direct_reqs: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
|
|
normal_reqs: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
|
|
for engine_id, pull_metas in metadata.reqs_to_recv.items():
|
|
for req_id, pm in pull_metas.items():
|
|
if pm.direct_read:
|
|
direct_reqs[engine_id][req_id] = pm
|
|
else:
|
|
normal_reqs[engine_id][req_id] = pm
|
|
|
|
if normal_reqs:
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._start_load_kv(normal_reqs), self.receiver_loop
|
|
)
|
|
if direct_reqs:
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._start_direct_read(direct_reqs), self.receiver_loop
|
|
)
|
|
|
|
if not self.is_kv_consumer and (
|
|
metadata.reqs_to_send or metadata.reqs_not_processed
|
|
):
|
|
asyncio.run_coroutine_threadsafe(
|
|
self.record_send_reqs(metadata), self.sender_loop
|
|
)
|
|
|
|
|
|
def group_concurrent_contiguous(
|
|
src_indices: list[int], dst_indices: list[int]
|
|
) -> tuple[list[list[int]], list[list[int]]]:
|
|
"""Vectorised NumPy implementation."""
|
|
if len(src_indices) == 0:
|
|
return [], []
|
|
|
|
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
|
src_groups = np.split(src_indices, brk)
|
|
dst_groups = np.split(dst_indices, brk)
|
|
|
|
src_groups = [g.tolist() for g in src_groups]
|
|
dst_groups = [g.tolist() for g in dst_groups]
|
|
|
|
return src_groups, dst_groups
|
|
|
|
|
|
def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int:
|
|
# This logic is now centralized
|
|
return (
|
|
envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
|
|
+ vllm_config.parallel_config.data_parallel_index
|
|
* vllm_config.parallel_config.tensor_parallel_size
|
|
)
|
|
|
|
|
|
def _async_loop(loop: asyncio.AbstractEventLoop):
|
|
asyncio.set_event_loop(loop)
|
|
loop.run_forever()
|
|
|
|
|
|
def should_launch_bootstrap_server(vllm_config: VllmConfig) -> bool:
|
|
assert (parallel_config := vllm_config.parallel_config)
|
|
# In hybrid or external LB mode,
|
|
# each instance should have its own bootstrap server.
|
|
#
|
|
# In internal LB mode,
|
|
# only the real global first rank need to launch the bootstrap server.
|
|
return is_local_first_rank() and (
|
|
parallel_config.local_engines_only or parallel_config.data_parallel_index == 0
|
|
)
|
|
|
|
|
|
def get_mooncake_bootstrap_addr(vllm_config: VllmConfig) -> tuple[str, int]:
|
|
"""
|
|
Returns the address of the Mooncake bootstrap server.
|
|
This is only used by prefillers to register workers.
|
|
Decoders should get addr from kv_transfer_params.
|
|
"""
|
|
assert (parallel_config := vllm_config.parallel_config)
|
|
if parallel_config.local_engines_only:
|
|
# In hybrid or external LB mode, connect to local server.
|
|
host = "127.0.0.1"
|
|
else:
|
|
host = parallel_config.data_parallel_master_ip
|
|
port = envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
|
|
return (host, port)
|