Add kvcache-centric profiling and admission controls

This commit is contained in:
2026-04-25 16:00:52 +00:00
parent 08b13d22bc
commit 13bb31a446
9 changed files with 1044 additions and 34 deletions

View File

@@ -149,13 +149,17 @@ async def _forward_to_backend(
) -> Response:
state = _require_state()
prefill_server, bootstrap_port, decode_server = state.select_pair(headers)
modified_request = request_data.copy()
modified_request.update(_build_bootstrap_payload(prefill_server, bootstrap_port))
prefill_request, decode_request = _build_backend_requests(
request_data=request_data,
prefill_server=prefill_server,
bootstrap_port=bootstrap_port,
)
if request_data.get("stream", False):
return StreamingResponse(
_stream_generate(
modified_request=modified_request,
prefill_request=prefill_request,
decode_request=decode_request,
prefill_server=prefill_server,
decode_server=decode_server,
endpoint_name=endpoint_name,
@@ -168,8 +172,8 @@ async def _forward_to_backend(
timeout=aiohttp.ClientTimeout(total=state.config.request_timeout_s)
) as session:
prefill_response, decode_response = await asyncio.gather(
session.post(f"{prefill_server}/{endpoint_name}", json=modified_request),
session.post(f"{decode_server}/{endpoint_name}", json=modified_request),
session.post(f"{prefill_server}/{endpoint_name}", json=prefill_request),
session.post(f"{decode_server}/{endpoint_name}", json=decode_request),
)
async with prefill_response:
await prefill_response.read()
@@ -184,7 +188,8 @@ async def _forward_to_backend(
async def _stream_generate(
*,
modified_request: dict,
prefill_request: dict,
decode_request: dict,
prefill_server: str,
decode_server: str,
endpoint_name: str,
@@ -194,8 +199,8 @@ async def _stream_generate(
timeout=aiohttp.ClientTimeout(total=timeout_s)
) as session:
prefill_response, decode_response = await asyncio.gather(
session.post(f"{prefill_server}/{endpoint_name}", json=modified_request),
session.post(f"{decode_server}/{endpoint_name}", json=modified_request),
session.post(f"{prefill_server}/{endpoint_name}", json=prefill_request),
session.post(f"{decode_server}/{endpoint_name}", json=decode_request),
)
async with prefill_response, decode_response:
if decode_response.status != HTTPStatus.OK:
@@ -221,6 +226,35 @@ def _build_bootstrap_payload(prefill_server: str, bootstrap_port: int) -> dict[s
}
def _build_backend_requests(
*,
request_data: dict,
prefill_server: str,
bootstrap_port: int,
) -> tuple[dict, dict]:
prefill_priority = request_data.get("smg_prefill_priority")
decode_priority = request_data.get("smg_decode_priority")
prefill_request = _strip_internal_fields(request_data)
decode_request = _strip_internal_fields(request_data)
bootstrap_payload = _build_bootstrap_payload(prefill_server, bootstrap_port)
prefill_request.update(bootstrap_payload)
decode_request.update(bootstrap_payload)
if prefill_priority is not None:
prefill_request["priority"] = int(prefill_priority)
if decode_priority is not None:
decode_request["priority"] = int(decode_priority)
return prefill_request, decode_request
def _strip_internal_fields(request_data: dict) -> dict:
cleaned = request_data.copy()
cleaned.pop("smg_prefill_priority", None)
cleaned.pop("smg_decode_priority", None)
return cleaned
def _require_state() -> RouterState:
if router_state is None:
raise HTTPException(status_code=500, detail="router not initialized")