Files
agentic-pd-hybrid/src/agentic_pd_hybrid/pd_router.py

302 lines
9.6 KiB
Python

from __future__ import annotations
import argparse
import asyncio
import random
import urllib.parse
from dataclasses import dataclass
from http import HTTPStatus
from itertools import chain
from typing import AsyncIterator
import aiohttp
import orjson
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
_STREAM_CHUNK_SIZE = 1024 * 64
@dataclass
class RouterConfig:
host: str
port: int
prefill_urls: list[tuple[str, int]]
decode_urls: list[str]
prefill_policy: str = "round_robin"
decode_policy: str = "manual"
request_timeout_s: float = 1800.0
class RouterState:
def __init__(self, config: RouterConfig):
if not config.prefill_urls:
raise ValueError("At least one prefill worker is required")
if not config.decode_urls:
raise ValueError("At least one decode worker is required")
self.config = config
self.prefill_cursor = 0
self.decode_cursor = 0
self.sticky_decode_map: dict[str, int] = {}
def select_pair(self, headers: dict[str, str]) -> tuple[str, int, str]:
prefill_url, bootstrap_port = self.config.prefill_urls[
self.prefill_cursor % len(self.config.prefill_urls)
]
self.prefill_cursor += 1
decode_index = self._select_decode_index(headers)
return prefill_url, bootstrap_port, self.config.decode_urls[decode_index]
def _select_decode_index(self, headers: dict[str, str]) -> int:
target_worker = headers.get("x-smg-target-worker")
routing_key = headers.get("x-smg-routing-key")
if (
self.config.decode_policy == "consistent_hashing"
and target_worker is not None
):
idx = int(target_worker)
if 0 <= idx < len(self.config.decode_urls):
return idx
if self.config.decode_policy == "manual" and routing_key:
cached = self.sticky_decode_map.get(routing_key)
if cached is not None:
return cached
idx = self.decode_cursor % len(self.config.decode_urls)
self.decode_cursor += 1
self.sticky_decode_map[routing_key] = idx
return idx
idx = self.decode_cursor % len(self.config.decode_urls)
self.decode_cursor += 1
return idx
app = FastAPI()
router_state: RouterState | None = None
@app.get("/health")
async def health() -> Response:
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate() -> Response:
state = _require_state()
async with aiohttp.ClientSession() as session:
tasks = []
for server in chain(
(url for url, _ in state.config.prefill_urls),
state.config.decode_urls,
):
tasks.append(session.get(f"{server}/health_generate"))
for response in asyncio.as_completed(tasks):
async with await response:
pass
return Response(status_code=200)
@app.get("/v1/models")
async def models() -> ORJSONResponse:
state = _require_state()
async with aiohttp.ClientSession() as session:
async with session.get(f"{state.config.prefill_urls[0][0]}/v1/models") as response:
payload = await response.json()
return ORJSONResponse(payload, status_code=response.status)
@app.post("/v1/chat/completions")
async def chat_completions(request: Request) -> Response:
request_data = await request.json()
headers = {key.lower(): value for key, value in request.headers.items()}
return await _forward_to_backend(
request_data=request_data,
headers=headers,
endpoint_name="v1/chat/completions",
)
@app.post("/v1/completions")
async def completions(request: Request) -> Response:
request_data = await request.json()
headers = {key.lower(): value for key, value in request.headers.items()}
return await _forward_to_backend(
request_data=request_data,
headers=headers,
endpoint_name="v1/completions",
)
@app.post("/generate")
async def generate(request: Request) -> Response:
request_data = await request.json()
headers = {key.lower(): value for key, value in request.headers.items()}
return await _forward_to_backend(
request_data=request_data,
headers=headers,
endpoint_name="generate",
)
async def _forward_to_backend(
*,
request_data: dict,
headers: dict[str, str],
endpoint_name: str,
) -> Response:
state = _require_state()
prefill_server, bootstrap_port, decode_server = state.select_pair(headers)
prefill_request, decode_request = _build_backend_requests(
request_data=request_data,
prefill_server=prefill_server,
bootstrap_port=bootstrap_port,
)
if request_data.get("stream", False):
return StreamingResponse(
_stream_generate(
prefill_request=prefill_request,
decode_request=decode_request,
prefill_server=prefill_server,
decode_server=decode_server,
endpoint_name=endpoint_name,
timeout_s=state.config.request_timeout_s,
),
media_type="text/event-stream",
)
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=state.config.request_timeout_s)
) as session:
prefill_response, decode_response = await asyncio.gather(
session.post(f"{prefill_server}/{endpoint_name}", json=prefill_request),
session.post(f"{decode_server}/{endpoint_name}", json=decode_request),
)
async with prefill_response:
await prefill_response.read()
async with decode_response:
body = await decode_response.read()
return Response(
content=body,
status_code=decode_response.status,
media_type=decode_response.content_type,
)
async def _stream_generate(
*,
prefill_request: dict,
decode_request: dict,
prefill_server: str,
decode_server: str,
endpoint_name: str,
timeout_s: float,
) -> AsyncIterator[bytes]:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=timeout_s)
) as session:
prefill_response, decode_response = await asyncio.gather(
session.post(f"{prefill_server}/{endpoint_name}", json=prefill_request),
session.post(f"{decode_server}/{endpoint_name}", json=decode_request),
)
async with prefill_response, decode_response:
if decode_response.status != HTTPStatus.OK:
payload = await decode_response.read()
yield payload
return
async for chunk in decode_response.content.iter_chunked(_STREAM_CHUNK_SIZE):
yield chunk
def _build_bootstrap_payload(prefill_server: str, bootstrap_port: int) -> dict[str, object]:
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
if hostname is None:
raise HTTPException(
status_code=500,
detail=f"Unable to parse prefill hostname from {prefill_server}",
)
return {
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": random.randint(0, 2**63 - 1),
}
def _build_backend_requests(
*,
request_data: dict,
prefill_server: str,
bootstrap_port: int,
) -> tuple[dict, dict]:
prefill_priority = request_data.get("smg_prefill_priority")
decode_priority = request_data.get("smg_decode_priority")
prefill_request = _strip_internal_fields(request_data)
decode_request = _strip_internal_fields(request_data)
bootstrap_payload = _build_bootstrap_payload(prefill_server, bootstrap_port)
prefill_request.update(bootstrap_payload)
decode_request.update(bootstrap_payload)
if prefill_priority is not None:
prefill_request["priority"] = int(prefill_priority)
if decode_priority is not None:
decode_request["priority"] = int(decode_priority)
return prefill_request, decode_request
def _strip_internal_fields(request_data: dict) -> dict:
cleaned = request_data.copy()
cleaned.pop("smg_prefill_priority", None)
cleaned.pop("smg_decode_priority", None)
return cleaned
def _require_state() -> RouterState:
if router_state is None:
raise HTTPException(status_code=500, detail="router not initialized")
return router_state
def main() -> None:
parser = argparse.ArgumentParser(description="Minimal local PD router")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--prefill",
nargs=2,
metavar=("URL", "BOOTSTRAP_PORT"),
action="append",
required=True,
)
parser.add_argument(
"--decode",
action="append",
required=True,
)
parser.add_argument("--prefill-policy", default="round_robin")
parser.add_argument("--decode-policy", default="manual")
parser.add_argument("--request-timeout-s", type=float, default=1800.0)
args = parser.parse_args()
global router_state
router_state = RouterState(
RouterConfig(
host=args.host,
port=args.port,
prefill_urls=[(url, int(port)) for url, port in args.prefill],
decode_urls=list(args.decode),
prefill_policy=args.prefill_policy,
decode_policy=args.decode_policy,
request_timeout_s=args.request_timeout_s,
)
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()