Add kvcache-centric profiling and admission controls

This commit is contained in:
2026-04-25 16:00:52 +00:00
parent 08b13d22bc
commit 13bb31a446
9 changed files with 1044 additions and 34 deletions

View File

@@ -27,10 +27,20 @@ class BenchmarkConfig:
time_scale: float = 1.0 time_scale: float = 1.0
concurrency_limit: int = 32 concurrency_limit: int = 32
timeout_s: float = 1200.0 timeout_s: float = 1200.0
request_timeout_s: float | None = None
stream: bool = True stream: bool = True
stream_idle_timeout_s: float | None = 900.0 stream_idle_timeout_s: float | None = 900.0
kvcache_direct_max_uncached_tokens: int = 2048 kvcache_direct_max_uncached_tokens: int = 2048
kvcache_admission_mode: str = "router" kvcache_admission_mode: str = "router"
kvcache_seed_max_resident_tokens: int | None = None
kvcache_seed_max_output_tokens: int | None = None
kvcache_seed_min_turn_id: int = 1
kvcache_seed_only_multiturn_sessions: bool = False
kvcache_prefill_backup_policy: str = "release-after-transfer"
kvcache_seed_max_inflight_decode: int | None = 3
kvcache_prefill_priority_eviction: bool = False
kvcache_prefill_direct_priority: int = -100
kvcache_prefill_normal_priority: int = 100
sample_profile: str = "default" sample_profile: str = "default"
min_initial_input_tokens: int | None = None min_initial_input_tokens: int | None = None
max_initial_input_tokens: int | None = None max_initial_input_tokens: int | None = None
@@ -59,10 +69,17 @@ def run_live_benchmark(config: BenchmarkConfig) -> BenchmarkArtifacts:
topology = config.topology topology = config.topology
if config.mechanism_name == "kvcache-centric": if config.mechanism_name == "kvcache-centric":
prefill_extra_server_args = topology.prefill_extra_server_args + (
"--enable-streaming-session",
)
if config.kvcache_prefill_priority_eviction:
prefill_extra_server_args = prefill_extra_server_args + (
"--radix-eviction-policy",
"priority",
)
topology = replace( topology = replace(
topology, topology,
prefill_extra_server_args=topology.prefill_extra_server_args prefill_extra_server_args=prefill_extra_server_args,
+ ("--enable-streaming-session",),
decode_extra_server_args=topology.decode_extra_server_args decode_extra_server_args=topology.decode_extra_server_args
+ ( + (
"--enable-streaming-session", "--enable-streaming-session",
@@ -107,6 +124,11 @@ def run_live_benchmark(config: BenchmarkConfig) -> BenchmarkArtifacts:
prefill_policy="round_robin", prefill_policy="round_robin",
decode_policy=_decode_policy_for(config.policy_name), decode_policy=_decode_policy_for(config.policy_name),
timeout_s=config.timeout_s, timeout_s=config.timeout_s,
router_request_timeout_s=(
config.request_timeout_s
if config.request_timeout_s is not None
else config.timeout_s
),
include_router=( include_router=(
config.mechanism_name in {"pd-disaggregation", "kvcache-centric"} config.mechanism_name in {"pd-disaggregation", "kvcache-centric"}
), ),
@@ -142,7 +164,27 @@ def run_live_benchmark(config: BenchmarkConfig) -> BenchmarkArtifacts:
stream_idle_timeout_s=config.stream_idle_timeout_s, stream_idle_timeout_s=config.stream_idle_timeout_s,
kvcache_direct_max_uncached_tokens=config.kvcache_direct_max_uncached_tokens, kvcache_direct_max_uncached_tokens=config.kvcache_direct_max_uncached_tokens,
kvcache_admission_mode=config.kvcache_admission_mode, # type: ignore[arg-type] kvcache_admission_mode=config.kvcache_admission_mode, # type: ignore[arg-type]
kvcache_seed_max_resident_tokens=config.kvcache_seed_max_resident_tokens,
kvcache_seed_max_output_tokens=config.kvcache_seed_max_output_tokens,
kvcache_seed_min_turn_id=config.kvcache_seed_min_turn_id,
kvcache_seed_only_multiturn_sessions=(
config.kvcache_seed_only_multiturn_sessions
),
kvcache_prefill_backup_policy=config.kvcache_prefill_backup_policy, # type: ignore[arg-type]
kvcache_seed_max_inflight_decode=(
config.kvcache_seed_max_inflight_decode
),
kvcache_prefill_priority_eviction=(
config.kvcache_prefill_priority_eviction
),
kvcache_prefill_direct_priority=config.kvcache_prefill_direct_priority,
kvcache_prefill_normal_priority=config.kvcache_prefill_normal_priority,
) )
if config.request_timeout_s is not None:
replay_config = replace(
replay_config,
timeout_s=config.request_timeout_s,
)
asyncio.run(replay_trace(replay_config)) asyncio.run(replay_trace(replay_config))
finally: finally:
signal.signal(signal.SIGINT, previous_sigint) signal.signal(signal.SIGINT, previous_sigint)
@@ -163,10 +205,30 @@ def run_live_benchmark(config: BenchmarkConfig) -> BenchmarkArtifacts:
"time_scale": config.time_scale, "time_scale": config.time_scale,
"concurrency_limit": config.concurrency_limit, "concurrency_limit": config.concurrency_limit,
"timeout_s": config.timeout_s, "timeout_s": config.timeout_s,
"request_timeout_s": config.request_timeout_s,
"stream": config.stream, "stream": config.stream,
"stream_idle_timeout_s": config.stream_idle_timeout_s, "stream_idle_timeout_s": config.stream_idle_timeout_s,
"kvcache_direct_max_uncached_tokens": config.kvcache_direct_max_uncached_tokens, "kvcache_direct_max_uncached_tokens": config.kvcache_direct_max_uncached_tokens,
"kvcache_admission_mode": config.kvcache_admission_mode, "kvcache_admission_mode": config.kvcache_admission_mode,
"kvcache_seed_max_resident_tokens": config.kvcache_seed_max_resident_tokens,
"kvcache_seed_max_output_tokens": config.kvcache_seed_max_output_tokens,
"kvcache_seed_min_turn_id": config.kvcache_seed_min_turn_id,
"kvcache_seed_only_multiturn_sessions": (
config.kvcache_seed_only_multiturn_sessions
),
"kvcache_prefill_backup_policy": config.kvcache_prefill_backup_policy,
"kvcache_seed_max_inflight_decode": (
config.kvcache_seed_max_inflight_decode
),
"kvcache_prefill_priority_eviction": (
config.kvcache_prefill_priority_eviction
),
"kvcache_prefill_direct_priority": (
config.kvcache_prefill_direct_priority
),
"kvcache_prefill_normal_priority": (
config.kvcache_prefill_normal_priority
),
"sample_profile": config.sample_profile, "sample_profile": config.sample_profile,
"min_initial_input_tokens": config.min_initial_input_tokens, "min_initial_input_tokens": config.min_initial_input_tokens,
"max_initial_input_tokens": config.max_initial_input_tokens, "max_initial_input_tokens": config.max_initial_input_tokens,

View File

@@ -7,6 +7,7 @@ from pathlib import Path
from agentic_pd_hybrid.benchmark import BenchmarkConfig, run_live_benchmark from agentic_pd_hybrid.benchmark import BenchmarkConfig, run_live_benchmark
from agentic_pd_hybrid.launcher import build_launch_plan from agentic_pd_hybrid.launcher import build_launch_plan
from agentic_pd_hybrid.microbench import SmallAppendTraceConfig, write_small_append_trace from agentic_pd_hybrid.microbench import SmallAppendTraceConfig, write_small_append_trace
from agentic_pd_hybrid.profile import ProfileConfig, print_profile_summary, write_profile
from agentic_pd_hybrid.replay import ReplayConfig, replay_trace from agentic_pd_hybrid.replay import ReplayConfig, replay_trace
from agentic_pd_hybrid.sampling import SessionSampleConfig, sample_trace_sessions from agentic_pd_hybrid.sampling import SessionSampleConfig, sample_trace_sessions
from agentic_pd_hybrid.trace_profiles import ( from agentic_pd_hybrid.trace_profiles import (
@@ -142,6 +143,71 @@ def main() -> None:
"or query the decode worker on the critical path." "or query the decode worker on the critical path."
), ),
) )
replay.add_argument(
"--kvcache-seed-max-resident-tokens",
type=int,
default=None,
help=(
"For kvcache-centric routing, do not seed/reseed a decode session "
"when input+output tokens exceed this value."
),
)
replay.add_argument(
"--kvcache-seed-max-output-tokens",
type=int,
default=None,
help=(
"For kvcache-centric routing, do not seed/reseed a decode session "
"when output tokens exceed this value."
),
)
replay.add_argument(
"--kvcache-seed-min-turn-id",
type=int,
default=1,
help=(
"For kvcache-centric routing, do not seed/reseed a decode session "
"before this turn id."
),
)
replay.add_argument(
"--kvcache-seed-only-multiturn-sessions",
action="store_true",
help=(
"Oracle ablation for kvcache-centric routing: only seed sessions "
"that have more than one turn in the replay trace."
),
)
replay.add_argument(
"--kvcache-prefill-backup-policy",
choices=["release-after-transfer", "capacity-backup"],
default="release-after-transfer",
help=(
"For kvcache-centric seed/reseed, release the P-side session after "
"P->D transfer or keep a capacity-limited P-side backup."
),
)
replay.add_argument(
"--kvcache-seed-max-inflight-decode",
type=int,
default=3,
help=(
"For kvcache-centric routing, skip seed/reseed when the router "
"shadow inflight decode load assigned to the target D exceeds this value. "
"Use a negative value to disable this filter."
),
)
replay.add_argument(
"--kvcache-prefill-priority-eviction",
action="store_true",
help=(
"For kvcache-centric routing, mark P-side prefixes predicted to move "
"direct-to-D as lower priority. Requires P workers to use "
"--radix-eviction-policy priority."
),
)
replay.add_argument("--kvcache-prefill-direct-priority", type=int, default=-100)
replay.add_argument("--kvcache-prefill-normal-priority", type=int, default=100)
sample = subparsers.add_parser( sample = subparsers.add_parser(
"sample-sessions", "sample-sessions",
@@ -177,6 +243,17 @@ def main() -> None:
normalize.add_argument("--output-length", type=int, default=1_000) normalize.add_argument("--output-length", type=int, default=1_000)
normalize.add_argument("--max-requests", type=int, default=None) normalize.add_argument("--max-requests", type=int, default=None)
profile = subparsers.add_parser(
"profile",
help="Profile a trace and optional request metrics for routing analysis",
)
profile.add_argument("--trace", type=Path, required=True)
profile.add_argument("--output", type=Path, default=None)
profile.add_argument("--metrics", type=Path, default=None)
profile.add_argument("--baseline-metrics", type=Path, default=None)
profile.add_argument("--candidate-metrics", type=Path, default=None)
profile.add_argument("--direct-max-uncached-tokens", type=int, default=2048)
micro = subparsers.add_parser( micro = subparsers.add_parser(
"make-small-append-trace", "make-small-append-trace",
help="Generate a synthetic multi-turn trace with small turn2+ appends", help="Generate a synthetic multi-turn trace with small turn2+ appends",
@@ -223,6 +300,15 @@ def main() -> None:
benchmark.add_argument("--time-scale", type=float, default=1.0) benchmark.add_argument("--time-scale", type=float, default=1.0)
benchmark.add_argument("--concurrency-limit", type=int, default=32) benchmark.add_argument("--concurrency-limit", type=int, default=32)
benchmark.add_argument("--timeout-s", type=float, default=1200.0) benchmark.add_argument("--timeout-s", type=float, default=1200.0)
benchmark.add_argument(
"--request-timeout-s",
type=float,
default=None,
help=(
"Per-request replay/router timeout. If unset, --timeout-s is used. "
"--timeout-s still controls stack startup."
),
)
benchmark.add_argument( benchmark.add_argument(
"--no-stream", "--no-stream",
action="store_true", action="store_true",
@@ -249,6 +335,70 @@ def main() -> None:
"or query the decode worker on the critical path." "or query the decode worker on the critical path."
), ),
) )
benchmark.add_argument(
"--kvcache-seed-max-resident-tokens",
type=int,
default=None,
help=(
"For kvcache-centric routing, do not seed/reseed a decode session "
"when input+output tokens exceed this value."
),
)
benchmark.add_argument(
"--kvcache-seed-max-output-tokens",
type=int,
default=None,
help=(
"For kvcache-centric routing, do not seed/reseed a decode session "
"when output tokens exceed this value."
),
)
benchmark.add_argument(
"--kvcache-seed-min-turn-id",
type=int,
default=1,
help=(
"For kvcache-centric routing, do not seed/reseed a decode session "
"before this turn id."
),
)
benchmark.add_argument(
"--kvcache-seed-only-multiturn-sessions",
action="store_true",
help=(
"Oracle ablation for kvcache-centric routing: only seed sessions "
"that have more than one turn in the replay trace."
),
)
benchmark.add_argument(
"--kvcache-prefill-backup-policy",
choices=["release-after-transfer", "capacity-backup"],
default="release-after-transfer",
help=(
"For kvcache-centric seed/reseed, release the P-side session after "
"P->D transfer or keep a capacity-limited P-side backup."
),
)
benchmark.add_argument(
"--kvcache-seed-max-inflight-decode",
type=int,
default=3,
help=(
"For kvcache-centric routing, skip seed/reseed when the router "
"shadow inflight decode load assigned to the target D exceeds this value. "
"Use a negative value to disable this filter."
),
)
benchmark.add_argument(
"--kvcache-prefill-priority-eviction",
action="store_true",
help=(
"For kvcache-centric benchmark-live, launch P workers with priority "
"radix eviction and mark direct-to-D predicted prefixes lower priority."
),
)
benchmark.add_argument("--kvcache-prefill-direct-priority", type=int, default=-100)
benchmark.add_argument("--kvcache-prefill-normal-priority", type=int, default=100)
benchmark.add_argument( benchmark.add_argument(
"--sample-profile", "--sample-profile",
choices=["default", "small-append"], choices=["default", "small-append"],
@@ -294,6 +444,23 @@ def main() -> None:
stream_idle_timeout_s=args.stream_idle_timeout_s, stream_idle_timeout_s=args.stream_idle_timeout_s,
kvcache_direct_max_uncached_tokens=args.kvcache_direct_max_uncached_tokens, kvcache_direct_max_uncached_tokens=args.kvcache_direct_max_uncached_tokens,
kvcache_admission_mode=args.kvcache_admission_mode, kvcache_admission_mode=args.kvcache_admission_mode,
kvcache_seed_max_resident_tokens=args.kvcache_seed_max_resident_tokens,
kvcache_seed_max_output_tokens=args.kvcache_seed_max_output_tokens,
kvcache_seed_min_turn_id=args.kvcache_seed_min_turn_id,
kvcache_seed_only_multiturn_sessions=(
args.kvcache_seed_only_multiturn_sessions
),
kvcache_prefill_backup_policy=args.kvcache_prefill_backup_policy,
kvcache_seed_max_inflight_decode=(
None
if args.kvcache_seed_max_inflight_decode < 0
else args.kvcache_seed_max_inflight_decode
),
kvcache_prefill_priority_eviction=(
args.kvcache_prefill_priority_eviction
),
kvcache_prefill_direct_priority=args.kvcache_prefill_direct_priority,
kvcache_prefill_normal_priority=args.kvcache_prefill_normal_priority,
) )
results = asyncio.run(replay_trace(config)) results = asyncio.run(replay_trace(config))
print( print(
@@ -302,6 +469,26 @@ def main() -> None:
) )
return return
if args.command == "profile":
if (args.baseline_metrics is None) != (args.candidate_metrics is None):
raise ValueError(
"--baseline-metrics and --candidate-metrics must be provided together"
)
report = write_profile(
ProfileConfig(
trace_path=args.trace,
output_path=args.output,
metrics_path=args.metrics,
baseline_metrics_path=args.baseline_metrics,
candidate_metrics_path=args.candidate_metrics,
direct_max_uncached_tokens=args.direct_max_uncached_tokens,
)
)
print_profile_summary(report)
if args.output is not None:
print(f"wrote profile to {args.output}")
return
if args.command == "sample-sessions": if args.command == "sample-sessions":
summary = sample_trace_sessions( summary = sample_trace_sessions(
SessionSampleConfig( SessionSampleConfig(
@@ -378,10 +565,32 @@ def main() -> None:
time_scale=args.time_scale, time_scale=args.time_scale,
concurrency_limit=args.concurrency_limit, concurrency_limit=args.concurrency_limit,
timeout_s=args.timeout_s, timeout_s=args.timeout_s,
request_timeout_s=args.request_timeout_s,
stream=not args.no_stream, stream=not args.no_stream,
stream_idle_timeout_s=args.stream_idle_timeout_s, stream_idle_timeout_s=args.stream_idle_timeout_s,
kvcache_direct_max_uncached_tokens=args.kvcache_direct_max_uncached_tokens, kvcache_direct_max_uncached_tokens=args.kvcache_direct_max_uncached_tokens,
kvcache_admission_mode=args.kvcache_admission_mode, kvcache_admission_mode=args.kvcache_admission_mode,
kvcache_seed_max_resident_tokens=args.kvcache_seed_max_resident_tokens,
kvcache_seed_max_output_tokens=args.kvcache_seed_max_output_tokens,
kvcache_seed_min_turn_id=args.kvcache_seed_min_turn_id,
kvcache_seed_only_multiturn_sessions=(
args.kvcache_seed_only_multiturn_sessions
),
kvcache_prefill_backup_policy=args.kvcache_prefill_backup_policy,
kvcache_seed_max_inflight_decode=(
None
if args.kvcache_seed_max_inflight_decode < 0
else args.kvcache_seed_max_inflight_decode
),
kvcache_prefill_priority_eviction=(
args.kvcache_prefill_priority_eviction
),
kvcache_prefill_direct_priority=(
args.kvcache_prefill_direct_priority
),
kvcache_prefill_normal_priority=(
args.kvcache_prefill_normal_priority
),
sample_profile=args.sample_profile, sample_profile=args.sample_profile,
min_initial_input_tokens=args.min_initial_input_tokens, min_initial_input_tokens=args.min_initial_input_tokens,
max_initial_input_tokens=args.max_initial_input_tokens, max_initial_input_tokens=args.max_initial_input_tokens,

View File

@@ -33,6 +33,7 @@ def build_launch_plan(
prefill_policy: str = "round_robin", prefill_policy: str = "round_robin",
decode_policy: str = "manual", decode_policy: str = "manual",
include_router: bool = True, include_router: bool = True,
router_request_timeout_s: float | None = None,
) -> LaunchPlan: ) -> LaunchPlan:
return LaunchPlan( return LaunchPlan(
prefill_commands=tuple( prefill_commands=tuple(
@@ -49,6 +50,7 @@ def build_launch_plan(
topology, topology,
prefill_policy=prefill_policy, prefill_policy=prefill_policy,
decode_policy=decode_policy, decode_policy=decode_policy,
request_timeout_s=router_request_timeout_s,
) )
if include_router and topology.prefill_workers and topology.decode_workers if include_router and topology.prefill_workers and topology.decode_workers
else None else None
@@ -105,6 +107,7 @@ def _build_router_command(
*, *,
prefill_policy: str, prefill_policy: str,
decode_policy: str, decode_policy: str,
request_timeout_s: float | None,
) -> tuple[str, ...]: ) -> tuple[str, ...]:
command: list[str] = [ command: list[str] = [
sys.executable, sys.executable,
@@ -121,6 +124,8 @@ def _build_router_command(
"--decode-policy", "--decode-policy",
decode_policy, decode_policy,
] ]
if request_timeout_s is not None:
command.extend(["--request-timeout-s", str(request_timeout_s)])
for worker in topology.prefill_workers: for worker in topology.prefill_workers:
command.extend( command.extend(
["--prefill", worker.url, str(worker.bootstrap_port or topology.router_port)] ["--prefill", worker.url, str(worker.bootstrap_port or topology.router_port)]

View File

@@ -33,6 +33,8 @@ class RequestMetrics:
kv_transfer_blocks: int kv_transfer_blocks: int
actual_kv_transfer_blocks: int actual_kv_transfer_blocks: int
cached_tokens: int cached_tokens: int
prefill_request_priority: int | None
decode_request_priority: int | None
re_prefill_required: bool re_prefill_required: bool
effective_input_length: int | None effective_input_length: int | None
session_reused: bool session_reused: bool
@@ -58,6 +60,8 @@ class RequestMetrics:
latency_s: float | None, latency_s: float | None,
ttft_s: float | None, ttft_s: float | None,
tpot_s: float | None, tpot_s: float | None,
prefill_request_priority: int | None = None,
decode_request_priority: int | None = None,
error: str | None = None, error: str | None = None,
) -> "RequestMetrics": ) -> "RequestMetrics":
return cls( return cls(
@@ -81,6 +85,8 @@ class RequestMetrics:
kv_transfer_blocks=decision.kv_transfer_blocks, kv_transfer_blocks=decision.kv_transfer_blocks,
actual_kv_transfer_blocks=actual_kv_transfer_blocks, actual_kv_transfer_blocks=actual_kv_transfer_blocks,
cached_tokens=cached_tokens, cached_tokens=cached_tokens,
prefill_request_priority=prefill_request_priority,
decode_request_priority=decode_request_priority,
re_prefill_required=decision.re_prefill_required, re_prefill_required=decision.re_prefill_required,
effective_input_length=effective_input_length, effective_input_length=effective_input_length,
session_reused=session_reused, session_reused=session_reused,
@@ -111,6 +117,16 @@ def write_summary_json(
tpots = [row.tpot_s for row in rows if row.tpot_s is not None] tpots = [row.tpot_s for row in rows if row.tpot_s is not None]
per_decode_load = Counter(row.assigned_decode_node for row in rows) per_decode_load = Counter(row.assigned_decode_node for row in rows)
per_prefill_load = Counter(row.assigned_prefill_node for row in rows) per_prefill_load = Counter(row.assigned_prefill_node for row in rows)
prefill_priorities = Counter(
row.prefill_request_priority
for row in rows
if row.prefill_request_priority is not None
)
decode_priorities = Counter(
row.decode_request_priority
for row in rows
if row.decode_request_priority is not None
)
summary: dict[str, Any] = { summary: dict[str, Any] = {
"trace_path": str(trace_path), "trace_path": str(trace_path),
@@ -135,6 +151,12 @@ def write_summary_json(
), ),
"per_decode_load": dict(sorted(per_decode_load.items())), "per_decode_load": dict(sorted(per_decode_load.items())),
"per_prefill_load": dict(sorted(per_prefill_load.items())), "per_prefill_load": dict(sorted(per_prefill_load.items())),
"prefill_request_priorities": {
str(key): value for key, value in sorted(prefill_priorities.items())
},
"decode_request_priorities": {
str(key): value for key, value in sorted(decode_priorities.items())
},
"error_count": sum(1 for row in rows if row.error is not None), "error_count": sum(1 for row in rows if row.error is not None),
} }
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)

View File

@@ -149,13 +149,17 @@ async def _forward_to_backend(
) -> Response: ) -> Response:
state = _require_state() state = _require_state()
prefill_server, bootstrap_port, decode_server = state.select_pair(headers) prefill_server, bootstrap_port, decode_server = state.select_pair(headers)
modified_request = request_data.copy() prefill_request, decode_request = _build_backend_requests(
modified_request.update(_build_bootstrap_payload(prefill_server, bootstrap_port)) request_data=request_data,
prefill_server=prefill_server,
bootstrap_port=bootstrap_port,
)
if request_data.get("stream", False): if request_data.get("stream", False):
return StreamingResponse( return StreamingResponse(
_stream_generate( _stream_generate(
modified_request=modified_request, prefill_request=prefill_request,
decode_request=decode_request,
prefill_server=prefill_server, prefill_server=prefill_server,
decode_server=decode_server, decode_server=decode_server,
endpoint_name=endpoint_name, endpoint_name=endpoint_name,
@@ -168,8 +172,8 @@ async def _forward_to_backend(
timeout=aiohttp.ClientTimeout(total=state.config.request_timeout_s) timeout=aiohttp.ClientTimeout(total=state.config.request_timeout_s)
) as session: ) as session:
prefill_response, decode_response = await asyncio.gather( prefill_response, decode_response = await asyncio.gather(
session.post(f"{prefill_server}/{endpoint_name}", json=modified_request), session.post(f"{prefill_server}/{endpoint_name}", json=prefill_request),
session.post(f"{decode_server}/{endpoint_name}", json=modified_request), session.post(f"{decode_server}/{endpoint_name}", json=decode_request),
) )
async with prefill_response: async with prefill_response:
await prefill_response.read() await prefill_response.read()
@@ -184,7 +188,8 @@ async def _forward_to_backend(
async def _stream_generate( async def _stream_generate(
*, *,
modified_request: dict, prefill_request: dict,
decode_request: dict,
prefill_server: str, prefill_server: str,
decode_server: str, decode_server: str,
endpoint_name: str, endpoint_name: str,
@@ -194,8 +199,8 @@ async def _stream_generate(
timeout=aiohttp.ClientTimeout(total=timeout_s) timeout=aiohttp.ClientTimeout(total=timeout_s)
) as session: ) as session:
prefill_response, decode_response = await asyncio.gather( prefill_response, decode_response = await asyncio.gather(
session.post(f"{prefill_server}/{endpoint_name}", json=modified_request), session.post(f"{prefill_server}/{endpoint_name}", json=prefill_request),
session.post(f"{decode_server}/{endpoint_name}", json=modified_request), session.post(f"{decode_server}/{endpoint_name}", json=decode_request),
) )
async with prefill_response, decode_response: async with prefill_response, decode_response:
if decode_response.status != HTTPStatus.OK: if decode_response.status != HTTPStatus.OK:
@@ -221,6 +226,35 @@ def _build_bootstrap_payload(prefill_server: str, bootstrap_port: int) -> dict[s
} }
def _build_backend_requests(
*,
request_data: dict,
prefill_server: str,
bootstrap_port: int,
) -> tuple[dict, dict]:
prefill_priority = request_data.get("smg_prefill_priority")
decode_priority = request_data.get("smg_decode_priority")
prefill_request = _strip_internal_fields(request_data)
decode_request = _strip_internal_fields(request_data)
bootstrap_payload = _build_bootstrap_payload(prefill_server, bootstrap_port)
prefill_request.update(bootstrap_payload)
decode_request.update(bootstrap_payload)
if prefill_priority is not None:
prefill_request["priority"] = int(prefill_priority)
if decode_priority is not None:
decode_request["priority"] = int(decode_priority)
return prefill_request, decode_request
def _strip_internal_fields(request_data: dict) -> dict:
cleaned = request_data.copy()
cleaned.pop("smg_prefill_priority", None)
cleaned.pop("smg_decode_priority", None)
return cleaned
def _require_state() -> RouterState: def _require_state() -> RouterState:
if router_state is None: if router_state is None:
raise HTTPException(status_code=500, detail="router not initialized") raise HTTPException(status_code=500, detail="router not initialized")

View File

@@ -0,0 +1,511 @@
from __future__ import annotations
import json
import statistics
from collections import Counter, defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable
from agentic_pd_hybrid.trace import TraceRequest, load_trace
@dataclass(frozen=True)
class ProfileConfig:
trace_path: Path
output_path: Path | None = None
metrics_path: Path | None = None
baseline_metrics_path: Path | None = None
candidate_metrics_path: Path | None = None
direct_max_uncached_tokens: int = 2048
def write_profile(config: ProfileConfig) -> dict[str, Any]:
report = build_profile(config)
if config.output_path is not None:
config.output_path.parent.mkdir(parents=True, exist_ok=True)
with config.output_path.open("w", encoding="utf-8") as handle:
json.dump(report, handle, indent=2, sort_keys=True)
return report
def build_profile(config: ProfileConfig) -> dict[str, Any]:
requests = load_trace(config.trace_path)
features = _build_trace_features(
requests,
direct_max_uncached_tokens=config.direct_max_uncached_tokens,
)
report: dict[str, Any] = {
"trace_path": str(config.trace_path),
"direct_max_uncached_tokens": config.direct_max_uncached_tokens,
"trace_profile": _trace_profile(requests, features),
}
if config.metrics_path is not None:
metrics = _load_jsonl(config.metrics_path)
report["metrics_path"] = str(config.metrics_path)
report["metrics_profile"] = _metrics_profile(metrics, features)
if (
config.baseline_metrics_path is not None
and config.candidate_metrics_path is not None
):
baseline = _load_jsonl(config.baseline_metrics_path)
candidate = _load_jsonl(config.candidate_metrics_path)
report["baseline_metrics_path"] = str(config.baseline_metrics_path)
report["candidate_metrics_path"] = str(config.candidate_metrics_path)
report["baseline_profile"] = _metrics_profile(baseline, features)
report["candidate_profile"] = _metrics_profile(candidate, features)
report["paired_comparison"] = _paired_comparison(
baseline=baseline,
candidate=candidate,
features=features,
)
return report
def print_profile_summary(report: dict[str, Any]) -> None:
trace = report["trace_profile"]
print(
"trace: "
f"{trace['request_count']} requests, "
f"{trace['session_count']} sessions, "
f"{trace['multi_turn_session_count']} multi-turn sessions"
)
print(
"direct-eligible turns: "
f"{trace['direct_eligible_turn2plus_count']}/"
f"{trace['turn2plus_count']} "
f"({trace['direct_eligible_turn2plus_ratio']:.3f})"
)
append_stats = trace.get("append_input_tokens_stats")
output_stats = trace.get("output_tokens_stats")
if append_stats is not None:
print(
"append tokens: "
f"mean={append_stats['mean']:.1f} "
f"p50={append_stats['p50']:.1f} "
f"p90={append_stats['p90']:.1f} "
f"p99={append_stats['p99']:.1f}"
)
if output_stats is not None:
print(
"output tokens: "
f"mean={output_stats['mean']:.1f} "
f"p50={output_stats['p50']:.1f} "
f"p90={output_stats['p90']:.1f} "
f"p99={output_stats['p99']:.1f}"
)
comparison = report.get("paired_comparison")
if isinstance(comparison, dict):
overall = comparison.get("overall", {})
delta = overall.get("latency_delta_s_stats")
if delta is not None:
print(
"candidate - baseline E2E: "
f"mean={delta['mean']:.3f}s "
f"p50={delta['p50']:.3f}s "
f"p90={delta['p90']:.3f}s"
)
print(
"paired wins/losses: "
f"{overall.get('candidate_faster_count', 0)} faster, "
f"{overall.get('candidate_slower_count', 0)} slower, "
f"{overall.get('paired_count', 0)} paired"
)
@dataclass(frozen=True)
class _TraceFeature:
request_id: str
session_id: str
turn_id: int
input_length: int
output_length: int
resident_tokens: int
append_input_tokens: int | None
inter_turn_gap_s: float | None
overlap_blocks_with_previous: int | None
overlap_ratio_with_previous: float | None
direct_eligible: bool
turn_class: str
append_bin: str
input_bin: str
output_bin: str
resident_bin: str
def _build_trace_features(
requests: list[TraceRequest],
*,
direct_max_uncached_tokens: int,
) -> dict[str, _TraceFeature]:
ordered_by_session: dict[str, list[TraceRequest]] = defaultdict(list)
for request in requests:
ordered_by_session[request.session_id].append(request)
previous_by_request_id: dict[str, TraceRequest | None] = {}
for session_requests in ordered_by_session.values():
ordered = sorted(
session_requests,
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
)
previous: TraceRequest | None = None
for request in ordered:
previous_by_request_id[request.request_id] = previous
previous = request
features: dict[str, _TraceFeature] = {}
for request in requests:
previous = previous_by_request_id.get(request.request_id)
append_input_tokens: int | None = None
inter_turn_gap_s: float | None = None
overlap_blocks: int | None = None
overlap_ratio: float | None = None
direct_eligible = False
if previous is not None:
append_input_tokens = request.input_length - (
previous.input_length + previous.output_length
)
inter_turn_gap_s = request.timestamp_s - previous.timestamp_s
previous_blocks = set(previous.hash_ids)
overlap_blocks = sum(1 for block in request.hash_ids if block in previous_blocks)
overlap_ratio = (
overlap_blocks / len(request.hash_ids) if request.hash_ids else 0.0
)
direct_eligible = (
append_input_tokens > 0
and append_input_tokens <= direct_max_uncached_tokens
and overlap_blocks > 0
)
features[request.request_id] = _TraceFeature(
request_id=request.request_id,
session_id=request.session_id,
turn_id=request.turn_id,
input_length=request.input_length,
output_length=request.output_length,
resident_tokens=request.input_length + request.output_length,
append_input_tokens=append_input_tokens,
inter_turn_gap_s=inter_turn_gap_s,
overlap_blocks_with_previous=overlap_blocks,
overlap_ratio_with_previous=overlap_ratio,
direct_eligible=direct_eligible,
turn_class="turn1" if request.turn_id <= 1 else "turn2plus",
append_bin=_token_bin(append_input_tokens),
input_bin=_token_bin(request.input_length),
output_bin=_token_bin(request.output_length),
resident_bin=_token_bin(request.input_length + request.output_length),
)
return features
def _trace_profile(
requests: list[TraceRequest],
features: dict[str, _TraceFeature],
) -> dict[str, Any]:
session_turns = Counter(request.session_id for request in requests)
turn2plus = [feature for feature in features.values() if feature.turn_id > 1]
direct_eligible = [feature for feature in turn2plus if feature.direct_eligible]
append_values = [
feature.append_input_tokens
for feature in turn2plus
if feature.append_input_tokens is not None
]
positive_append_values = [
value for value in append_values if value is not None and value > 0
]
overlap_ratios = [
feature.overlap_ratio_with_previous
for feature in turn2plus
if feature.overlap_ratio_with_previous is not None
]
gaps = [
feature.inter_turn_gap_s
for feature in turn2plus
if feature.inter_turn_gap_s is not None
]
return {
"request_count": len(requests),
"session_count": len(session_turns),
"multi_turn_session_count": sum(1 for turns in session_turns.values() if turns > 1),
"turn2plus_count": len(turn2plus),
"direct_eligible_turn2plus_count": len(direct_eligible),
"direct_eligible_turn2plus_ratio": (
len(direct_eligible) / len(turn2plus) if turn2plus else 0.0
),
"turn_count_distribution": dict(sorted(Counter(session_turns.values()).items())),
"request_type_distribution": dict(
sorted(Counter(request.request_type for request in requests).items())
),
"turn_id_distribution": dict(
sorted(Counter(request.turn_id for request in requests).items())
),
"append_bin_distribution": dict(
sorted(Counter(feature.append_bin for feature in turn2plus).items())
),
"input_bin_distribution": dict(
sorted(Counter(feature.input_bin for feature in features.values()).items())
),
"output_bin_distribution": dict(
sorted(Counter(feature.output_bin for feature in features.values()).items())
),
"resident_bin_distribution": dict(
sorted(Counter(feature.resident_bin for feature in features.values()).items())
),
"input_tokens_stats": _stats(
[float(request.input_length) for request in requests]
),
"output_tokens_stats": _stats(
[float(request.output_length) for request in requests]
),
"resident_tokens_stats": _stats(
[float(feature.resident_tokens) for feature in features.values()]
),
"append_input_tokens_stats": _stats(
[float(value) for value in append_values if value is not None]
),
"positive_append_input_tokens_stats": _stats(
[float(value) for value in positive_append_values]
),
"inter_turn_gap_s_stats": _stats([float(value) for value in gaps]),
"overlap_ratio_stats": _stats([float(value) for value in overlap_ratios]),
"non_positive_append_count": sum(
1 for value in append_values if value is not None and value <= 0
),
}
def _metrics_profile(
rows: list[dict[str, Any]],
features: dict[str, _TraceFeature],
) -> dict[str, Any]:
return {
"request_count": len(rows),
"mechanism_distribution": dict(
sorted(Counter(str(row.get("mechanism_name")) for row in rows).items())
),
"execution_mode_distribution": dict(
sorted(Counter(str(row.get("execution_mode")) for row in rows).items())
),
"latency_s_stats": _stats(_numeric_values(rows, "latency_s")),
"ttft_s_stats": _stats(_numeric_values(rows, "ttft_s")),
"tpot_s_stats": _stats(_numeric_values(rows, "tpot_s")),
"cached_tokens_stats": _stats(_numeric_values(rows, "cached_tokens")),
"actual_kv_transfer_blocks_stats": _stats(
_numeric_values(rows, "actual_kv_transfer_blocks")
),
"session_reused_count": sum(1 for row in rows if row.get("session_reused")),
"session_reset_count": sum(1 for row in rows if row.get("session_reset")),
"error_count": sum(1 for row in rows if row.get("error") is not None),
"by_turn_class": _group_metrics(rows, features, lambda feature, _row: feature.turn_class),
"by_direct_eligible": _group_metrics(
rows,
features,
lambda feature, _row: "eligible" if feature.direct_eligible else "not_eligible",
),
"by_append_bin": _group_metrics(rows, features, lambda feature, _row: feature.append_bin),
"by_resident_bin": _group_metrics(
rows,
features,
lambda feature, _row: feature.resident_bin,
),
"by_execution_mode": _group_metrics(
rows,
features,
lambda _feature, row: str(row.get("execution_mode")),
),
}
def _paired_comparison(
*,
baseline: list[dict[str, Any]],
candidate: list[dict[str, Any]],
features: dict[str, _TraceFeature],
) -> dict[str, Any]:
baseline_by_id = {
str(row.get("request_id")): row
for row in baseline
if row.get("latency_s") is not None
}
candidate_by_id = {
str(row.get("request_id")): row
for row in candidate
if row.get("latency_s") is not None
}
paired_ids = sorted(set(baseline_by_id) & set(candidate_by_id))
pairs = [
(baseline_by_id[request_id], candidate_by_id[request_id], features.get(request_id))
for request_id in paired_ids
]
pairs = [pair for pair in pairs if pair[2] is not None]
return {
"overall": _delta_summary(pairs),
"by_turn_class": _group_deltas(
pairs,
lambda feature, _base, _cand: feature.turn_class,
),
"by_direct_eligible": _group_deltas(
pairs,
lambda feature, _base, _cand: (
"eligible" if feature.direct_eligible else "not_eligible"
),
),
"by_append_bin": _group_deltas(
pairs,
lambda feature, _base, _cand: feature.append_bin,
),
"by_resident_bin": _group_deltas(
pairs,
lambda feature, _base, _cand: feature.resident_bin,
),
"by_candidate_execution_mode": _group_deltas(
pairs,
lambda _feature, _base, cand: str(cand.get("execution_mode")),
),
}
def _group_metrics(
rows: list[dict[str, Any]],
features: dict[str, _TraceFeature],
key_fn,
) -> dict[str, Any]:
grouped: dict[str, list[dict[str, Any]]] = defaultdict(list)
for row in rows:
feature = features.get(str(row.get("request_id")))
if feature is None:
continue
grouped[str(key_fn(feature, row))].append(row)
return {
key: {
"count": len(group_rows),
"latency_s_stats": _stats(_numeric_values(group_rows, "latency_s")),
"ttft_s_stats": _stats(_numeric_values(group_rows, "ttft_s")),
"tpot_s_stats": _stats(_numeric_values(group_rows, "tpot_s")),
"session_reused_count": sum(
1 for row in group_rows if row.get("session_reused")
),
"error_count": sum(1 for row in group_rows if row.get("error") is not None),
}
for key, group_rows in sorted(grouped.items())
}
def _group_deltas(
pairs: list[tuple[dict[str, Any], dict[str, Any], _TraceFeature | None]],
key_fn,
) -> dict[str, Any]:
grouped: dict[str, list[tuple[dict[str, Any], dict[str, Any], _TraceFeature | None]]] = (
defaultdict(list)
)
for base, cand, feature in pairs:
if feature is None:
continue
grouped[str(key_fn(feature, base, cand))].append((base, cand, feature))
return {key: _delta_summary(group_pairs) for key, group_pairs in sorted(grouped.items())}
def _delta_summary(
pairs: list[tuple[dict[str, Any], dict[str, Any], _TraceFeature | None]],
) -> dict[str, Any]:
latency_deltas = [
float(cand["latency_s"]) - float(base["latency_s"])
for base, cand, _feature in pairs
if base.get("latency_s") is not None and cand.get("latency_s") is not None
]
ttft_deltas = [
float(cand["ttft_s"]) - float(base["ttft_s"])
for base, cand, _feature in pairs
if base.get("ttft_s") is not None and cand.get("ttft_s") is not None
]
return {
"paired_count": len(latency_deltas),
"candidate_faster_count": sum(1 for delta in latency_deltas if delta < 0),
"candidate_slower_count": sum(1 for delta in latency_deltas if delta > 0),
"latency_delta_s_stats": _stats(latency_deltas),
"ttft_delta_s_stats": _stats(ttft_deltas),
"total_latency_delta_s": sum(latency_deltas),
"mean_baseline_latency_s": _mean(
[
float(base["latency_s"])
for base, cand, _feature in pairs
if base.get("latency_s") is not None and cand.get("latency_s") is not None
]
),
"mean_candidate_latency_s": _mean(
[
float(cand["latency_s"])
for base, cand, _feature in pairs
if base.get("latency_s") is not None and cand.get("latency_s") is not None
]
),
}
def _load_jsonl(path: Path) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if line.strip():
rows.append(json.loads(line))
return rows
def _numeric_values(rows: Iterable[dict[str, Any]], key: str) -> list[float]:
values: list[float] = []
for row in rows:
value = row.get(key)
if value is not None:
values.append(float(value))
return values
def _stats(values: list[float]) -> dict[str, float] | None:
clean = [value for value in values if value is not None]
if not clean:
return None
clean.sort()
return {
"count": float(len(clean)),
"mean": statistics.fmean(clean),
"p50": _percentile(clean, 0.50),
"p90": _percentile(clean, 0.90),
"p99": _percentile(clean, 0.99),
"min": clean[0],
"max": clean[-1],
}
def _percentile(sorted_values: list[float], percentile: float) -> float:
if len(sorted_values) == 1:
return sorted_values[0]
index = round((len(sorted_values) - 1) * percentile)
return sorted_values[index]
def _mean(values: list[float]) -> float | None:
if not values:
return None
return statistics.fmean(values)
def _token_bin(value: int | None) -> str:
if value is None:
return "none"
if value <= 0:
return "<=0"
if value <= 512:
return "1-512"
if value <= 2048:
return "513-2048"
if value <= 8192:
return "2049-8192"
if value <= 32768:
return "8193-32768"
return ">32768"

View File

@@ -3,7 +3,8 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import time import time
from dataclasses import dataclass, field from collections import Counter
from dataclasses import dataclass, field, replace
from pathlib import Path from pathlib import Path
from typing import Any, Literal from typing import Any, Literal
@@ -26,6 +27,8 @@ from agentic_pd_hybrid.trace import (
HeaderMode = Literal["none", "routing-key", "target-worker", "auto"] HeaderMode = Literal["none", "routing-key", "target-worker", "auto"]
KvCacheAdmissionMode = Literal["router", "worker"] KvCacheAdmissionMode = Literal["router", "worker"]
KvCachePrefillBackupPolicy = Literal["release-after-transfer", "capacity-backup"]
_ADMISSION_PROBE_TIMEOUT_S = 2.0
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -47,6 +50,18 @@ class ReplayConfig:
stream_idle_timeout_s: float | None = 900.0 stream_idle_timeout_s: float | None = 900.0
kvcache_direct_max_uncached_tokens: int = 2048 kvcache_direct_max_uncached_tokens: int = 2048
kvcache_admission_mode: KvCacheAdmissionMode = "router" kvcache_admission_mode: KvCacheAdmissionMode = "router"
kvcache_seed_max_resident_tokens: int | None = None
kvcache_seed_max_output_tokens: int | None = None
kvcache_seed_min_turn_id: int = 1
kvcache_seed_only_multiturn_sessions: bool = False
kvcache_seed_allowed_session_ids: frozenset[str] | None = None
kvcache_prefill_backup_policy: KvCachePrefillBackupPolicy = (
"release-after-transfer"
)
kvcache_seed_max_inflight_decode: int | None = 3
kvcache_prefill_priority_eviction: bool = False
kvcache_prefill_direct_priority: int = -100
kvcache_prefill_normal_priority: int = 100
@dataclass @dataclass
@@ -104,11 +119,23 @@ class ExecutionResult:
latency_s: float | None latency_s: float | None
ttft_s: float | None ttft_s: float | None
tpot_s: float | None tpot_s: float | None
prefill_request_priority: int | None = None
decode_request_priority: int | None = None
error: str | None = None error: str | None = None
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]: async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
requests = load_trace(config.trace_path, request_limit=config.request_limit) requests = load_trace(config.trace_path, request_limit=config.request_limit)
if config.kvcache_seed_only_multiturn_sessions:
session_turns = Counter(request.session_id for request in requests)
config = replace(
config,
kvcache_seed_allowed_session_ids=frozenset(
session_id
for session_id, turn_count in session_turns.items()
if turn_count > 1
),
)
policy = create_policy(config.policy_name) policy = create_policy(config.policy_name)
state = RoutingState.create(config.topology) state = RoutingState.create(config.topology)
state_lock = asyncio.Lock() state_lock = asyncio.Lock()
@@ -242,6 +269,8 @@ async def _run_request(
latency_s=execution.latency_s, latency_s=execution.latency_s,
ttft_s=execution.ttft_s, ttft_s=execution.ttft_s,
tpot_s=execution.tpot_s, tpot_s=execution.tpot_s,
prefill_request_priority=execution.prefill_request_priority,
decode_request_priority=execution.decode_request_priority,
error=execution.error, error=execution.error,
) )
@@ -253,6 +282,8 @@ async def _invoke_router(
config: ReplayConfig, config: ReplayConfig,
decode_worker_index: int, decode_worker_index: int,
session_id: str | None = None, session_id: str | None = None,
prefill_request_priority: int | None = None,
decode_request_priority: int | None = None,
) -> tuple[float, float | None, float | None, int]: ) -> tuple[float, float | None, float | None, int]:
headers = _build_headers( headers = _build_headers(
request=request, request=request,
@@ -274,6 +305,10 @@ async def _invoke_router(
} }
if session_id is not None: if session_id is not None:
payload["session_params"] = {"id": session_id} payload["session_params"] = {"id": session_id}
if prefill_request_priority is not None:
payload["smg_prefill_priority"] = prefill_request_priority
if decode_request_priority is not None:
payload["smg_decode_priority"] = decode_request_priority
return await _invoke_generate( return await _invoke_generate(
client=client, client=client,
@@ -462,6 +497,7 @@ async def _open_streaming_session(
"session_id": session_id, "session_id": session_id,
"streaming": True, "streaming": True,
}, },
timeout=_ADMISSION_PROBE_TIMEOUT_S,
) )
response.raise_for_status() response.raise_for_status()
opened_session_id = response.json() opened_session_id = response.json()
@@ -481,6 +517,7 @@ async def _close_streaming_session(
response = await client.post( response = await client.post(
f"{server_url.rstrip('/')}/close_session", f"{server_url.rstrip('/')}/close_session",
json={"session_id": session_id}, json={"session_id": session_id},
timeout=_ADMISSION_PROBE_TIMEOUT_S,
) )
if response.is_success: if response.is_success:
return return
@@ -538,7 +575,10 @@ async def _fetch_decode_server_state(
server_url: str, server_url: str,
) -> tuple[dict[str, Any], int, int]: ) -> tuple[dict[str, Any], int, int]:
try: try:
response = await client.get(f"{server_url.rstrip('/')}/server_info") response = await client.get(
f"{server_url.rstrip('/')}/server_info",
timeout=_ADMISSION_PROBE_TIMEOUT_S,
)
response.raise_for_status() response.raise_for_status()
payload = response.json() payload = response.json()
except Exception: except Exception:
@@ -567,6 +607,7 @@ async def _query_decode_direct_admission(
"uncached_input_tokens": max(0, uncached_input_tokens), "uncached_input_tokens": max(0, uncached_input_tokens),
"output_tokens": max(0, output_tokens), "output_tokens": max(0, output_tokens),
}, },
timeout=_ADMISSION_PROBE_TIMEOUT_S,
) )
response.raise_for_status() response.raise_for_status()
payload = response.json() payload = response.json()
@@ -643,6 +684,51 @@ def _estimate_session_resident_tokens(request: TraceRequest) -> int:
return request.input_length + request.output_length return request.input_length + request.output_length
def _seed_filter_reason(
*,
request: TraceRequest,
config: ReplayConfig,
inflight_decode_load: int | None = None,
) -> str | None:
if request.turn_id < config.kvcache_seed_min_turn_id:
return "seed-filter-early-turn"
if (
config.kvcache_seed_max_inflight_decode is not None
and inflight_decode_load is not None
and inflight_decode_load > config.kvcache_seed_max_inflight_decode
):
return "seed-filter-inflight-decode-load"
if (
config.kvcache_seed_allowed_session_ids is not None
and request.session_id not in config.kvcache_seed_allowed_session_ids
):
return "seed-filter-single-turn-session"
resident_tokens = _estimate_session_resident_tokens(request)
if (
config.kvcache_seed_max_resident_tokens is not None
and resident_tokens > config.kvcache_seed_max_resident_tokens
):
return "seed-filter-resident-tokens"
if (
config.kvcache_seed_max_output_tokens is not None
and request.output_length > config.kvcache_seed_max_output_tokens
):
return "seed-filter-output-tokens"
return None
def _prefill_priority_for_router_request(
*,
config: ReplayConfig,
direct_to_d_predicted: bool,
) -> int | None:
if not config.kvcache_prefill_priority_eviction:
return None
if direct_to_d_predicted:
return config.kvcache_prefill_direct_priority
return config.kvcache_prefill_normal_priority
def _inspect_direct_request( def _inspect_direct_request(
*, *,
request: TraceRequest, request: TraceRequest,
@@ -802,6 +888,7 @@ async def _fetch_decode_load_snapshot(
response = await client.get( response = await client.get(
f"{server_url.rstrip('/')}/v1/loads", f"{server_url.rstrip('/')}/v1/loads",
params={"include": "core,disagg"}, params={"include": "core,disagg"},
timeout=_ADMISSION_PROBE_TIMEOUT_S,
) )
response.raise_for_status() response.raise_for_status()
payload = response.json() payload = response.json()
@@ -865,6 +952,13 @@ def _is_decode_backpressure_reason(reason: str | None) -> bool:
} }
def _is_stale_decode_session_error(exc: Exception) -> bool:
return (
isinstance(exc, httpx.HTTPStatusError)
and exc.response.status_code == 400
)
def _dynamic_decode_headroom_tokens( def _dynamic_decode_headroom_tokens(
*, *,
residency: DecodeResidencyState, residency: DecodeResidencyState,
@@ -1469,17 +1563,23 @@ async def _invoke_plain_router(
decision, decision,
execution_mode: str, execution_mode: str,
) -> ExecutionResult: ) -> ExecutionResult:
prefill_priority = _prefill_priority_for_router_request(
config=config,
direct_to_d_predicted=False,
)
latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_router( latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_router(
client=client, client=client,
request=request, request=request,
config=config, config=config,
decode_worker_index=decision.decode_worker_index, decode_worker_index=decision.decode_worker_index,
prefill_request_priority=prefill_priority,
) )
return ExecutionResult( return ExecutionResult(
execution_mode=execution_mode, execution_mode=execution_mode,
actual_kv_transfer_blocks=decision.kv_transfer_blocks, actual_kv_transfer_blocks=decision.kv_transfer_blocks,
effective_input_length=request.input_length, effective_input_length=request.input_length,
cached_tokens=cached_tokens, cached_tokens=cached_tokens,
prefill_request_priority=prefill_priority,
session_reused=False, session_reused=False,
session_reset=False, session_reset=False,
latency_s=latency_s, latency_s=latency_s,
@@ -1502,17 +1602,20 @@ async def _invoke_kvcache_seeded_router(
reserved_tokens: int, reserved_tokens: int,
execution_mode: str, execution_mode: str,
) -> ExecutionResult: ) -> ExecutionResult:
keep_prefill_backup = False
prefill_reserved_tokens = 0
async with direct_session_lock: async with direct_session_lock:
keep_prefill_backup, prefill_reserved_tokens, _prefill_evicted = ( if config.kvcache_prefill_backup_policy == "capacity-backup":
await _reserve_prefill_backup_capacity( keep_prefill_backup, prefill_reserved_tokens, _prefill_evicted = (
client=client, await _reserve_prefill_backup_capacity(
request=request, client=client,
prefill_url=prefill_url, request=request,
session=decode_session, prefill_url=prefill_url,
direct_sessions=direct_sessions, session=decode_session,
residency=decode_residency, direct_sessions=direct_sessions,
residency=decode_residency,
)
) )
)
if ( if (
decode_session.prefill_opened decode_session.prefill_opened
and decode_session.prefill_server_url != prefill_url and decode_session.prefill_server_url != prefill_url
@@ -1538,6 +1641,10 @@ async def _invoke_kvcache_seeded_router(
decode_session_newly_opened = False decode_session_newly_opened = False
try: try:
prefill_priority = _prefill_priority_for_router_request(
config=config,
direct_to_d_predicted=True,
)
async with direct_session_lock: async with direct_session_lock:
if not decode_session.opened: if not decode_session.opened:
await _open_streaming_session( await _open_streaming_session(
@@ -1555,6 +1662,7 @@ async def _invoke_kvcache_seeded_router(
config=config, config=config,
decode_worker_index=decision.decode_worker_index, decode_worker_index=decision.decode_worker_index,
session_id=request.session_id, session_id=request.session_id,
prefill_request_priority=prefill_priority,
) )
except Exception: except Exception:
async with direct_session_lock: async with direct_session_lock:
@@ -1615,6 +1723,7 @@ async def _invoke_kvcache_seeded_router(
actual_kv_transfer_blocks=decision.kv_transfer_blocks, actual_kv_transfer_blocks=decision.kv_transfer_blocks,
effective_input_length=request.input_length, effective_input_length=request.input_length,
cached_tokens=cached_tokens, cached_tokens=cached_tokens,
prefill_request_priority=prefill_priority,
session_reused=False, session_reused=False,
session_reset=False, session_reset=False,
latency_s=latency_s, latency_s=latency_s,
@@ -1697,6 +1806,19 @@ async def _execute_request(
) )
if request.turn_id == 1: if request.turn_id == 1:
seed_filter_reason = _seed_filter_reason(
request=request,
config=config,
inflight_decode_load=decision.inflight_decode_load,
)
if seed_filter_reason is not None:
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode=f"pd-router-turn1-{seed_filter_reason}",
)
async with direct_session_lock: async with direct_session_lock:
admit_new_decode_session = _should_admit_new_decode_session( admit_new_decode_session = _should_admit_new_decode_session(
residency=decode_residency, residency=decode_residency,
@@ -1800,16 +1922,33 @@ async def _execute_request(
) )
) )
if can_direct: if can_direct:
return await _invoke_decode_session_direct( try:
client=client, return await _invoke_decode_session_direct(
request=request, client=client,
config=config, request=request,
decision=decision, config=config,
direct_sessions=direct_sessions, decision=decision,
direct_session_lock=direct_session_lock, direct_sessions=direct_sessions,
decode_residency=decode_residency, direct_session_lock=direct_session_lock,
reserved_tokens=direct_reserved_tokens, decode_residency=decode_residency,
) reserved_tokens=direct_reserved_tokens,
)
except Exception as exc:
if not _is_stale_decode_session_error(exc):
raise
async with direct_session_lock:
await _close_decode_session(
client=client,
session=decode_session,
residency=decode_residency,
)
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode="pd-router-fallback-stale-d-session",
)
if _is_decode_backpressure_reason(direct_reason): if _is_decode_backpressure_reason(direct_reason):
return await _invoke_plain_router( return await _invoke_plain_router(
client=client, client=client,
@@ -1819,6 +1958,19 @@ async def _execute_request(
execution_mode="pd-router-fallback-d-backpressure", execution_mode="pd-router-fallback-d-backpressure",
) )
seed_filter_reason = _seed_filter_reason(
request=request,
config=config,
inflight_decode_load=decision.inflight_decode_load,
)
if seed_filter_reason is not None:
return await _invoke_plain_router(
client=client,
request=request,
config=config,
decision=decision,
execution_mode=f"pd-router-fallback-{seed_filter_reason}",
)
async with direct_session_lock: async with direct_session_lock:
admit_new_decode_session = _should_admit_new_decode_session( admit_new_decode_session = _should_admit_new_decode_session(
residency=decode_residency, residency=decode_residency,
@@ -1894,6 +2046,19 @@ async def _execute_request(
), ),
) )
seed_filter_reason = _seed_filter_reason(
request=request,
config=config,
inflight_decode_load=decision.inflight_decode_load,
)
if seed_filter_reason is not None:
return await _invoke_plain_router(
request=request,
client=client,
config=config,
decision=decision,
execution_mode=f"pd-router-fallback-large-append-{seed_filter_reason}",
)
async with direct_session_lock: async with direct_session_lock:
admit_new_decode_session = _should_admit_new_decode_session( admit_new_decode_session = _should_admit_new_decode_session(
residency=decode_residency, residency=decode_residency,

View File

@@ -64,6 +64,7 @@ def launch_pd_stack(
prefill_policy: str, prefill_policy: str,
decode_policy: str, decode_policy: str,
timeout_s: float = 1200.0, timeout_s: float = 1200.0,
router_request_timeout_s: float | None = None,
include_router: bool = True, include_router: bool = True,
) -> ManagedPdStack: ) -> ManagedPdStack:
run_dir.mkdir(parents=True, exist_ok=True) run_dir.mkdir(parents=True, exist_ok=True)
@@ -75,6 +76,7 @@ def launch_pd_stack(
prefill_policy=prefill_policy, prefill_policy=prefill_policy,
decode_policy=decode_policy, decode_policy=decode_policy,
include_router=include_router, include_router=include_router,
router_request_timeout_s=router_request_timeout_s,
) )
prefill_processes = [ prefill_processes = [
@@ -186,7 +188,7 @@ def _build_process_env(topology: SingleNodeTopology) -> dict[str, str]:
env["NO_PROXY"] = "*" env["NO_PROXY"] = "*"
env["no_proxy"] = "*" env["no_proxy"] = "*"
env.setdefault("SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", "600") env.setdefault("SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", "600")
env.setdefault("SGLANG_DISAGGREGATION_WAITING_TIMEOUT", "60") env.setdefault("SGLANG_DISAGGREGATION_WAITING_TIMEOUT", "600")
env.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1") env.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1")
if topology.force_rdma: if topology.force_rdma:
env["MOONCAKE_PROTOCOL"] = "rdma" env["MOONCAKE_PROTOCOL"] = "rdma"

View File

@@ -176,7 +176,7 @@ NSA_CHOICES = [
"trtllm", "trtllm",
] ]
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu", "slru"] RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu", "slru", "priority"]
RL_ON_POLICY_TARGET_CHOICES = ["fsdp"] RL_ON_POLICY_TARGET_CHOICES = ["fsdp"]
@@ -4049,7 +4049,7 @@ class ServerArgs:
type=str, type=str,
choices=RADIX_EVICTION_POLICY_CHOICES, choices=RADIX_EVICTION_POLICY_CHOICES,
default=ServerArgs.radix_eviction_policy, default=ServerArgs.radix_eviction_policy,
help="The eviction policy of radix trees. 'lru' stands for Least Recently Used, 'lfu' stands for Least Frequently Used, and 'slru' stands for Segmented Least Recently Used.", help="The eviction policy of radix trees. 'lru' stands for Least Recently Used, 'lfu' stands for Least Frequently Used, 'slru' stands for Segmented Least Recently Used, and 'priority' evicts lower request priority values first.",
) )
parser.add_argument( parser.add_argument(
"--enable-prefill-delayer", "--enable-prefill-delayer",