Scheduler tracks per-producer block_ids (accumulated from scheduler_output) and emits per-step LWSendMeta with cumulative computed_tokens. Worker lw_wait_for_save records a CUDA event per step and enqueues progress; the sender-loop ship loop drains it, shipping only computed+dst-wanted+unshipped blocks in order (correct under chunked prefill). Per-transfer state = concurrent-safe. Keeps v1 single-transfer version as reference. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1686 lines
68 KiB
Python
1686 lines
68 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import asyncio
|
|
import os
|
|
import threading
|
|
import time
|
|
from collections import defaultdict
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from dataclasses import dataclass, field
|
|
from enum import IntEnum
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import httpx
|
|
import msgspec
|
|
import numpy as np
|
|
import torch
|
|
import zmq
|
|
import zmq.asyncio
|
|
|
|
from vllm import envs
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
|
EngineId,
|
|
TpKVTopology,
|
|
get_current_attn_backend,
|
|
)
|
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|
KVConnectorBase_V1,
|
|
KVConnectorMetadata,
|
|
KVConnectorRole,
|
|
)
|
|
from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_utils import (
|
|
MooncakeBootstrapServer,
|
|
RegisterWorkerPayload,
|
|
)
|
|
from vllm.distributed.parallel_state import (
|
|
get_pp_group,
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
is_local_first_rank,
|
|
)
|
|
from vllm.forward_context import ForwardContext
|
|
from vllm.logger import init_logger
|
|
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
|
from vllm.v1.attention.backend import AttentionMetadata
|
|
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.request import RequestStatus
|
|
|
|
try:
|
|
from mooncake.engine import TransferEngine
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Please install mooncake by following the instructions at "
|
|
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md "
|
|
"to run VLLM with MooncakeTransferEngine."
|
|
) from e
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
from vllm.v1.request import Request
|
|
|
|
ReqId = str # Internal scheduler request ID
|
|
TransferId = str # KV transfer coordination ID (shared by P/D)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# Module-level block pool for bootstrap server access (kv_both same-process only)
|
|
_shared_block_pool = None
|
|
|
|
def _set_shared_block_pool(bp):
|
|
global _shared_block_pool
|
|
_shared_block_pool = bp
|
|
|
|
|
|
class MooncakeXferMetadata(
|
|
msgspec.Struct,
|
|
omit_defaults=True, # type: ignore[call-arg]
|
|
):
|
|
remote_hostname: str
|
|
remote_port: int
|
|
remote_tp_size: int
|
|
remote_tp_rank: int
|
|
req_blocks: dict[ReqId, tuple[TransferId, list[int]]]
|
|
kv_caches_base_addr: list[int]
|
|
|
|
|
|
class MooncakeXferResponseStatus(IntEnum):
|
|
# Transfer finished
|
|
FINISH = 0
|
|
# Continue to receive
|
|
CONTINUE = 1
|
|
# Something wrong, see err_msg
|
|
ERROR = 2
|
|
|
|
|
|
class MooncakeXferResponse(
|
|
msgspec.Struct,
|
|
omit_defaults=True, # type: ignore[call-arg]
|
|
):
|
|
status: MooncakeXferResponseStatus
|
|
ok_reqs: list[ReqId] | None = None
|
|
err_reqs: list[ReqId] | None = None
|
|
err_msg: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class PullReqMeta:
|
|
d_req_id: ReqId
|
|
transfer_id: TransferId
|
|
local_block_ids: list[int]
|
|
remote_engine_id: EngineId
|
|
remote_bootstrap_addr: str
|
|
# Set expire time to avoid infinitely sending requests.
|
|
expire_time: float = float("inf")
|
|
# Designed for one D pairing to multiple P
|
|
pull_tasks_count: int = 0
|
|
# Direct RDMA read: D reads from C's GPU memory without C's scheduler
|
|
direct_read: bool = False
|
|
block_hashes: list[bytes] = field(default_factory=list)
|
|
prompt_token_ids: list[int] = field(default_factory=list)
|
|
remote_num_tokens: int = 0
|
|
|
|
|
|
@dataclass
|
|
class SendBlockMeta:
|
|
p_req_id: ReqId
|
|
transfer_id: TransferId
|
|
local_block_ids: list[int]
|
|
ready: asyncio.Event
|
|
expire_time: float = float("inf")
|
|
need_send: int = 0
|
|
sent: int = 0
|
|
sending: int = 0
|
|
|
|
|
|
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
|
def __init__(self):
|
|
self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
|
|
self.reqs_to_send: dict[ReqId, tuple[TransferId, list[int]]] = {}
|
|
self.reqs_not_processed: set[TransferId] = set()
|
|
# Hash table sync: scheduler → worker (for direct RDMA read)
|
|
self.hash_table_updates: dict[str, int] = {} # hex hash → block_id
|
|
self.hash_table_removals: set[str] = set()
|
|
self.token_hash_updates: dict[str, int] = {} # str(hash(tokens)) → block_id
|
|
|
|
def add_new_req(
|
|
self,
|
|
request_id: ReqId,
|
|
local_block_ids: list[int],
|
|
kv_transfer_params: dict[str, Any],
|
|
load_remote_cache: bool = True,
|
|
block_hashes: list[bytes] | None = None,
|
|
prompt_token_ids: list[int] | None = None,
|
|
):
|
|
transfer_id = kv_transfer_params["transfer_id"]
|
|
if load_remote_cache:
|
|
remote_engine_id = kv_transfer_params["remote_engine_id"]
|
|
remote_num = kv_transfer_params.get("remote_num_tokens", 0)
|
|
self.reqs_to_recv[remote_engine_id][request_id] = PullReqMeta(
|
|
d_req_id=request_id,
|
|
local_block_ids=local_block_ids,
|
|
remote_engine_id=remote_engine_id,
|
|
remote_bootstrap_addr=kv_transfer_params["remote_bootstrap_addr"],
|
|
transfer_id=transfer_id,
|
|
direct_read=bool(kv_transfer_params.get("direct_read")),
|
|
block_hashes=block_hashes or [],
|
|
prompt_token_ids=prompt_token_ids or [],
|
|
remote_num_tokens=remote_num,
|
|
)
|
|
else:
|
|
self.reqs_to_send[request_id] = (transfer_id, local_block_ids)
|
|
|
|
|
|
class MooncakeConnector(KVConnectorBase_V1):
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
role: KVConnectorRole,
|
|
kv_cache_config: "KVCacheConfig | None" = None,
|
|
):
|
|
super().__init__(vllm_config, role, kv_cache_config)
|
|
|
|
assert vllm_config.kv_transfer_config is not None
|
|
assert vllm_config.kv_transfer_config.engine_id is not None
|
|
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
|
|
|
if role == KVConnectorRole.SCHEDULER:
|
|
self.connector_scheduler: MooncakeConnectorScheduler | None = (
|
|
MooncakeConnectorScheduler(vllm_config, self.engine_id)
|
|
)
|
|
self.connector_worker: MooncakeConnectorWorker | None = None
|
|
elif role == KVConnectorRole.WORKER:
|
|
self.connector_scheduler = None
|
|
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
|
|
|
|
def set_block_pool(self, block_pool):
|
|
if self.connector_scheduler is not None:
|
|
self.connector_scheduler.set_block_pool(block_pool)
|
|
# Also store module-level for bootstrap server access (same process for kv_both TP=1)
|
|
_set_shared_block_pool(block_pool)
|
|
|
|
############################################################
|
|
# Scheduler Side Methods
|
|
############################################################
|
|
|
|
def get_num_new_matched_tokens(
|
|
self, request: "Request", num_computed_tokens: int
|
|
) -> tuple[int, bool]:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.get_num_new_matched_tokens(
|
|
request, num_computed_tokens
|
|
)
|
|
|
|
def update_state_after_alloc(
|
|
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
|
):
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.update_state_after_alloc(
|
|
request, blocks, num_external_tokens
|
|
)
|
|
|
|
def build_connector_meta(
|
|
self,
|
|
scheduler_output: SchedulerOutput,
|
|
) -> KVConnectorMetadata:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
|
|
|
def request_finished(
|
|
self,
|
|
request: "Request",
|
|
block_ids: list[int],
|
|
) -> tuple[bool, dict[str, Any] | None]:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.request_finished(request, block_ids)
|
|
|
|
############################################################
|
|
# Worker Side Methods
|
|
############################################################
|
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
|
assert self.connector_worker is not None
|
|
self.connector_worker.register_kv_caches(kv_caches)
|
|
|
|
def get_finished(
|
|
self, finished_req_ids: set[str]
|
|
) -> tuple[set[str] | None, set[str] | None]:
|
|
"""Get the finished recving and sending requests."""
|
|
assert self.connector_worker is not None
|
|
return self.connector_worker.get_finished()
|
|
|
|
def get_block_ids_with_load_errors(self) -> set[int]:
|
|
assert self.connector_worker is not None
|
|
return self.connector_worker.get_block_ids_with_load_errors()
|
|
|
|
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
|
assert self.connector_worker is not None
|
|
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
|
|
self.connector_worker.start_load_kv(self._connector_metadata)
|
|
|
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
|
"""MooncakeConnector does not do layerwise saving."""
|
|
pass
|
|
|
|
def save_kv_layer(
|
|
self,
|
|
layer_name: str,
|
|
kv_layer: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
**kwargs,
|
|
) -> None:
|
|
"""LAYERWISE: record that this layer's KV is computed so the sender
|
|
can push it during prefill. No-op unless MOONCAKE_LAYERWISE=1."""
|
|
if self.connector_worker is not None:
|
|
self.connector_worker.note_layer_computed(layer_name)
|
|
|
|
def wait_for_save(self):
|
|
pass
|
|
|
|
|
|
class MooncakeConnectorScheduler:
|
|
"""Implementation of Scheduler side methods"""
|
|
|
|
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
|
self.vllm_config = vllm_config
|
|
self._block_pool = None
|
|
self._known_hash_keys: set = set()
|
|
|
|
assert vllm_config.kv_transfer_config
|
|
self.is_kv_producer: bool = (
|
|
vllm_config.kv_transfer_config.kv_role == "kv_producer"
|
|
)
|
|
self.is_kv_consumer: bool = (
|
|
vllm_config.kv_transfer_config.kv_role == "kv_consumer"
|
|
)
|
|
logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id)
|
|
|
|
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
|
self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {}
|
|
self._reqs_not_processed: set[TransferId] = set()
|
|
self._req_block_hashes: dict[ReqId, list[bytes]] = {}
|
|
self._req_token_ids: dict[ReqId, list[int]] = {}
|
|
# LAYERWISE: capture producer block_ids at alloc (before prefill done).
|
|
self._lw_enabled = os.environ.get("MOONCAKE_LAYERWISE", "0") == "1"
|
|
self._lw_sent_once: set[ReqId] = set()
|
|
|
|
def set_block_pool(self, block_pool):
|
|
self._block_pool = block_pool
|
|
|
|
def get_num_new_matched_tokens(
|
|
self, request: "Request", num_computed_tokens: int
|
|
) -> tuple[int, bool]:
|
|
"""
|
|
For remote prefill, pull all prompt blocks from remote
|
|
asynchronously relative to engine execution.
|
|
|
|
Args:
|
|
request (Request): the request object.
|
|
num_computed_tokens (int): the number of locally
|
|
computed tokens for this request
|
|
Returns:
|
|
* the number of tokens that can be loaded from the
|
|
external KV cache beyond what is already computed.
|
|
* true if the external KV cache tokens will be loaded
|
|
asynchronously (between scheduler steps).
|
|
"""
|
|
|
|
params = request.kv_transfer_params
|
|
logger.debug(
|
|
"MooncakeConnector get_num_new_matched_tokens: "
|
|
"num_computed_tokens=%s, kv_transfer_params=%s",
|
|
num_computed_tokens,
|
|
params,
|
|
)
|
|
|
|
if not params:
|
|
return 0, False
|
|
|
|
if params.get("do_remote_prefill"):
|
|
assert not self.is_kv_producer
|
|
token_ids = request.prompt_token_ids or []
|
|
# Partial remote prefill: only pull remote_num_tokens from remote,
|
|
# compute the rest locally. Falls back to full remote prefill
|
|
# when remote_num_tokens is not set.
|
|
remote_total = params.get("remote_num_tokens", len(token_ids))
|
|
remote_total = min(remote_total, len(token_ids))
|
|
count = max(0, remote_total - num_computed_tokens)
|
|
if count > 0:
|
|
return count, True
|
|
|
|
return 0, False
|
|
|
|
def update_state_after_alloc(
|
|
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
|
):
|
|
params = request.kv_transfer_params
|
|
logger.debug(
|
|
"MooncakeConnector update_state_after_alloc: "
|
|
"req_id=%s num_external_tokens=%s, kv_transfer_params=%s",
|
|
request.request_id,
|
|
num_external_tokens,
|
|
params,
|
|
)
|
|
|
|
if not params:
|
|
return
|
|
|
|
if params.get("do_remote_prefill"):
|
|
assert not self.is_kv_producer
|
|
if all(
|
|
p in params
|
|
for p in ("remote_engine_id", "remote_bootstrap_addr", "transfer_id")
|
|
):
|
|
if num_external_tokens > 0:
|
|
all_unhashed = blocks.get_unhashed_block_ids()
|
|
# Partial remote prefill: only receive blocks for the
|
|
# external portion, leave the rest for local compute.
|
|
if params.get("remote_num_tokens") is not None:
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
num_remote_blocks = (
|
|
(num_external_tokens + block_size - 1) // block_size
|
|
)
|
|
local_block_ids = all_unhashed[:num_remote_blocks]
|
|
else:
|
|
local_block_ids = all_unhashed
|
|
else:
|
|
local_block_ids = []
|
|
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
|
|
if params.get("direct_read"):
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
num_remote_blocks = (
|
|
(num_external_tokens + block_size - 1) // block_size
|
|
)
|
|
if hasattr(request, "block_hashes"):
|
|
self._req_block_hashes[request.request_id] = [
|
|
bytes(h) for h in request.block_hashes[:num_remote_blocks]
|
|
]
|
|
# Store prompt token_ids for token-based lookup on C
|
|
if hasattr(request, "prompt_token_ids") and request.prompt_token_ids:
|
|
self._req_token_ids[request.request_id] = list(
|
|
request.prompt_token_ids[:num_remote_blocks * block_size]
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Got invalid KVTransferParams: %s. This "
|
|
"request will not utilize KVTransfer",
|
|
params,
|
|
)
|
|
params["do_remote_prefill"] = False
|
|
|
|
if params.get("do_remote_decode"):
|
|
assert not self.is_kv_consumer
|
|
if not params.get("transfer_id"):
|
|
logger.warning("Missing transfer_id in kv_transfer_params from router!")
|
|
elif self._lw_enabled:
|
|
# LAYERWISE: capture the producer block_ids NOW (at alloc), so
|
|
# the worker learns local block_ids and sets `ready` before
|
|
# prefill finishes — enabling per-layer writes during prefill.
|
|
try:
|
|
block_groups = blocks.get_block_ids()
|
|
local_block_ids = list(block_groups[0]) if block_groups else []
|
|
except Exception as e:
|
|
logger.warning("LAYERWISE: failed to get block_ids at alloc: %s", e)
|
|
local_block_ids = []
|
|
self._reqs_need_send[request.request_id] = (request, local_block_ids)
|
|
else:
|
|
# Add an empty list to worker to create event.
|
|
self._reqs_need_send[request.request_id] = (request, [])
|
|
|
|
def build_connector_meta(
|
|
self,
|
|
scheduler_output: SchedulerOutput,
|
|
) -> KVConnectorMetadata:
|
|
meta = MooncakeConnectorMetadata()
|
|
|
|
if not self.is_kv_producer:
|
|
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
|
assert req.kv_transfer_params is not None
|
|
meta.add_new_req(
|
|
request_id=req_id,
|
|
local_block_ids=block_ids,
|
|
kv_transfer_params=req.kv_transfer_params,
|
|
block_hashes=self._req_block_hashes.pop(req_id, None),
|
|
prompt_token_ids=self._req_token_ids.pop(req_id, None),
|
|
)
|
|
self._reqs_need_recv.clear()
|
|
|
|
# Sync hash table to worker for direct RDMA read block lookups
|
|
if self._block_pool is not None:
|
|
cache = self._block_pool.cached_block_hash_to_block._cache
|
|
current_keys = set(cache.keys())
|
|
new_keys = current_keys - self._known_hash_keys
|
|
removed_keys = self._known_hash_keys - current_keys
|
|
if new_keys or removed_keys:
|
|
from vllm.v1.core.kv_cache_utils import get_block_hash
|
|
for k in new_keys:
|
|
block = cache[k]
|
|
if isinstance(block, dict):
|
|
bid = next(iter(block.values())).block_id
|
|
else:
|
|
bid = block.block_id
|
|
meta.hash_table_updates[get_block_hash(k).hex()] = bid
|
|
meta.hash_table_removals = {
|
|
get_block_hash(k).hex() for k in removed_keys
|
|
}
|
|
self._known_hash_keys = current_keys.copy()
|
|
logger.info("hash_table_sync: +%d -%d (total known=%d)",
|
|
len(new_keys), len(removed_keys), len(self._known_hash_keys))
|
|
else:
|
|
if not hasattr(self, '_bp_warned'):
|
|
logger.warning("_block_pool is None, hash table sync disabled")
|
|
self._bp_warned = True
|
|
|
|
if not self.is_kv_consumer:
|
|
for req_id, (req, block_ids) in self._reqs_need_send.items():
|
|
assert req.kv_transfer_params is not None
|
|
meta.add_new_req(
|
|
request_id=req_id,
|
|
local_block_ids=block_ids,
|
|
kv_transfer_params=req.kv_transfer_params,
|
|
load_remote_cache=False,
|
|
)
|
|
self._reqs_need_send.clear()
|
|
meta.reqs_not_processed = self._reqs_not_processed
|
|
self._reqs_not_processed = set()
|
|
|
|
return meta
|
|
|
|
def request_finished(
|
|
self,
|
|
request: "Request",
|
|
block_ids: list[int],
|
|
) -> tuple[bool, dict[str, Any] | None]:
|
|
"""
|
|
Once a request is finished, determine whether request blocks
|
|
should be freed now or will be sent asynchronously and freed later.
|
|
"""
|
|
|
|
params = request.kv_transfer_params
|
|
logger.debug(
|
|
"MooncakeConnector request_finished, req_id=%s, request_status=%s, "
|
|
"kv_transfer_params=%s",
|
|
request.request_id,
|
|
request.status,
|
|
params,
|
|
)
|
|
if not params or not params.get("transfer_id"):
|
|
return False, None
|
|
|
|
if params.get("do_remote_prefill"):
|
|
# If do_remote_prefill is still True when the request is finished,
|
|
# update_state_after_alloc must not have been called (the request
|
|
# must have been aborted before it was scheduled).
|
|
# To avoid stranding the prefill blocks in the prefill instance,
|
|
# we must add empty block_ids to _reqs_need_recv so that our
|
|
# worker side will notify and free blocks in the prefill instance.
|
|
assert not self.is_kv_producer
|
|
self._reqs_need_recv[request.request_id] = (request, [])
|
|
params["do_remote_prefill"] = False
|
|
return False, None
|
|
|
|
if not params.get("do_remote_decode"):
|
|
return False, None
|
|
|
|
assert not self.is_kv_consumer
|
|
|
|
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
|
|
# Also include the case of a P/D Prefill request with immediate
|
|
# block free (eg abort). Stop tracking this request.
|
|
self._reqs_not_processed.add(params["transfer_id"])
|
|
return False, None
|
|
|
|
# TODO: check whether block_ids actually ever be 0. If not we could
|
|
# remove the conditional below
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
prompt_blocks = (request.num_prompt_tokens + block_size - 1) // block_size
|
|
send_block_ids = block_ids[:prompt_blocks]
|
|
delay_free_blocks = len(send_block_ids) > 0
|
|
|
|
if self._lw_enabled:
|
|
# LAYERWISE: the transfer was already driven from the alloc-time
|
|
# block_ids during prefill; do NOT re-enqueue (would double-send).
|
|
# Keep blocks alive until the worker signals finished_sending.
|
|
return delay_free_blocks, None
|
|
|
|
if delay_free_blocks:
|
|
self._reqs_need_send[request.request_id] = (request, send_block_ids)
|
|
|
|
return delay_free_blocks, None
|
|
|
|
|
|
class MooncakeConnectorWorker:
|
|
"""Implementation of Worker side methods"""
|
|
|
|
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
|
logger.info("Initializing Mooncake Transfer Engine worker %s", engine_id)
|
|
|
|
self.vllm_config = vllm_config
|
|
|
|
self.engine = TransferEngine()
|
|
self.hostname = get_ip()
|
|
|
|
assert (kv_transfer_config := vllm_config.kv_transfer_config)
|
|
self.is_kv_producer: bool = kv_transfer_config.kv_role == "kv_producer"
|
|
self.is_kv_consumer: bool = kv_transfer_config.kv_role == "kv_consumer"
|
|
self.num_sender_workers = kv_transfer_config.kv_connector_extra_config.get(
|
|
"num_workers", 10
|
|
)
|
|
# Create more tasks than workers to keep the thread pool saturated.
|
|
# Tasks can await async events, so a surplus (2x is a robust heuristic)
|
|
# prevents workers from idling.
|
|
self.num_sender_tasks = self.num_sender_workers * 2
|
|
protocol = kv_transfer_config.kv_connector_extra_config.get( # type: ignore[union-attr]
|
|
"mooncake_protocol", "rdma"
|
|
)
|
|
logger.info(
|
|
"The Mooncake Transfer Engine is using %s as its protocol.", protocol
|
|
)
|
|
ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", protocol, "")
|
|
if ret_value != 0:
|
|
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
|
|
|
|
self.rpc_port = self.engine.get_rpc_port()
|
|
|
|
logger.debug(
|
|
"Mooncake Transfer Engine initialized at %s:%d",
|
|
self.hostname,
|
|
self.rpc_port,
|
|
)
|
|
|
|
self._remote_agents: dict[EngineId, dict[int, dict[int, str]]] = {}
|
|
self._pending_bootstrap_queries: dict[str, asyncio.Event] = {}
|
|
self.side_channel_port: int = 0 # we will bind it in register_kv_caches()
|
|
self.engine_id: EngineId = engine_id
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.num_blocks = 0
|
|
self.bootstrap_server = None
|
|
|
|
assert (parallel_config := vllm_config.parallel_config)
|
|
dp_rank = parallel_config.data_parallel_index
|
|
dp_local_rank = parallel_config.data_parallel_rank_local
|
|
self.dp_rank = dp_local_rank if parallel_config.local_engines_only else dp_rank
|
|
pp_size = vllm_config.parallel_config.pipeline_parallel_size
|
|
if pp_size > 1:
|
|
raise ValueError(
|
|
"Mooncake Transfer Engine does not support pipeline parallelism yet."
|
|
)
|
|
self.pp_rank = get_pp_group().rank_in_group
|
|
|
|
self.kv_caches_base_addr: list[int] = []
|
|
self.device_kv_caches: dict[str, torch.Tensor] = {}
|
|
self.reqs_need_send: dict[TransferId, SendBlockMeta] = {}
|
|
|
|
# --- LAYERWISE (opt-in via MOONCAKE_LAYERWISE=1) ---------------------
|
|
# Push KV per-layer as prefill computes it, so the RDMA write overlaps
|
|
# the remaining prefill compute instead of being a post-hoc full
|
|
# transfer. Off by default => byte-identical upstream behaviour.
|
|
self._lw_enabled = os.environ.get("MOONCAKE_LAYERWISE", "0") == "1"
|
|
self._lw_layer_pos: dict[str, int] = {} # layer_name -> 0..N-1
|
|
self._lw_addr_idx: dict[int, list[int]] = {} # layer_pos -> base-addr idxs
|
|
self._lw_num_layers: int = 0
|
|
# Single-producer-transfer-at-a-time (sufficient for the microbench):
|
|
# _lw_gmax is the highest layer position whose KV is computed in the
|
|
# current producer prefill; reset in record_send_reqs (which runs when
|
|
# the producer request is scheduled, before its forward starts).
|
|
self._lw_gmax: int = -1
|
|
self._lw_producing: bool = False
|
|
self._lw_events: dict[int, Any] = {} # layer_pos -> cuda Event (HBM-ready)
|
|
self._lw_lock = threading.Lock()
|
|
if self._lw_enabled:
|
|
logger.info("Mooncake LAYERWISE mode ENABLED")
|
|
|
|
# For kv_both, we will act both prefiller and decoder.
|
|
if not self.is_kv_consumer:
|
|
# Background threads for sending kvcaches to D.
|
|
self._sender_executor = ThreadPoolExecutor(
|
|
max_workers=self.num_sender_workers,
|
|
thread_name_prefix="vllm-mooncake-sender",
|
|
)
|
|
logger.debug(
|
|
"Mooncake Prefiller: use %d workers to send kvcaches",
|
|
self.num_sender_workers,
|
|
)
|
|
# An asyncio queue to buffer incoming requests for the sender
|
|
self.sender_worker_queue = asyncio.Queue[tuple[bytes, bytes]]()
|
|
self.sender_loop = asyncio.new_event_loop()
|
|
# Background thread for processing new sending requests.
|
|
self._sender_listener_t = threading.Thread(
|
|
target=_async_loop, args=(self.sender_loop,), daemon=True
|
|
)
|
|
self._sender_listener_t.start()
|
|
|
|
# Start bootstrap server on global rank 0.
|
|
if should_launch_bootstrap_server(vllm_config):
|
|
_, port = get_mooncake_bootstrap_addr(vllm_config)
|
|
self.bootstrap_server = MooncakeBootstrapServer(
|
|
vllm_config, "0.0.0.0", port
|
|
)
|
|
self.bootstrap_server.start()
|
|
|
|
if not self.is_kv_producer:
|
|
self.receiver_loop = asyncio.new_event_loop()
|
|
self._mooncake_receiver_t = threading.Thread(
|
|
target=_async_loop, args=(self.receiver_loop,), daemon=True
|
|
)
|
|
self._mooncake_receiver_t.start()
|
|
logger.debug("Mooncake Decoder: start receiver thread")
|
|
|
|
self.finished_sending_reqs: set[ReqId] = set()
|
|
self.finished_recving_reqs: set[ReqId] = set()
|
|
self.failed_recving_block_ids: set[int] = set()
|
|
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
self.model_config = vllm_config.model_config
|
|
self.cache_config = vllm_config.cache_config
|
|
self.use_mla = self.model_config.use_mla
|
|
|
|
# Get the attention backend from the first layer
|
|
# NOTE (NickLucche) models with multiple backends are not supported yet
|
|
backend = get_current_attn_backend(vllm_config)
|
|
self.backend_name = backend.get_name()
|
|
self.kv_cache_layout = get_kv_cache_layout()
|
|
logger.debug("Detected attention backend %s", self.backend_name)
|
|
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
|
|
|
self._tp_size: dict[EngineId, int] = {self.engine_id: self.tp_size}
|
|
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
|
|
self.kv_topo = TpKVTopology(
|
|
tp_rank=self.tp_rank,
|
|
engine_id=self.engine_id,
|
|
remote_tp_size=self._tp_size, # shared state
|
|
remote_block_size=self._block_size, # shared state
|
|
is_mla=self.use_mla,
|
|
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
|
attn_backends=[backend],
|
|
)
|
|
|
|
self.async_zmq_ctx = zmq.asyncio.Context()
|
|
self._encoder = msgspec.msgpack.Encoder()
|
|
self._xfer_meta_decoder = msgspec.msgpack.Decoder(MooncakeXferMetadata)
|
|
self._xfer_resp_decoder = msgspec.msgpack.Decoder(MooncakeXferResponse)
|
|
|
|
def __del__(self):
|
|
self.shutdown()
|
|
|
|
def shutdown(self):
|
|
"""Cleanup background threads on destruction."""
|
|
self.async_zmq_ctx.term()
|
|
if not self.is_kv_consumer:
|
|
self._sender_executor.shutdown(wait=False)
|
|
if self.sender_loop.is_running():
|
|
self.sender_loop.call_soon_threadsafe(self.sender_loop.stop)
|
|
self._sender_listener_t.join()
|
|
if should_launch_bootstrap_server(self.vllm_config):
|
|
self.bootstrap_server.shutdown()
|
|
if not self.is_kv_producer and self.receiver_loop.is_running():
|
|
self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop)
|
|
self._mooncake_receiver_t.join()
|
|
|
|
async def register_worker_with_bootstrap(self):
|
|
host, port = get_mooncake_bootstrap_addr(self.vllm_config)
|
|
url = make_zmq_path("http", host, port) + "/register"
|
|
worker_addr = make_zmq_path("tcp", self.hostname, self.side_channel_port)
|
|
payload = RegisterWorkerPayload(
|
|
engine_id=self.engine_id,
|
|
dp_rank=self.dp_rank,
|
|
tp_rank=self.tp_rank,
|
|
pp_rank=self.pp_rank,
|
|
addr=worker_addr,
|
|
)
|
|
while True:
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(url, json=payload.model_dump())
|
|
response.raise_for_status()
|
|
logger.debug("Successfully registered with bootstrap server at %s", url)
|
|
break
|
|
except httpx.ConnectError:
|
|
# Bootstrap server not ready, wait for a while and retry.
|
|
await asyncio.sleep(1)
|
|
except Exception as e:
|
|
err_msg = (
|
|
e.response.text if isinstance(e, httpx.HTTPStatusError) else str(e)
|
|
)
|
|
logger.error(
|
|
"Error registering %s with bootstrap server: %s", payload, err_msg
|
|
)
|
|
raise e
|
|
|
|
async def _mooncake_sender_listener(self, ready_event: threading.Event):
|
|
"""
|
|
Background thread that listens for Mooncake requests, dispatches them
|
|
to a thread pool, and sends acknowledgments upon completion.
|
|
"""
|
|
|
|
sock = self.async_zmq_ctx.socket(zmq.ROUTER)
|
|
self.side_channel_port = sock.bind_to_random_port(f"tcp://{self.hostname}")
|
|
logger.debug(
|
|
"Mooncake sender starting listening on path: tcp://%s:%d",
|
|
self.hostname,
|
|
self.side_channel_port,
|
|
)
|
|
|
|
await self.register_worker_with_bootstrap()
|
|
|
|
# Create async worker tasks that process items from the queue
|
|
sender_tasks = [
|
|
asyncio.create_task(self._sender_worker(sock))
|
|
for _ in range(self.num_sender_tasks)
|
|
]
|
|
|
|
ready_event.set()
|
|
|
|
try:
|
|
while True:
|
|
identity, metadata_bytes = await sock.recv_multipart()
|
|
await self.sender_worker_queue.put((identity, metadata_bytes))
|
|
except zmq.ContextTerminated:
|
|
logger.debug("ZMQ context terminated, exiting Mooncake sender thread.")
|
|
except Exception as e:
|
|
logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e))
|
|
finally:
|
|
# Clean up worker tasks
|
|
for task in sender_tasks:
|
|
task.cancel()
|
|
await asyncio.gather(*sender_tasks, return_exceptions=True)
|
|
sock.close()
|
|
|
|
async def _sender_worker(self, sock: zmq.asyncio.Socket):
|
|
while True:
|
|
try:
|
|
identity, metadata_bytes = await self.sender_worker_queue.get()
|
|
try:
|
|
metadata = self._xfer_meta_decoder.decode(metadata_bytes)
|
|
await self.send_kv_to_decode(identity, sock, metadata)
|
|
except Exception as e:
|
|
logger.error("Error processing Mooncake xfer request: %s", e)
|
|
error_response = MooncakeXferResponse(
|
|
status=MooncakeXferResponseStatus.ERROR, err_msg=str(e)
|
|
)
|
|
await sock.send_multipart(
|
|
(identity, self._encoder.encode(error_response))
|
|
)
|
|
finally:
|
|
self.sender_worker_queue.task_done()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error("Error in _sender_worker: %s", e)
|
|
|
|
async def send_kv_to_decode(
|
|
self, identity: bytes, sock: zmq.asyncio.Socket, meta: MooncakeXferMetadata
|
|
):
|
|
pending_reqs: dict[ReqId, SendBlockMeta] = {}
|
|
remote_tp_ranks = self.kv_topo.get_target_remote_ranks(meta.remote_tp_size)
|
|
if self.tp_rank not in remote_tp_ranks:
|
|
# This D worker does not pair with the P worker.
|
|
msg = f"This P tp_rank {self.tp_rank} not in remote D target ranks {remote_tp_ranks}" # noqa: E501
|
|
logger.error(msg)
|
|
response = MooncakeXferResponse(
|
|
status=MooncakeXferResponseStatus.ERROR,
|
|
err_msg=msg,
|
|
)
|
|
await sock.send_multipart((identity, self._encoder.encode(response)))
|
|
return
|
|
for d_req_id, (transfer_id, _) in meta.req_blocks.items():
|
|
if transfer_id not in self.reqs_need_send:
|
|
# This req is not enqueued in P side yet, create it here.
|
|
self.reqs_need_send[transfer_id] = SendBlockMeta(
|
|
p_req_id="",
|
|
transfer_id=transfer_id,
|
|
local_block_ids=[],
|
|
ready=asyncio.Event(),
|
|
)
|
|
send_meta = self.reqs_need_send[transfer_id]
|
|
pending_reqs[d_req_id] = send_meta
|
|
|
|
if self._lw_enabled:
|
|
await self._send_kv_layerwise(
|
|
identity, sock, meta, pending_reqs, remote_tp_ranks)
|
|
return
|
|
|
|
async def wait_and_ret(
|
|
d_req_id: ReqId, send_meta: SendBlockMeta
|
|
) -> tuple[ReqId, SendBlockMeta]:
|
|
await send_meta.ready.wait()
|
|
return d_req_id, send_meta
|
|
|
|
wait_tasks = [
|
|
asyncio.create_task(wait_and_ret(d_req_id, send_meta))
|
|
for d_req_id, send_meta in pending_reqs.items()
|
|
]
|
|
|
|
while wait_tasks:
|
|
done, pending = await asyncio.wait(
|
|
wait_tasks,
|
|
timeout=envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
|
|
if not done:
|
|
# Timeout, abort all pending requests.
|
|
for task in wait_tasks:
|
|
task.cancel()
|
|
logger.warning(
|
|
"Timeout waiting for P side ready: %s", list(pending_reqs)
|
|
)
|
|
response = MooncakeXferResponse(
|
|
status=MooncakeXferResponseStatus.FINISH,
|
|
err_reqs=list(pending_reqs),
|
|
err_msg="Timeout waiting for P side ready.",
|
|
)
|
|
await sock.send_multipart((identity, self._encoder.encode(response)))
|
|
break
|
|
|
|
wait_tasks = list(pending)
|
|
response_status = (
|
|
MooncakeXferResponseStatus.CONTINUE
|
|
if wait_tasks
|
|
else MooncakeXferResponseStatus.FINISH
|
|
)
|
|
ready_reqs: list[tuple[ReqId, SendBlockMeta]] = []
|
|
for task in done:
|
|
d_req_id, send_meta = task.result()
|
|
del pending_reqs[d_req_id]
|
|
# Do we still in reqs_need_send (not expired)?
|
|
if send_meta.transfer_id in self.reqs_need_send:
|
|
# Mark it sending to avoid expiration.
|
|
send_meta.sending += 1
|
|
if not send_meta.need_send:
|
|
self.resolve_need_send(send_meta, remote_tp_ranks)
|
|
ready_reqs.append((d_req_id, send_meta))
|
|
else:
|
|
# Otherwise (expired, very unlikely), just forget it.
|
|
logger.warning(
|
|
"Request %s expired before sending on P side.", d_req_id
|
|
)
|
|
|
|
src_ptrs, dst_ptrs, lengths, err_reqs = await self._build_transfer_params(
|
|
ready_reqs, meta
|
|
)
|
|
|
|
if err_reqs:
|
|
response = MooncakeXferResponse(
|
|
status=response_status,
|
|
err_reqs=err_reqs,
|
|
err_msg="P num blocks less than D",
|
|
)
|
|
await sock.send_multipart((identity, self._encoder.encode(response)))
|
|
|
|
if src_ptrs:
|
|
remote_session = f"{meta.remote_hostname}:{meta.remote_port}"
|
|
ret_value = await self.sender_loop.run_in_executor(
|
|
self._sender_executor,
|
|
self._send_blocks,
|
|
remote_session,
|
|
src_ptrs,
|
|
dst_ptrs,
|
|
lengths,
|
|
)
|
|
|
|
if ret_value != 0:
|
|
err_reqs = []
|
|
for d_req_id, send_meta in ready_reqs:
|
|
send_meta.sending -= 1
|
|
err_reqs.append(d_req_id)
|
|
# Do best effort to transfer the remaining reqs.
|
|
response = MooncakeXferResponse(
|
|
status=response_status,
|
|
err_reqs=err_reqs,
|
|
err_msg=f"Mooncake transfer engine returned {ret_value}",
|
|
)
|
|
await sock.send_multipart(
|
|
(identity, self._encoder.encode(response))
|
|
)
|
|
continue
|
|
|
|
for d_req_id, send_meta in ready_reqs:
|
|
# TODO: for heterogeneous TP (one P pairs to multiple D),
|
|
# we need to check whether all headers are sent.
|
|
# If not, we should set expire_time to normal and skip the below.
|
|
send_meta.sending -= 1
|
|
send_meta.sent += 1
|
|
if send_meta.sent == send_meta.need_send:
|
|
del self.reqs_need_send[send_meta.transfer_id]
|
|
self.finished_sending_reqs.add(send_meta.p_req_id)
|
|
|
|
response = MooncakeXferResponse(
|
|
status=response_status,
|
|
ok_reqs=[d_req_id for d_req_id, _ in ready_reqs],
|
|
)
|
|
await sock.send_multipart((identity, self._encoder.encode(response)))
|
|
|
|
def resolve_need_send(self, send_meta: SendBlockMeta, remote_tp_ranks: list[int]):
|
|
# Prepare for heterogeneous TP (one P pairs to multiple D)
|
|
send_meta.need_send = len(remote_tp_ranks)
|
|
if send_meta.need_send != 1:
|
|
logger.error("Mooncake: Heterogeneous TP is not supported yet.")
|
|
raise NotImplementedError(
|
|
"Mooncake: Heterogeneous TP is not supported yet."
|
|
)
|
|
|
|
def note_layer_computed(self, layer_name: str):
|
|
"""LAYERWISE: called from save_kv_layer after layer L's attention runs.
|
|
|
|
Records a CUDA event so the sender can wait until L's KV is actually
|
|
in HBM before RDMA-reading it, and bumps the per-transfer high-water
|
|
mark of computed layers.
|
|
"""
|
|
if not self._lw_enabled or not self._lw_producing:
|
|
return
|
|
pos = self._lw_layer_pos.get(layer_name)
|
|
if pos is None:
|
|
return
|
|
ev = torch.cuda.Event()
|
|
ev.record(torch.cuda.current_stream())
|
|
with self._lw_lock:
|
|
self._lw_events[pos] = ev
|
|
if pos > self._lw_gmax:
|
|
self._lw_gmax = pos
|
|
|
|
def _build_layer_params(self, d_req_id, send_meta, agent_meta, layer_pos):
|
|
"""Like _build_transfer_params but only this layer's base-addr slots."""
|
|
src_ptrs: list[int] = []
|
|
dst_ptrs: list[int] = []
|
|
lengths: list[int] = []
|
|
_, remote_block_ids = agent_meta.req_blocks[d_req_id]
|
|
num_remote = len(remote_block_ids)
|
|
if num_remote == 0:
|
|
return src_ptrs, dst_ptrs, lengths
|
|
local_block_ids = send_meta.local_block_ids
|
|
if len(local_block_ids) < num_remote:
|
|
logger.error("layerwise %s: local blocks(%d) < remote(%d)",
|
|
d_req_id, len(local_block_ids), num_remote)
|
|
return src_ptrs, dst_ptrs, lengths
|
|
if len(local_block_ids) > num_remote:
|
|
local_block_ids = local_block_ids[-num_remote:]
|
|
g_local, g_remote = group_concurrent_contiguous(
|
|
local_block_ids, remote_block_ids)
|
|
block_len = self.block_len
|
|
for addr_idx in self._lw_addr_idx[layer_pos]:
|
|
local_layer_addr = self.kv_caches_base_addr[addr_idx]
|
|
remote_layer_addr = agent_meta.kv_caches_base_addr[addr_idx]
|
|
for gl, gr in zip(g_local, g_remote):
|
|
src_ptrs.append(local_layer_addr + gl[0] * block_len)
|
|
dst_ptrs.append(remote_layer_addr + gr[0] * block_len)
|
|
lengths.append(block_len * len(gl))
|
|
return src_ptrs, dst_ptrs, lengths
|
|
|
|
async def _send_kv_layerwise(
|
|
self, identity, sock, meta, pending_reqs, remote_tp_ranks,
|
|
):
|
|
"""Write each layer's KV as soon as prefill computes it (write mode)."""
|
|
ready_reqs: list[tuple[ReqId, SendBlockMeta]] = []
|
|
for d_req_id, send_meta in pending_reqs.items():
|
|
try:
|
|
await asyncio.wait_for(
|
|
send_meta.ready.wait(),
|
|
timeout=envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT)
|
|
except asyncio.TimeoutError:
|
|
logger.warning("layerwise: timeout waiting block_ids for %s", d_req_id)
|
|
continue
|
|
send_meta.sending += 1
|
|
if not send_meta.need_send:
|
|
self.resolve_need_send(send_meta, remote_tp_ranks)
|
|
ready_reqs.append((d_req_id, send_meta))
|
|
|
|
if not ready_reqs:
|
|
response = MooncakeXferResponse(
|
|
status=MooncakeXferResponseStatus.FINISH,
|
|
err_reqs=list(pending_reqs), err_msg="layerwise: no ready reqs")
|
|
await sock.send_multipart((identity, self._encoder.encode(response)))
|
|
return
|
|
|
|
remote_session = f"{meta.remote_hostname}:{meta.remote_port}"
|
|
t0 = time.perf_counter()
|
|
deadline = t0 + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
|
|
for layer_pos in range(self._lw_num_layers):
|
|
# Wait until this layer's KV is computed (poll; layers are ms).
|
|
while True:
|
|
with self._lw_lock:
|
|
done = self._lw_gmax
|
|
if done >= layer_pos:
|
|
break
|
|
if time.perf_counter() > deadline:
|
|
logger.error("layerwise: timeout at layer %d (gmax=%d)",
|
|
layer_pos, done)
|
|
break
|
|
await asyncio.sleep(0.0005)
|
|
# Ensure layer's KV write to HBM is complete before RDMA read.
|
|
with self._lw_lock:
|
|
ev = self._lw_events.get(layer_pos)
|
|
if ev is not None:
|
|
ev.synchronize()
|
|
for d_req_id, send_meta in ready_reqs:
|
|
src_ptrs, dst_ptrs, lengths = self._build_layer_params(
|
|
d_req_id, send_meta, meta, layer_pos)
|
|
if src_ptrs:
|
|
ret = await self.sender_loop.run_in_executor(
|
|
self._sender_executor, self._send_blocks,
|
|
remote_session, src_ptrs, dst_ptrs, lengths)
|
|
if ret != 0:
|
|
logger.error("layerwise: _send_blocks ret=%d layer=%d",
|
|
ret, layer_pos)
|
|
|
|
logger.info("layerwise: transfer done in %.3fs (%d layers, %d reqs)",
|
|
time.perf_counter() - t0, self._lw_num_layers, len(ready_reqs))
|
|
with self._lw_lock:
|
|
self._lw_producing = False
|
|
for d_req_id, send_meta in ready_reqs:
|
|
send_meta.sending -= 1
|
|
send_meta.sent += 1
|
|
if send_meta.sent == send_meta.need_send:
|
|
self.reqs_need_send.pop(send_meta.transfer_id, None)
|
|
self.finished_sending_reqs.add(send_meta.p_req_id)
|
|
response = MooncakeXferResponse(
|
|
status=MooncakeXferResponseStatus.FINISH,
|
|
ok_reqs=[d for d, _ in ready_reqs])
|
|
await sock.send_multipart((identity, self._encoder.encode(response)))
|
|
|
|
async def _build_transfer_params(
|
|
self,
|
|
ready_reqs: list[tuple[ReqId, SendBlockMeta]],
|
|
agent_meta: MooncakeXferMetadata,
|
|
) -> tuple[list[int], list[int], list[int], list[ReqId]]:
|
|
src_ptrs = []
|
|
dst_ptrs = []
|
|
lengths = []
|
|
err_reqs: list[ReqId] = []
|
|
local_base_addr = self.kv_caches_base_addr
|
|
remote_base_addr = agent_meta.kv_caches_base_addr
|
|
block_len = self.block_len
|
|
remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}"
|
|
|
|
for d_req_id, send_meta in ready_reqs:
|
|
_, remote_block_ids = agent_meta.req_blocks[d_req_id]
|
|
num_remote_blocks = len(remote_block_ids)
|
|
if num_remote_blocks == 0:
|
|
continue
|
|
|
|
local_block_ids = send_meta.local_block_ids
|
|
# Partial prefix cache hit: just read uncomputed blocks.
|
|
num_local_blocks = len(local_block_ids)
|
|
if num_local_blocks < num_remote_blocks:
|
|
logger.error(
|
|
"req %s: local blocks(%d) less than remote blocks(%d)!",
|
|
d_req_id,
|
|
num_local_blocks,
|
|
num_remote_blocks,
|
|
)
|
|
err_reqs.append(d_req_id)
|
|
continue
|
|
if num_local_blocks > num_remote_blocks:
|
|
local_block_ids = local_block_ids[-num_remote_blocks:]
|
|
|
|
# Group by indices
|
|
group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous(
|
|
local_block_ids, remote_block_ids
|
|
)
|
|
|
|
for local_layer_addr, remote_layer_addr in zip(
|
|
local_base_addr, remote_base_addr
|
|
):
|
|
for group_local_block_id, group_remote_block_id in zip(
|
|
group_local_block_ids, group_remote_block_ids
|
|
):
|
|
src_ptrs.append(
|
|
local_layer_addr + group_local_block_id[0] * block_len
|
|
)
|
|
dst_ptrs.append(
|
|
remote_layer_addr + group_remote_block_id[0] * block_len
|
|
)
|
|
lengths.append(block_len * len(group_local_block_id))
|
|
|
|
logger.debug(
|
|
"Sending kv_caches for request %s (%d blocks) to %s",
|
|
d_req_id,
|
|
num_remote_blocks,
|
|
remote_session,
|
|
)
|
|
|
|
return src_ptrs, dst_ptrs, lengths, err_reqs
|
|
|
|
def _send_blocks(
|
|
self,
|
|
remote_session: str,
|
|
src_ptrs: list[int],
|
|
dst_ptrs: list[int],
|
|
lengths: list[int],
|
|
) -> int:
|
|
start_time = time.perf_counter()
|
|
ret_value = self.engine.batch_transfer_sync_write(
|
|
remote_session, src_ptrs, dst_ptrs, lengths
|
|
)
|
|
if ret_value == 0:
|
|
logger.debug(
|
|
"Sending to %s done, took %s",
|
|
remote_session,
|
|
time.perf_counter() - start_time,
|
|
)
|
|
return ret_value
|
|
|
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
|
"""Register the KV Cache data in mooncake."""
|
|
|
|
logger.info("Registering KV_Caches. use_mla: %s", self.use_mla)
|
|
|
|
kv_data_ptrs = []
|
|
kv_data_lens = []
|
|
seen_base_addresses = []
|
|
|
|
split_k_and_v = self.kv_topo.split_k_and_v
|
|
tensor_size_bytes = None
|
|
for layer_name, cache_or_caches in kv_caches.items():
|
|
logger.debug(
|
|
"registering layer %s with shape %s", layer_name, cache_or_caches.shape
|
|
)
|
|
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
|
|
|
for cache in cache_list:
|
|
base_addr = cache.data_ptr()
|
|
if base_addr in seen_base_addresses:
|
|
continue
|
|
|
|
seen_base_addresses.append(base_addr)
|
|
curr_tensor_size_bytes = cache.nbytes
|
|
|
|
if tensor_size_bytes is None:
|
|
tensor_size_bytes = curr_tensor_size_bytes
|
|
self.num_blocks = cache.shape[0]
|
|
|
|
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
|
"All kv cache tensors must have the same size"
|
|
)
|
|
kernel_block_size = cache.shape[-2 if self.use_mla else -3]
|
|
assert self.block_size == kernel_block_size
|
|
kv_data_ptrs.append(base_addr)
|
|
kv_data_lens.append(tensor_size_bytes)
|
|
|
|
self.kv_caches_base_addr = seen_base_addresses
|
|
|
|
ret_value = self.engine.batch_register_memory(kv_data_ptrs, kv_data_lens)
|
|
if ret_value != 0:
|
|
raise RuntimeError("Mooncake batch memory registration failed.")
|
|
|
|
assert tensor_size_bytes is not None
|
|
assert self.num_blocks != 0
|
|
assert tensor_size_bytes % self.num_blocks == 0
|
|
self.block_len = tensor_size_bytes // self.num_blocks
|
|
self.device_kv_caches = kv_caches
|
|
logger.debug(
|
|
"registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len
|
|
)
|
|
|
|
# --- LAYERWISE: map layer_name -> position, position -> base-addr idxs.
|
|
if self._lw_enabled:
|
|
n_per = 2 if split_k_and_v else 1
|
|
layer_names = list(kv_caches.keys())
|
|
self._lw_num_layers = len(layer_names)
|
|
for pos, ln in enumerate(layer_names):
|
|
self._lw_layer_pos[ln] = pos
|
|
self._lw_addr_idx[pos] = [pos * n_per + j for j in range(n_per)]
|
|
logger.info("LAYERWISE: %d layers, %d base-addrs (split_k_and_v=%s)",
|
|
self._lw_num_layers, len(self.kv_caches_base_addr),
|
|
split_k_and_v)
|
|
|
|
if self.is_kv_consumer:
|
|
return
|
|
|
|
ready_event = threading.Event()
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._mooncake_sender_listener(ready_event), self.sender_loop
|
|
)
|
|
ready_event.wait()
|
|
|
|
if self.bootstrap_server is not None:
|
|
self.bootstrap_server.set_worker_kv_info(
|
|
self.kv_caches_base_addr, self.block_len,
|
|
self.block_size, self.hostname, self.rpc_port,
|
|
transfer_engine=self.engine,
|
|
)
|
|
if _shared_block_pool is not None:
|
|
self.bootstrap_server.set_block_pool(_shared_block_pool)
|
|
|
|
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
|
|
finished_recving_reqs = self.finished_recving_reqs
|
|
self.finished_recving_reqs = set()
|
|
return finished_recving_reqs
|
|
|
|
def get_block_ids_with_load_errors(self) -> set[int]:
|
|
failed = self.failed_recving_block_ids
|
|
self.failed_recving_block_ids = set()
|
|
return failed
|
|
|
|
async def fetch_finished_sending_reqs(self) -> set[ReqId]:
|
|
finished_sending_reqs = self.finished_sending_reqs
|
|
self.finished_sending_reqs = set()
|
|
|
|
# Handle timeout to avoid stranding blocks on remote.
|
|
now = time.perf_counter()
|
|
|
|
expired_transfer_id = []
|
|
for transfer_id, send_meta in self.reqs_need_send.items():
|
|
if (
|
|
send_meta.p_req_id
|
|
and send_meta.expire_time < now
|
|
and send_meta.sending == 0
|
|
):
|
|
logger.warning(
|
|
"Request %s timed out after %d seconds without "
|
|
"being sent. Freeing its blocks on the producer side.",
|
|
send_meta.p_req_id,
|
|
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
|
|
)
|
|
finished_sending_reqs.add(send_meta.p_req_id)
|
|
expired_transfer_id.append(transfer_id)
|
|
|
|
for transfer_id in expired_transfer_id:
|
|
del self.reqs_need_send[transfer_id]
|
|
|
|
return finished_sending_reqs
|
|
|
|
def get_finished(self) -> tuple[set[str] | None, set[str] | None]:
|
|
"""
|
|
Get requests that are done sending or recving on this specific worker.
|
|
The scheduler process (via the MultiprocExecutor) will use this output
|
|
to track which workers are done.
|
|
"""
|
|
recv_fut = None
|
|
send_fut = None
|
|
if not self.is_kv_producer:
|
|
recv_fut = asyncio.run_coroutine_threadsafe(
|
|
self.fetch_finished_recving_reqs(), self.receiver_loop
|
|
)
|
|
|
|
if not self.is_kv_consumer:
|
|
send_fut = asyncio.run_coroutine_threadsafe(
|
|
self.fetch_finished_sending_reqs(), self.sender_loop
|
|
)
|
|
|
|
finished_recving_reqs = recv_fut.result() if recv_fut else set()
|
|
finished_sending_reqs = send_fut.result() if send_fut else set()
|
|
|
|
if finished_sending_reqs or finished_recving_reqs:
|
|
logger.debug(
|
|
"Rank %s, get_finished: %s requests done sending "
|
|
"and %s requests done recving",
|
|
self.tp_rank,
|
|
len(finished_sending_reqs),
|
|
len(finished_recving_reqs),
|
|
)
|
|
|
|
return finished_sending_reqs or None, finished_recving_reqs or None
|
|
|
|
async def receive_kv_from_single_worker(
|
|
self,
|
|
worker_addr: str,
|
|
pull_metas: dict[ReqId, PullReqMeta],
|
|
):
|
|
req_ids = set(pull_metas)
|
|
metadata = MooncakeXferMetadata(
|
|
remote_hostname=self.hostname,
|
|
remote_port=self.rpc_port,
|
|
remote_tp_size=self.tp_size,
|
|
remote_tp_rank=self.tp_rank,
|
|
req_blocks={
|
|
req_id: (pull_meta.transfer_id, pull_meta.local_block_ids)
|
|
for req_id, pull_meta in pull_metas.items()
|
|
},
|
|
kv_caches_base_addr=self.kv_caches_base_addr,
|
|
)
|
|
|
|
encoded_data = self._encoder.encode(metadata)
|
|
logger.debug(
|
|
"Size of encoded MooncakeXferMetadata: %d bytes", len(encoded_data)
|
|
)
|
|
logger.debug(
|
|
"Sending kv transfer request for %s on path: %s", req_ids, worker_addr
|
|
)
|
|
|
|
# Send query for the request.
|
|
try:
|
|
with make_zmq_socket(
|
|
self.async_zmq_ctx, worker_addr, zmq.DEALER, bind=False, linger=0
|
|
) as sock:
|
|
# If something goes wrong, let P wait timeout first (in asyncio.wait()).
|
|
sock.setsockopt(
|
|
zmq.RCVTIMEO, (envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT + 60) * 1000
|
|
)
|
|
await sock.send(encoded_data)
|
|
while True:
|
|
ret_msg = await sock.recv()
|
|
response = self._xfer_resp_decoder.decode(ret_msg)
|
|
if response.status == MooncakeXferResponseStatus.ERROR:
|
|
logger.error(
|
|
"Error happens during transferring kvcache for %s: %s",
|
|
req_ids,
|
|
response.err_msg,
|
|
)
|
|
return
|
|
self.process_pulling_result(response, pull_metas)
|
|
if response.status == MooncakeXferResponseStatus.FINISH:
|
|
break
|
|
except zmq.ContextTerminated:
|
|
logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.")
|
|
except Exception as e:
|
|
logger.error("MooncakeXferMetadata transfer failed for %s: %s", req_ids, e)
|
|
for req_id in req_ids:
|
|
pull_meta = pull_metas[req_id]
|
|
self.failed_recving_block_ids.update(pull_meta.local_block_ids)
|
|
self.finished_recving_reqs.add(pull_meta.d_req_id)
|
|
return
|
|
|
|
def process_pulling_result(
|
|
self,
|
|
response: MooncakeXferResponse,
|
|
pull_metas: dict[ReqId, PullReqMeta],
|
|
):
|
|
ok_reqs: list[ReqId] = response.ok_reqs or []
|
|
|
|
for req_id in ok_reqs:
|
|
pull_meta = pull_metas[req_id]
|
|
# No race because we are in async loop.
|
|
pull_meta.pull_tasks_count -= 1
|
|
if pull_meta.pull_tasks_count == 0:
|
|
self.finished_recving_reqs.add(pull_meta.d_req_id)
|
|
|
|
if ok_reqs:
|
|
logger.debug("pulling kv_caches for %s finished", ok_reqs)
|
|
|
|
if response.err_reqs:
|
|
logger.error(
|
|
"pulling kv_caches for %s failed: %s",
|
|
response.err_reqs,
|
|
response.err_msg,
|
|
)
|
|
for req_id in response.err_reqs:
|
|
pull_meta = pull_metas.get(req_id)
|
|
if pull_meta is None:
|
|
continue
|
|
self.failed_recving_block_ids.update(pull_meta.local_block_ids)
|
|
self.finished_recving_reqs.add(pull_meta.d_req_id)
|
|
|
|
async def _connect_to_prefiller_bootstrap(self, remote_bootstrap_addr: str):
|
|
url = remote_bootstrap_addr + "/query"
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(url)
|
|
response.raise_for_status()
|
|
data: dict = response.json()
|
|
for _, dp_entry in data.items():
|
|
remote_engine_id = dp_entry["engine_id"]
|
|
self._remote_agents[remote_engine_id] = {
|
|
int(tp_rank): {
|
|
int(pp_rank): worker_addr
|
|
for pp_rank, worker_addr in tp_entry.items()
|
|
}
|
|
for tp_rank, tp_entry in dp_entry["worker_addr"].items()
|
|
}
|
|
self._tp_size[remote_engine_id] = len(dp_entry["worker_addr"])
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to connect to bootstrap server %s: %s",
|
|
remote_bootstrap_addr,
|
|
e,
|
|
)
|
|
|
|
# Always notify others regardless of connection success or failure.
|
|
self._pending_bootstrap_queries[remote_bootstrap_addr].set()
|
|
del self._pending_bootstrap_queries[remote_bootstrap_addr]
|
|
|
|
def receive_kv(
|
|
self,
|
|
remote_engine_id: EngineId,
|
|
pull_metas: dict[ReqId, PullReqMeta],
|
|
):
|
|
remote_tp_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
|
|
remote_engine_id
|
|
)
|
|
count = len(remote_tp_ranks)
|
|
if count != 1:
|
|
logger.error("Mooncake: Heterogeneous TP is not supported yet.")
|
|
raise NotImplementedError(
|
|
"Mooncake: Heterogeneous TP is not supported yet."
|
|
)
|
|
for pull_meta in pull_metas.values():
|
|
pull_meta.pull_tasks_count = count
|
|
for remote_tp_rank in remote_tp_ranks:
|
|
worker_addr = self._remote_agents[remote_engine_id][remote_tp_rank][0]
|
|
asyncio.create_task(
|
|
self.receive_kv_from_single_worker(worker_addr, pull_metas)
|
|
)
|
|
|
|
async def handle_new_engine_id(
|
|
self,
|
|
remote_engine_id: EngineId,
|
|
pull_metas: dict[ReqId, PullReqMeta],
|
|
):
|
|
remote_bootstrap_addr = next(iter(pull_metas.values())).remote_bootstrap_addr
|
|
if remote_bootstrap_addr not in self._pending_bootstrap_queries:
|
|
self._pending_bootstrap_queries[remote_bootstrap_addr] = asyncio.Event()
|
|
await self._connect_to_prefiller_bootstrap(remote_bootstrap_addr)
|
|
else:
|
|
await self._pending_bootstrap_queries[remote_bootstrap_addr].wait()
|
|
|
|
if remote_engine_id not in self._remote_agents:
|
|
logger.error(
|
|
"Failed to find remote engine_id %s from bootstrap server %s",
|
|
remote_engine_id,
|
|
remote_bootstrap_addr,
|
|
)
|
|
return
|
|
|
|
self.receive_kv(remote_engine_id, pull_metas)
|
|
|
|
async def _start_direct_read(
|
|
self, reqs_by_engine: dict[EngineId, dict[ReqId, PullReqMeta]]
|
|
):
|
|
"""Direct RDMA read: D reads cached KV blocks from C's GPU memory
|
|
without involving C's scheduler.
|
|
"""
|
|
for _engine_id, pull_metas in reqs_by_engine.items():
|
|
for req_id, pm in pull_metas.items():
|
|
asyncio.create_task(
|
|
self._direct_read_single(req_id, pm)
|
|
)
|
|
|
|
async def _direct_read_single(self, req_id: ReqId, pm: PullReqMeta):
|
|
"""Bootstrap-triggered PUSH: D asks C's bootstrap to push matched blocks.
|
|
|
|
C's bootstrap looks up cached blocks by token_ids, then uses C's
|
|
TransferEngine to RDMA WRITE (push) them directly into D's GPU memory.
|
|
C's scheduler is NOT involved.
|
|
"""
|
|
bootstrap_url = pm.remote_bootstrap_addr
|
|
num_remote_tokens = pm.remote_num_tokens or len(pm.prompt_token_ids)
|
|
|
|
try:
|
|
local_block_ids = pm.local_block_ids
|
|
d_session = f"{self.hostname}:{self.rpc_port}"
|
|
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
resp = await client.post(
|
|
f"{bootstrap_url}/push_blocks",
|
|
json={
|
|
"token_ids": pm.prompt_token_ids,
|
|
"num_tokens": num_remote_tokens,
|
|
"dst_block_ids": local_block_ids,
|
|
"dst_base_addrs": self.kv_caches_base_addr,
|
|
"dst_block_len": self.block_len,
|
|
"dst_session": d_session,
|
|
},
|
|
)
|
|
resp.raise_for_status()
|
|
result = resp.json()
|
|
|
|
matched = result.get("matched", 0)
|
|
pushed = result.get("pushed", False)
|
|
|
|
if matched > 0 and pushed:
|
|
logger.info("direct_push %s: %d blocks pushed from C", req_id, matched)
|
|
else:
|
|
logger.debug("direct_push %s: %d matched, pushed=%s", req_id, matched, pushed)
|
|
self.failed_recving_block_ids.update(local_block_ids)
|
|
|
|
self.finished_recving_reqs.add(req_id)
|
|
|
|
except Exception as e:
|
|
logger.error("direct_push %s failed: %s", req_id, e)
|
|
self.failed_recving_block_ids.update(pm.local_block_ids)
|
|
self.finished_recving_reqs.add(req_id)
|
|
|
|
async def _start_load_kv(
|
|
self, reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]]
|
|
):
|
|
for remote_engine_id, pull_metas in reqs_to_recv.items():
|
|
if remote_engine_id not in self._remote_agents:
|
|
asyncio.create_task(
|
|
self.handle_new_engine_id(remote_engine_id, pull_metas)
|
|
)
|
|
else:
|
|
self.receive_kv(remote_engine_id, pull_metas)
|
|
|
|
async def record_send_reqs(self, metadata: MooncakeConnectorMetadata):
|
|
for p_req_id, (transfer_id, block_ids) in metadata.reqs_to_send.items():
|
|
if block_ids:
|
|
# Already gone through request_finished() — OR, in LAYERWISE
|
|
# mode, block_ids arrive at alloc time and the SendBlockMeta
|
|
# may not exist yet, so create it on demand.
|
|
send_meta = self.reqs_need_send.get(transfer_id)
|
|
if send_meta is None:
|
|
send_meta = SendBlockMeta(
|
|
p_req_id=p_req_id,
|
|
transfer_id=transfer_id,
|
|
local_block_ids=[],
|
|
ready=asyncio.Event(),
|
|
)
|
|
self.reqs_need_send[transfer_id] = send_meta
|
|
send_meta.p_req_id = p_req_id
|
|
send_meta.local_block_ids = block_ids
|
|
send_meta.expire_time = (
|
|
time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
|
|
)
|
|
send_meta.ready.set()
|
|
if self._lw_enabled:
|
|
# Producer scheduled (before its prefill forward) — reset the
|
|
# layer high-water mark so note_layer_computed tracks THIS
|
|
# request's layers from scratch.
|
|
with self._lw_lock:
|
|
self._lw_gmax = -1
|
|
self._lw_events.clear()
|
|
self._lw_producing = True
|
|
else:
|
|
# From update_state_after_alloc(),
|
|
# but not reach request_finished() yet
|
|
# This may be already created by send_kv_to_decode()
|
|
# when D is sending MooncakeXferMetadata.
|
|
if transfer_id not in self.reqs_need_send:
|
|
self.reqs_need_send[transfer_id] = SendBlockMeta(
|
|
p_req_id=p_req_id,
|
|
transfer_id=transfer_id,
|
|
local_block_ids=[],
|
|
ready=asyncio.Event(),
|
|
)
|
|
for transfer_id in metadata.reqs_not_processed:
|
|
send_meta = self.reqs_need_send.pop(transfer_id)
|
|
if send_meta:
|
|
assert not send_meta.ready.is_set()
|
|
|
|
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
|
# Sync hash table to bootstrap server (for direct RDMA read queries)
|
|
if self.bootstrap_server is not None and (
|
|
metadata.hash_table_updates or metadata.hash_table_removals
|
|
):
|
|
self.bootstrap_server.update_hash_table(
|
|
metadata.hash_table_updates, metadata.hash_table_removals
|
|
)
|
|
|
|
if not self.is_kv_producer and metadata.reqs_to_recv:
|
|
# Split direct_read vs normal pull requests
|
|
direct_reqs: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
|
|
normal_reqs: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
|
|
for engine_id, pull_metas in metadata.reqs_to_recv.items():
|
|
for req_id, pm in pull_metas.items():
|
|
if pm.direct_read:
|
|
direct_reqs[engine_id][req_id] = pm
|
|
else:
|
|
normal_reqs[engine_id][req_id] = pm
|
|
|
|
if normal_reqs:
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._start_load_kv(normal_reqs), self.receiver_loop
|
|
)
|
|
if direct_reqs:
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._start_direct_read(direct_reqs), self.receiver_loop
|
|
)
|
|
|
|
if not self.is_kv_consumer and (
|
|
metadata.reqs_to_send or metadata.reqs_not_processed
|
|
):
|
|
asyncio.run_coroutine_threadsafe(
|
|
self.record_send_reqs(metadata), self.sender_loop
|
|
)
|
|
|
|
|
|
def group_concurrent_contiguous(
|
|
src_indices: list[int], dst_indices: list[int]
|
|
) -> tuple[list[list[int]], list[list[int]]]:
|
|
"""Vectorised NumPy implementation."""
|
|
if len(src_indices) == 0:
|
|
return [], []
|
|
|
|
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
|
src_groups = np.split(src_indices, brk)
|
|
dst_groups = np.split(dst_indices, brk)
|
|
|
|
src_groups = [g.tolist() for g in src_groups]
|
|
dst_groups = [g.tolist() for g in dst_groups]
|
|
|
|
return src_groups, dst_groups
|
|
|
|
|
|
def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int:
|
|
# This logic is now centralized
|
|
return (
|
|
envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
|
|
+ vllm_config.parallel_config.data_parallel_index
|
|
* vllm_config.parallel_config.tensor_parallel_size
|
|
)
|
|
|
|
|
|
def _async_loop(loop: asyncio.AbstractEventLoop):
|
|
asyncio.set_event_loop(loop)
|
|
loop.run_forever()
|
|
|
|
|
|
def should_launch_bootstrap_server(vllm_config: VllmConfig) -> bool:
|
|
assert (parallel_config := vllm_config.parallel_config)
|
|
# In hybrid or external LB mode,
|
|
# each instance should have its own bootstrap server.
|
|
#
|
|
# In internal LB mode,
|
|
# only the real global first rank need to launch the bootstrap server.
|
|
return is_local_first_rank() and (
|
|
parallel_config.local_engines_only or parallel_config.data_parallel_index == 0
|
|
)
|
|
|
|
|
|
def get_mooncake_bootstrap_addr(vllm_config: VllmConfig) -> tuple[str, int]:
|
|
"""
|
|
Returns the address of the Mooncake bootstrap server.
|
|
This is only used by prefillers to register workers.
|
|
Decoders should get addr from kv_transfer_params.
|
|
"""
|
|
assert (parallel_config := vllm_config.parallel_config)
|
|
if parallel_config.local_engines_only:
|
|
# In hybrid or external LB mode, connect to local server.
|
|
host = "127.0.0.1"
|
|
else:
|
|
host = parallel_config.data_parallel_master_ip
|
|
port = envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
|
|
return (host, port)
|