from __future__ import annotations import argparse import asyncio import random import urllib.parse from dataclasses import dataclass from http import HTTPStatus from itertools import chain from typing import AsyncIterator import aiohttp import orjson import uvicorn from fastapi import FastAPI, HTTPException, Request from fastapi.responses import ORJSONResponse, Response, StreamingResponse _STREAM_CHUNK_SIZE = 1024 * 64 @dataclass class RouterConfig: host: str port: int prefill_urls: list[tuple[str, int]] decode_urls: list[str] prefill_policy: str = "round_robin" decode_policy: str = "manual" request_timeout_s: float = 1800.0 class RouterState: def __init__(self, config: RouterConfig): if not config.prefill_urls: raise ValueError("At least one prefill worker is required") if not config.decode_urls: raise ValueError("At least one decode worker is required") self.config = config self.prefill_cursor = 0 self.decode_cursor = 0 self.sticky_decode_map: dict[str, int] = {} def select_pair(self, headers: dict[str, str]) -> tuple[str, int, str]: prefill_url, bootstrap_port = self.config.prefill_urls[ self.prefill_cursor % len(self.config.prefill_urls) ] self.prefill_cursor += 1 decode_index = self._select_decode_index(headers) return prefill_url, bootstrap_port, self.config.decode_urls[decode_index] def _select_decode_index(self, headers: dict[str, str]) -> int: target_worker = headers.get("x-smg-target-worker") routing_key = headers.get("x-smg-routing-key") if ( self.config.decode_policy == "consistent_hashing" and target_worker is not None ): idx = int(target_worker) if 0 <= idx < len(self.config.decode_urls): return idx if self.config.decode_policy == "manual" and routing_key: cached = self.sticky_decode_map.get(routing_key) if cached is not None: return cached idx = self.decode_cursor % len(self.config.decode_urls) self.decode_cursor += 1 self.sticky_decode_map[routing_key] = idx return idx idx = self.decode_cursor % len(self.config.decode_urls) self.decode_cursor += 1 return idx app = FastAPI() router_state: RouterState | None = None @app.get("/health") async def health() -> Response: return Response(status_code=200) @app.get("/health_generate") async def health_generate() -> Response: state = _require_state() async with aiohttp.ClientSession() as session: tasks = [] for server in chain( (url for url, _ in state.config.prefill_urls), state.config.decode_urls, ): tasks.append(session.get(f"{server}/health_generate")) for response in asyncio.as_completed(tasks): async with await response: pass return Response(status_code=200) @app.get("/v1/models") async def models() -> ORJSONResponse: state = _require_state() async with aiohttp.ClientSession() as session: async with session.get(f"{state.config.prefill_urls[0][0]}/v1/models") as response: payload = await response.json() return ORJSONResponse(payload, status_code=response.status) @app.post("/v1/chat/completions") async def chat_completions(request: Request) -> Response: request_data = await request.json() headers = {key.lower(): value for key, value in request.headers.items()} return await _forward_to_backend( request_data=request_data, headers=headers, endpoint_name="v1/chat/completions", ) @app.post("/v1/completions") async def completions(request: Request) -> Response: request_data = await request.json() headers = {key.lower(): value for key, value in request.headers.items()} return await _forward_to_backend( request_data=request_data, headers=headers, endpoint_name="v1/completions", ) @app.post("/generate") async def generate(request: Request) -> Response: request_data = await request.json() headers = {key.lower(): value for key, value in request.headers.items()} return await _forward_to_backend( request_data=request_data, headers=headers, endpoint_name="generate", ) async def _forward_to_backend( *, request_data: dict, headers: dict[str, str], endpoint_name: str, ) -> Response: state = _require_state() prefill_server, bootstrap_port, decode_server = state.select_pair(headers) prefill_request, decode_request = _build_backend_requests( request_data=request_data, prefill_server=prefill_server, bootstrap_port=bootstrap_port, ) if request_data.get("stream", False): return StreamingResponse( _stream_generate( prefill_request=prefill_request, decode_request=decode_request, prefill_server=prefill_server, decode_server=decode_server, endpoint_name=endpoint_name, timeout_s=state.config.request_timeout_s, ), media_type="text/event-stream", ) async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=state.config.request_timeout_s) ) as session: prefill_response, decode_response = await asyncio.gather( session.post(f"{prefill_server}/{endpoint_name}", json=prefill_request), session.post(f"{decode_server}/{endpoint_name}", json=decode_request), ) async with prefill_response: await prefill_response.read() async with decode_response: body = await decode_response.read() return Response( content=body, status_code=decode_response.status, media_type=decode_response.content_type, ) async def _stream_generate( *, prefill_request: dict, decode_request: dict, prefill_server: str, decode_server: str, endpoint_name: str, timeout_s: float, ) -> AsyncIterator[bytes]: async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=timeout_s) ) as session: prefill_response, decode_response = await asyncio.gather( session.post(f"{prefill_server}/{endpoint_name}", json=prefill_request), session.post(f"{decode_server}/{endpoint_name}", json=decode_request), ) async with prefill_response, decode_response: if decode_response.status != HTTPStatus.OK: payload = await decode_response.read() yield payload return async for chunk in decode_response.content.iter_chunked(_STREAM_CHUNK_SIZE): yield chunk def _build_bootstrap_payload(prefill_server: str, bootstrap_port: int) -> dict[str, object]: parsed_url = urllib.parse.urlparse(prefill_server) hostname = parsed_url.hostname if hostname is None: raise HTTPException( status_code=500, detail=f"Unable to parse prefill hostname from {prefill_server}", ) return { "bootstrap_host": hostname, "bootstrap_port": bootstrap_port, "bootstrap_room": random.randint(0, 2**63 - 1), } def _build_backend_requests( *, request_data: dict, prefill_server: str, bootstrap_port: int, ) -> tuple[dict, dict]: prefill_priority = request_data.get("smg_prefill_priority") decode_priority = request_data.get("smg_decode_priority") prefill_request = _strip_internal_fields(request_data) decode_request = _strip_internal_fields(request_data) bootstrap_payload = _build_bootstrap_payload(prefill_server, bootstrap_port) prefill_request.update(bootstrap_payload) decode_request.update(bootstrap_payload) if prefill_priority is not None: prefill_request["priority"] = int(prefill_priority) if decode_priority is not None: decode_request["priority"] = int(decode_priority) return prefill_request, decode_request def _strip_internal_fields(request_data: dict) -> dict: cleaned = request_data.copy() cleaned.pop("smg_prefill_priority", None) cleaned.pop("smg_decode_priority", None) return cleaned def _require_state() -> RouterState: if router_state is None: raise HTTPException(status_code=500, detail="router not initialized") return router_state def main() -> None: parser = argparse.ArgumentParser(description="Minimal local PD router") parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", type=int, default=8000) parser.add_argument( "--prefill", nargs=2, metavar=("URL", "BOOTSTRAP_PORT"), action="append", required=True, ) parser.add_argument( "--decode", action="append", required=True, ) parser.add_argument("--prefill-policy", default="round_robin") parser.add_argument("--decode-policy", default="manual") parser.add_argument("--request-timeout-s", type=float, default=1800.0) args = parser.parse_args() global router_state router_state = RouterState( RouterConfig( host=args.host, port=args.port, prefill_urls=[(url, int(port)) for url, port in args.prefill], decode_urls=list(args.decode), prefill_policy=args.prefill_policy, decode_policy=args.decode_policy, request_timeout_s=args.request_timeout_s, ) ) uvicorn.run(app, host=args.host, port=args.port, log_level="info") if __name__ == "__main__": main()