Add kvcache-centric profiling and admission controls
This commit is contained in:
@@ -149,13 +149,17 @@ async def _forward_to_backend(
|
||||
) -> Response:
|
||||
state = _require_state()
|
||||
prefill_server, bootstrap_port, decode_server = state.select_pair(headers)
|
||||
modified_request = request_data.copy()
|
||||
modified_request.update(_build_bootstrap_payload(prefill_server, bootstrap_port))
|
||||
prefill_request, decode_request = _build_backend_requests(
|
||||
request_data=request_data,
|
||||
prefill_server=prefill_server,
|
||||
bootstrap_port=bootstrap_port,
|
||||
)
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return StreamingResponse(
|
||||
_stream_generate(
|
||||
modified_request=modified_request,
|
||||
prefill_request=prefill_request,
|
||||
decode_request=decode_request,
|
||||
prefill_server=prefill_server,
|
||||
decode_server=decode_server,
|
||||
endpoint_name=endpoint_name,
|
||||
@@ -168,8 +172,8 @@ async def _forward_to_backend(
|
||||
timeout=aiohttp.ClientTimeout(total=state.config.request_timeout_s)
|
||||
) as session:
|
||||
prefill_response, decode_response = await asyncio.gather(
|
||||
session.post(f"{prefill_server}/{endpoint_name}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint_name}", json=modified_request),
|
||||
session.post(f"{prefill_server}/{endpoint_name}", json=prefill_request),
|
||||
session.post(f"{decode_server}/{endpoint_name}", json=decode_request),
|
||||
)
|
||||
async with prefill_response:
|
||||
await prefill_response.read()
|
||||
@@ -184,7 +188,8 @@ async def _forward_to_backend(
|
||||
|
||||
async def _stream_generate(
|
||||
*,
|
||||
modified_request: dict,
|
||||
prefill_request: dict,
|
||||
decode_request: dict,
|
||||
prefill_server: str,
|
||||
decode_server: str,
|
||||
endpoint_name: str,
|
||||
@@ -194,8 +199,8 @@ async def _stream_generate(
|
||||
timeout=aiohttp.ClientTimeout(total=timeout_s)
|
||||
) as session:
|
||||
prefill_response, decode_response = await asyncio.gather(
|
||||
session.post(f"{prefill_server}/{endpoint_name}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint_name}", json=modified_request),
|
||||
session.post(f"{prefill_server}/{endpoint_name}", json=prefill_request),
|
||||
session.post(f"{decode_server}/{endpoint_name}", json=decode_request),
|
||||
)
|
||||
async with prefill_response, decode_response:
|
||||
if decode_response.status != HTTPStatus.OK:
|
||||
@@ -221,6 +226,35 @@ def _build_bootstrap_payload(prefill_server: str, bootstrap_port: int) -> dict[s
|
||||
}
|
||||
|
||||
|
||||
def _build_backend_requests(
|
||||
*,
|
||||
request_data: dict,
|
||||
prefill_server: str,
|
||||
bootstrap_port: int,
|
||||
) -> tuple[dict, dict]:
|
||||
prefill_priority = request_data.get("smg_prefill_priority")
|
||||
decode_priority = request_data.get("smg_decode_priority")
|
||||
|
||||
prefill_request = _strip_internal_fields(request_data)
|
||||
decode_request = _strip_internal_fields(request_data)
|
||||
bootstrap_payload = _build_bootstrap_payload(prefill_server, bootstrap_port)
|
||||
prefill_request.update(bootstrap_payload)
|
||||
decode_request.update(bootstrap_payload)
|
||||
|
||||
if prefill_priority is not None:
|
||||
prefill_request["priority"] = int(prefill_priority)
|
||||
if decode_priority is not None:
|
||||
decode_request["priority"] = int(decode_priority)
|
||||
return prefill_request, decode_request
|
||||
|
||||
|
||||
def _strip_internal_fields(request_data: dict) -> dict:
|
||||
cleaned = request_data.copy()
|
||||
cleaned.pop("smg_prefill_priority", None)
|
||||
cleaned.pop("smg_decode_priority", None)
|
||||
return cleaned
|
||||
|
||||
|
||||
def _require_state() -> RouterState:
|
||||
if router_state is None:
|
||||
raise HTTPException(status_code=500, detail="router not initialized")
|
||||
|
||||
Reference in New Issue
Block a user