Add kvcache-centric profiling and admission controls
This commit is contained in:
@@ -27,10 +27,20 @@ class BenchmarkConfig:
|
||||
time_scale: float = 1.0
|
||||
concurrency_limit: int = 32
|
||||
timeout_s: float = 1200.0
|
||||
request_timeout_s: float | None = None
|
||||
stream: bool = True
|
||||
stream_idle_timeout_s: float | None = 900.0
|
||||
kvcache_direct_max_uncached_tokens: int = 2048
|
||||
kvcache_admission_mode: str = "router"
|
||||
kvcache_seed_max_resident_tokens: int | None = None
|
||||
kvcache_seed_max_output_tokens: int | None = None
|
||||
kvcache_seed_min_turn_id: int = 1
|
||||
kvcache_seed_only_multiturn_sessions: bool = False
|
||||
kvcache_prefill_backup_policy: str = "release-after-transfer"
|
||||
kvcache_seed_max_inflight_decode: int | None = 3
|
||||
kvcache_prefill_priority_eviction: bool = False
|
||||
kvcache_prefill_direct_priority: int = -100
|
||||
kvcache_prefill_normal_priority: int = 100
|
||||
sample_profile: str = "default"
|
||||
min_initial_input_tokens: int | None = None
|
||||
max_initial_input_tokens: int | None = None
|
||||
@@ -59,10 +69,17 @@ def run_live_benchmark(config: BenchmarkConfig) -> BenchmarkArtifacts:
|
||||
|
||||
topology = config.topology
|
||||
if config.mechanism_name == "kvcache-centric":
|
||||
prefill_extra_server_args = topology.prefill_extra_server_args + (
|
||||
"--enable-streaming-session",
|
||||
)
|
||||
if config.kvcache_prefill_priority_eviction:
|
||||
prefill_extra_server_args = prefill_extra_server_args + (
|
||||
"--radix-eviction-policy",
|
||||
"priority",
|
||||
)
|
||||
topology = replace(
|
||||
topology,
|
||||
prefill_extra_server_args=topology.prefill_extra_server_args
|
||||
+ ("--enable-streaming-session",),
|
||||
prefill_extra_server_args=prefill_extra_server_args,
|
||||
decode_extra_server_args=topology.decode_extra_server_args
|
||||
+ (
|
||||
"--enable-streaming-session",
|
||||
@@ -107,6 +124,11 @@ def run_live_benchmark(config: BenchmarkConfig) -> BenchmarkArtifacts:
|
||||
prefill_policy="round_robin",
|
||||
decode_policy=_decode_policy_for(config.policy_name),
|
||||
timeout_s=config.timeout_s,
|
||||
router_request_timeout_s=(
|
||||
config.request_timeout_s
|
||||
if config.request_timeout_s is not None
|
||||
else config.timeout_s
|
||||
),
|
||||
include_router=(
|
||||
config.mechanism_name in {"pd-disaggregation", "kvcache-centric"}
|
||||
),
|
||||
@@ -142,7 +164,27 @@ def run_live_benchmark(config: BenchmarkConfig) -> BenchmarkArtifacts:
|
||||
stream_idle_timeout_s=config.stream_idle_timeout_s,
|
||||
kvcache_direct_max_uncached_tokens=config.kvcache_direct_max_uncached_tokens,
|
||||
kvcache_admission_mode=config.kvcache_admission_mode, # type: ignore[arg-type]
|
||||
kvcache_seed_max_resident_tokens=config.kvcache_seed_max_resident_tokens,
|
||||
kvcache_seed_max_output_tokens=config.kvcache_seed_max_output_tokens,
|
||||
kvcache_seed_min_turn_id=config.kvcache_seed_min_turn_id,
|
||||
kvcache_seed_only_multiturn_sessions=(
|
||||
config.kvcache_seed_only_multiturn_sessions
|
||||
),
|
||||
kvcache_prefill_backup_policy=config.kvcache_prefill_backup_policy, # type: ignore[arg-type]
|
||||
kvcache_seed_max_inflight_decode=(
|
||||
config.kvcache_seed_max_inflight_decode
|
||||
),
|
||||
kvcache_prefill_priority_eviction=(
|
||||
config.kvcache_prefill_priority_eviction
|
||||
),
|
||||
kvcache_prefill_direct_priority=config.kvcache_prefill_direct_priority,
|
||||
kvcache_prefill_normal_priority=config.kvcache_prefill_normal_priority,
|
||||
)
|
||||
if config.request_timeout_s is not None:
|
||||
replay_config = replace(
|
||||
replay_config,
|
||||
timeout_s=config.request_timeout_s,
|
||||
)
|
||||
asyncio.run(replay_trace(replay_config))
|
||||
finally:
|
||||
signal.signal(signal.SIGINT, previous_sigint)
|
||||
@@ -163,10 +205,30 @@ def run_live_benchmark(config: BenchmarkConfig) -> BenchmarkArtifacts:
|
||||
"time_scale": config.time_scale,
|
||||
"concurrency_limit": config.concurrency_limit,
|
||||
"timeout_s": config.timeout_s,
|
||||
"request_timeout_s": config.request_timeout_s,
|
||||
"stream": config.stream,
|
||||
"stream_idle_timeout_s": config.stream_idle_timeout_s,
|
||||
"kvcache_direct_max_uncached_tokens": config.kvcache_direct_max_uncached_tokens,
|
||||
"kvcache_admission_mode": config.kvcache_admission_mode,
|
||||
"kvcache_seed_max_resident_tokens": config.kvcache_seed_max_resident_tokens,
|
||||
"kvcache_seed_max_output_tokens": config.kvcache_seed_max_output_tokens,
|
||||
"kvcache_seed_min_turn_id": config.kvcache_seed_min_turn_id,
|
||||
"kvcache_seed_only_multiturn_sessions": (
|
||||
config.kvcache_seed_only_multiturn_sessions
|
||||
),
|
||||
"kvcache_prefill_backup_policy": config.kvcache_prefill_backup_policy,
|
||||
"kvcache_seed_max_inflight_decode": (
|
||||
config.kvcache_seed_max_inflight_decode
|
||||
),
|
||||
"kvcache_prefill_priority_eviction": (
|
||||
config.kvcache_prefill_priority_eviction
|
||||
),
|
||||
"kvcache_prefill_direct_priority": (
|
||||
config.kvcache_prefill_direct_priority
|
||||
),
|
||||
"kvcache_prefill_normal_priority": (
|
||||
config.kvcache_prefill_normal_priority
|
||||
),
|
||||
"sample_profile": config.sample_profile,
|
||||
"min_initial_input_tokens": config.min_initial_input_tokens,
|
||||
"max_initial_input_tokens": config.max_initial_input_tokens,
|
||||
|
||||
@@ -7,6 +7,7 @@ from pathlib import Path
|
||||
from agentic_pd_hybrid.benchmark import BenchmarkConfig, run_live_benchmark
|
||||
from agentic_pd_hybrid.launcher import build_launch_plan
|
||||
from agentic_pd_hybrid.microbench import SmallAppendTraceConfig, write_small_append_trace
|
||||
from agentic_pd_hybrid.profile import ProfileConfig, print_profile_summary, write_profile
|
||||
from agentic_pd_hybrid.replay import ReplayConfig, replay_trace
|
||||
from agentic_pd_hybrid.sampling import SessionSampleConfig, sample_trace_sessions
|
||||
from agentic_pd_hybrid.trace_profiles import (
|
||||
@@ -142,6 +143,71 @@ def main() -> None:
|
||||
"or query the decode worker on the critical path."
|
||||
),
|
||||
)
|
||||
replay.add_argument(
|
||||
"--kvcache-seed-max-resident-tokens",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"For kvcache-centric routing, do not seed/reseed a decode session "
|
||||
"when input+output tokens exceed this value."
|
||||
),
|
||||
)
|
||||
replay.add_argument(
|
||||
"--kvcache-seed-max-output-tokens",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"For kvcache-centric routing, do not seed/reseed a decode session "
|
||||
"when output tokens exceed this value."
|
||||
),
|
||||
)
|
||||
replay.add_argument(
|
||||
"--kvcache-seed-min-turn-id",
|
||||
type=int,
|
||||
default=1,
|
||||
help=(
|
||||
"For kvcache-centric routing, do not seed/reseed a decode session "
|
||||
"before this turn id."
|
||||
),
|
||||
)
|
||||
replay.add_argument(
|
||||
"--kvcache-seed-only-multiturn-sessions",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Oracle ablation for kvcache-centric routing: only seed sessions "
|
||||
"that have more than one turn in the replay trace."
|
||||
),
|
||||
)
|
||||
replay.add_argument(
|
||||
"--kvcache-prefill-backup-policy",
|
||||
choices=["release-after-transfer", "capacity-backup"],
|
||||
default="release-after-transfer",
|
||||
help=(
|
||||
"For kvcache-centric seed/reseed, release the P-side session after "
|
||||
"P->D transfer or keep a capacity-limited P-side backup."
|
||||
),
|
||||
)
|
||||
replay.add_argument(
|
||||
"--kvcache-seed-max-inflight-decode",
|
||||
type=int,
|
||||
default=3,
|
||||
help=(
|
||||
"For kvcache-centric routing, skip seed/reseed when the router "
|
||||
"shadow inflight decode load assigned to the target D exceeds this value. "
|
||||
"Use a negative value to disable this filter."
|
||||
),
|
||||
)
|
||||
replay.add_argument(
|
||||
"--kvcache-prefill-priority-eviction",
|
||||
action="store_true",
|
||||
help=(
|
||||
"For kvcache-centric routing, mark P-side prefixes predicted to move "
|
||||
"direct-to-D as lower priority. Requires P workers to use "
|
||||
"--radix-eviction-policy priority."
|
||||
),
|
||||
)
|
||||
replay.add_argument("--kvcache-prefill-direct-priority", type=int, default=-100)
|
||||
replay.add_argument("--kvcache-prefill-normal-priority", type=int, default=100)
|
||||
|
||||
sample = subparsers.add_parser(
|
||||
"sample-sessions",
|
||||
@@ -177,6 +243,17 @@ def main() -> None:
|
||||
normalize.add_argument("--output-length", type=int, default=1_000)
|
||||
normalize.add_argument("--max-requests", type=int, default=None)
|
||||
|
||||
profile = subparsers.add_parser(
|
||||
"profile",
|
||||
help="Profile a trace and optional request metrics for routing analysis",
|
||||
)
|
||||
profile.add_argument("--trace", type=Path, required=True)
|
||||
profile.add_argument("--output", type=Path, default=None)
|
||||
profile.add_argument("--metrics", type=Path, default=None)
|
||||
profile.add_argument("--baseline-metrics", type=Path, default=None)
|
||||
profile.add_argument("--candidate-metrics", type=Path, default=None)
|
||||
profile.add_argument("--direct-max-uncached-tokens", type=int, default=2048)
|
||||
|
||||
micro = subparsers.add_parser(
|
||||
"make-small-append-trace",
|
||||
help="Generate a synthetic multi-turn trace with small turn2+ appends",
|
||||
@@ -223,6 +300,15 @@ def main() -> None:
|
||||
benchmark.add_argument("--time-scale", type=float, default=1.0)
|
||||
benchmark.add_argument("--concurrency-limit", type=int, default=32)
|
||||
benchmark.add_argument("--timeout-s", type=float, default=1200.0)
|
||||
benchmark.add_argument(
|
||||
"--request-timeout-s",
|
||||
type=float,
|
||||
default=None,
|
||||
help=(
|
||||
"Per-request replay/router timeout. If unset, --timeout-s is used. "
|
||||
"--timeout-s still controls stack startup."
|
||||
),
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--no-stream",
|
||||
action="store_true",
|
||||
@@ -249,6 +335,70 @@ def main() -> None:
|
||||
"or query the decode worker on the critical path."
|
||||
),
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--kvcache-seed-max-resident-tokens",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"For kvcache-centric routing, do not seed/reseed a decode session "
|
||||
"when input+output tokens exceed this value."
|
||||
),
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--kvcache-seed-max-output-tokens",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"For kvcache-centric routing, do not seed/reseed a decode session "
|
||||
"when output tokens exceed this value."
|
||||
),
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--kvcache-seed-min-turn-id",
|
||||
type=int,
|
||||
default=1,
|
||||
help=(
|
||||
"For kvcache-centric routing, do not seed/reseed a decode session "
|
||||
"before this turn id."
|
||||
),
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--kvcache-seed-only-multiturn-sessions",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Oracle ablation for kvcache-centric routing: only seed sessions "
|
||||
"that have more than one turn in the replay trace."
|
||||
),
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--kvcache-prefill-backup-policy",
|
||||
choices=["release-after-transfer", "capacity-backup"],
|
||||
default="release-after-transfer",
|
||||
help=(
|
||||
"For kvcache-centric seed/reseed, release the P-side session after "
|
||||
"P->D transfer or keep a capacity-limited P-side backup."
|
||||
),
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--kvcache-seed-max-inflight-decode",
|
||||
type=int,
|
||||
default=3,
|
||||
help=(
|
||||
"For kvcache-centric routing, skip seed/reseed when the router "
|
||||
"shadow inflight decode load assigned to the target D exceeds this value. "
|
||||
"Use a negative value to disable this filter."
|
||||
),
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--kvcache-prefill-priority-eviction",
|
||||
action="store_true",
|
||||
help=(
|
||||
"For kvcache-centric benchmark-live, launch P workers with priority "
|
||||
"radix eviction and mark direct-to-D predicted prefixes lower priority."
|
||||
),
|
||||
)
|
||||
benchmark.add_argument("--kvcache-prefill-direct-priority", type=int, default=-100)
|
||||
benchmark.add_argument("--kvcache-prefill-normal-priority", type=int, default=100)
|
||||
benchmark.add_argument(
|
||||
"--sample-profile",
|
||||
choices=["default", "small-append"],
|
||||
@@ -294,6 +444,23 @@ def main() -> None:
|
||||
stream_idle_timeout_s=args.stream_idle_timeout_s,
|
||||
kvcache_direct_max_uncached_tokens=args.kvcache_direct_max_uncached_tokens,
|
||||
kvcache_admission_mode=args.kvcache_admission_mode,
|
||||
kvcache_seed_max_resident_tokens=args.kvcache_seed_max_resident_tokens,
|
||||
kvcache_seed_max_output_tokens=args.kvcache_seed_max_output_tokens,
|
||||
kvcache_seed_min_turn_id=args.kvcache_seed_min_turn_id,
|
||||
kvcache_seed_only_multiturn_sessions=(
|
||||
args.kvcache_seed_only_multiturn_sessions
|
||||
),
|
||||
kvcache_prefill_backup_policy=args.kvcache_prefill_backup_policy,
|
||||
kvcache_seed_max_inflight_decode=(
|
||||
None
|
||||
if args.kvcache_seed_max_inflight_decode < 0
|
||||
else args.kvcache_seed_max_inflight_decode
|
||||
),
|
||||
kvcache_prefill_priority_eviction=(
|
||||
args.kvcache_prefill_priority_eviction
|
||||
),
|
||||
kvcache_prefill_direct_priority=args.kvcache_prefill_direct_priority,
|
||||
kvcache_prefill_normal_priority=args.kvcache_prefill_normal_priority,
|
||||
)
|
||||
results = asyncio.run(replay_trace(config))
|
||||
print(
|
||||
@@ -302,6 +469,26 @@ def main() -> None:
|
||||
)
|
||||
return
|
||||
|
||||
if args.command == "profile":
|
||||
if (args.baseline_metrics is None) != (args.candidate_metrics is None):
|
||||
raise ValueError(
|
||||
"--baseline-metrics and --candidate-metrics must be provided together"
|
||||
)
|
||||
report = write_profile(
|
||||
ProfileConfig(
|
||||
trace_path=args.trace,
|
||||
output_path=args.output,
|
||||
metrics_path=args.metrics,
|
||||
baseline_metrics_path=args.baseline_metrics,
|
||||
candidate_metrics_path=args.candidate_metrics,
|
||||
direct_max_uncached_tokens=args.direct_max_uncached_tokens,
|
||||
)
|
||||
)
|
||||
print_profile_summary(report)
|
||||
if args.output is not None:
|
||||
print(f"wrote profile to {args.output}")
|
||||
return
|
||||
|
||||
if args.command == "sample-sessions":
|
||||
summary = sample_trace_sessions(
|
||||
SessionSampleConfig(
|
||||
@@ -378,10 +565,32 @@ def main() -> None:
|
||||
time_scale=args.time_scale,
|
||||
concurrency_limit=args.concurrency_limit,
|
||||
timeout_s=args.timeout_s,
|
||||
request_timeout_s=args.request_timeout_s,
|
||||
stream=not args.no_stream,
|
||||
stream_idle_timeout_s=args.stream_idle_timeout_s,
|
||||
kvcache_direct_max_uncached_tokens=args.kvcache_direct_max_uncached_tokens,
|
||||
kvcache_admission_mode=args.kvcache_admission_mode,
|
||||
kvcache_seed_max_resident_tokens=args.kvcache_seed_max_resident_tokens,
|
||||
kvcache_seed_max_output_tokens=args.kvcache_seed_max_output_tokens,
|
||||
kvcache_seed_min_turn_id=args.kvcache_seed_min_turn_id,
|
||||
kvcache_seed_only_multiturn_sessions=(
|
||||
args.kvcache_seed_only_multiturn_sessions
|
||||
),
|
||||
kvcache_prefill_backup_policy=args.kvcache_prefill_backup_policy,
|
||||
kvcache_seed_max_inflight_decode=(
|
||||
None
|
||||
if args.kvcache_seed_max_inflight_decode < 0
|
||||
else args.kvcache_seed_max_inflight_decode
|
||||
),
|
||||
kvcache_prefill_priority_eviction=(
|
||||
args.kvcache_prefill_priority_eviction
|
||||
),
|
||||
kvcache_prefill_direct_priority=(
|
||||
args.kvcache_prefill_direct_priority
|
||||
),
|
||||
kvcache_prefill_normal_priority=(
|
||||
args.kvcache_prefill_normal_priority
|
||||
),
|
||||
sample_profile=args.sample_profile,
|
||||
min_initial_input_tokens=args.min_initial_input_tokens,
|
||||
max_initial_input_tokens=args.max_initial_input_tokens,
|
||||
|
||||
@@ -33,6 +33,7 @@ def build_launch_plan(
|
||||
prefill_policy: str = "round_robin",
|
||||
decode_policy: str = "manual",
|
||||
include_router: bool = True,
|
||||
router_request_timeout_s: float | None = None,
|
||||
) -> LaunchPlan:
|
||||
return LaunchPlan(
|
||||
prefill_commands=tuple(
|
||||
@@ -49,6 +50,7 @@ def build_launch_plan(
|
||||
topology,
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
request_timeout_s=router_request_timeout_s,
|
||||
)
|
||||
if include_router and topology.prefill_workers and topology.decode_workers
|
||||
else None
|
||||
@@ -105,6 +107,7 @@ def _build_router_command(
|
||||
*,
|
||||
prefill_policy: str,
|
||||
decode_policy: str,
|
||||
request_timeout_s: float | None,
|
||||
) -> tuple[str, ...]:
|
||||
command: list[str] = [
|
||||
sys.executable,
|
||||
@@ -121,6 +124,8 @@ def _build_router_command(
|
||||
"--decode-policy",
|
||||
decode_policy,
|
||||
]
|
||||
if request_timeout_s is not None:
|
||||
command.extend(["--request-timeout-s", str(request_timeout_s)])
|
||||
for worker in topology.prefill_workers:
|
||||
command.extend(
|
||||
["--prefill", worker.url, str(worker.bootstrap_port or topology.router_port)]
|
||||
|
||||
@@ -33,6 +33,8 @@ class RequestMetrics:
|
||||
kv_transfer_blocks: int
|
||||
actual_kv_transfer_blocks: int
|
||||
cached_tokens: int
|
||||
prefill_request_priority: int | None
|
||||
decode_request_priority: int | None
|
||||
re_prefill_required: bool
|
||||
effective_input_length: int | None
|
||||
session_reused: bool
|
||||
@@ -58,6 +60,8 @@ class RequestMetrics:
|
||||
latency_s: float | None,
|
||||
ttft_s: float | None,
|
||||
tpot_s: float | None,
|
||||
prefill_request_priority: int | None = None,
|
||||
decode_request_priority: int | None = None,
|
||||
error: str | None = None,
|
||||
) -> "RequestMetrics":
|
||||
return cls(
|
||||
@@ -81,6 +85,8 @@ class RequestMetrics:
|
||||
kv_transfer_blocks=decision.kv_transfer_blocks,
|
||||
actual_kv_transfer_blocks=actual_kv_transfer_blocks,
|
||||
cached_tokens=cached_tokens,
|
||||
prefill_request_priority=prefill_request_priority,
|
||||
decode_request_priority=decode_request_priority,
|
||||
re_prefill_required=decision.re_prefill_required,
|
||||
effective_input_length=effective_input_length,
|
||||
session_reused=session_reused,
|
||||
@@ -111,6 +117,16 @@ def write_summary_json(
|
||||
tpots = [row.tpot_s for row in rows if row.tpot_s is not None]
|
||||
per_decode_load = Counter(row.assigned_decode_node for row in rows)
|
||||
per_prefill_load = Counter(row.assigned_prefill_node for row in rows)
|
||||
prefill_priorities = Counter(
|
||||
row.prefill_request_priority
|
||||
for row in rows
|
||||
if row.prefill_request_priority is not None
|
||||
)
|
||||
decode_priorities = Counter(
|
||||
row.decode_request_priority
|
||||
for row in rows
|
||||
if row.decode_request_priority is not None
|
||||
)
|
||||
|
||||
summary: dict[str, Any] = {
|
||||
"trace_path": str(trace_path),
|
||||
@@ -135,6 +151,12 @@ def write_summary_json(
|
||||
),
|
||||
"per_decode_load": dict(sorted(per_decode_load.items())),
|
||||
"per_prefill_load": dict(sorted(per_prefill_load.items())),
|
||||
"prefill_request_priorities": {
|
||||
str(key): value for key, value in sorted(prefill_priorities.items())
|
||||
},
|
||||
"decode_request_priorities": {
|
||||
str(key): value for key, value in sorted(decode_priorities.items())
|
||||
},
|
||||
"error_count": sum(1 for row in rows if row.error is not None),
|
||||
}
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -149,13 +149,17 @@ async def _forward_to_backend(
|
||||
) -> Response:
|
||||
state = _require_state()
|
||||
prefill_server, bootstrap_port, decode_server = state.select_pair(headers)
|
||||
modified_request = request_data.copy()
|
||||
modified_request.update(_build_bootstrap_payload(prefill_server, bootstrap_port))
|
||||
prefill_request, decode_request = _build_backend_requests(
|
||||
request_data=request_data,
|
||||
prefill_server=prefill_server,
|
||||
bootstrap_port=bootstrap_port,
|
||||
)
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return StreamingResponse(
|
||||
_stream_generate(
|
||||
modified_request=modified_request,
|
||||
prefill_request=prefill_request,
|
||||
decode_request=decode_request,
|
||||
prefill_server=prefill_server,
|
||||
decode_server=decode_server,
|
||||
endpoint_name=endpoint_name,
|
||||
@@ -168,8 +172,8 @@ async def _forward_to_backend(
|
||||
timeout=aiohttp.ClientTimeout(total=state.config.request_timeout_s)
|
||||
) as session:
|
||||
prefill_response, decode_response = await asyncio.gather(
|
||||
session.post(f"{prefill_server}/{endpoint_name}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint_name}", json=modified_request),
|
||||
session.post(f"{prefill_server}/{endpoint_name}", json=prefill_request),
|
||||
session.post(f"{decode_server}/{endpoint_name}", json=decode_request),
|
||||
)
|
||||
async with prefill_response:
|
||||
await prefill_response.read()
|
||||
@@ -184,7 +188,8 @@ async def _forward_to_backend(
|
||||
|
||||
async def _stream_generate(
|
||||
*,
|
||||
modified_request: dict,
|
||||
prefill_request: dict,
|
||||
decode_request: dict,
|
||||
prefill_server: str,
|
||||
decode_server: str,
|
||||
endpoint_name: str,
|
||||
@@ -194,8 +199,8 @@ async def _stream_generate(
|
||||
timeout=aiohttp.ClientTimeout(total=timeout_s)
|
||||
) as session:
|
||||
prefill_response, decode_response = await asyncio.gather(
|
||||
session.post(f"{prefill_server}/{endpoint_name}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint_name}", json=modified_request),
|
||||
session.post(f"{prefill_server}/{endpoint_name}", json=prefill_request),
|
||||
session.post(f"{decode_server}/{endpoint_name}", json=decode_request),
|
||||
)
|
||||
async with prefill_response, decode_response:
|
||||
if decode_response.status != HTTPStatus.OK:
|
||||
@@ -221,6 +226,35 @@ def _build_bootstrap_payload(prefill_server: str, bootstrap_port: int) -> dict[s
|
||||
}
|
||||
|
||||
|
||||
def _build_backend_requests(
|
||||
*,
|
||||
request_data: dict,
|
||||
prefill_server: str,
|
||||
bootstrap_port: int,
|
||||
) -> tuple[dict, dict]:
|
||||
prefill_priority = request_data.get("smg_prefill_priority")
|
||||
decode_priority = request_data.get("smg_decode_priority")
|
||||
|
||||
prefill_request = _strip_internal_fields(request_data)
|
||||
decode_request = _strip_internal_fields(request_data)
|
||||
bootstrap_payload = _build_bootstrap_payload(prefill_server, bootstrap_port)
|
||||
prefill_request.update(bootstrap_payload)
|
||||
decode_request.update(bootstrap_payload)
|
||||
|
||||
if prefill_priority is not None:
|
||||
prefill_request["priority"] = int(prefill_priority)
|
||||
if decode_priority is not None:
|
||||
decode_request["priority"] = int(decode_priority)
|
||||
return prefill_request, decode_request
|
||||
|
||||
|
||||
def _strip_internal_fields(request_data: dict) -> dict:
|
||||
cleaned = request_data.copy()
|
||||
cleaned.pop("smg_prefill_priority", None)
|
||||
cleaned.pop("smg_decode_priority", None)
|
||||
return cleaned
|
||||
|
||||
|
||||
def _require_state() -> RouterState:
|
||||
if router_state is None:
|
||||
raise HTTPException(status_code=500, detail="router not initialized")
|
||||
|
||||
511
src/agentic_pd_hybrid/profile.py
Normal file
511
src/agentic_pd_hybrid/profile.py
Normal file
@@ -0,0 +1,511 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import statistics
|
||||
from collections import Counter, defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
|
||||
from agentic_pd_hybrid.trace import TraceRequest, load_trace
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProfileConfig:
|
||||
trace_path: Path
|
||||
output_path: Path | None = None
|
||||
metrics_path: Path | None = None
|
||||
baseline_metrics_path: Path | None = None
|
||||
candidate_metrics_path: Path | None = None
|
||||
direct_max_uncached_tokens: int = 2048
|
||||
|
||||
|
||||
def write_profile(config: ProfileConfig) -> dict[str, Any]:
|
||||
report = build_profile(config)
|
||||
if config.output_path is not None:
|
||||
config.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with config.output_path.open("w", encoding="utf-8") as handle:
|
||||
json.dump(report, handle, indent=2, sort_keys=True)
|
||||
return report
|
||||
|
||||
|
||||
def build_profile(config: ProfileConfig) -> dict[str, Any]:
|
||||
requests = load_trace(config.trace_path)
|
||||
features = _build_trace_features(
|
||||
requests,
|
||||
direct_max_uncached_tokens=config.direct_max_uncached_tokens,
|
||||
)
|
||||
report: dict[str, Any] = {
|
||||
"trace_path": str(config.trace_path),
|
||||
"direct_max_uncached_tokens": config.direct_max_uncached_tokens,
|
||||
"trace_profile": _trace_profile(requests, features),
|
||||
}
|
||||
|
||||
if config.metrics_path is not None:
|
||||
metrics = _load_jsonl(config.metrics_path)
|
||||
report["metrics_path"] = str(config.metrics_path)
|
||||
report["metrics_profile"] = _metrics_profile(metrics, features)
|
||||
|
||||
if (
|
||||
config.baseline_metrics_path is not None
|
||||
and config.candidate_metrics_path is not None
|
||||
):
|
||||
baseline = _load_jsonl(config.baseline_metrics_path)
|
||||
candidate = _load_jsonl(config.candidate_metrics_path)
|
||||
report["baseline_metrics_path"] = str(config.baseline_metrics_path)
|
||||
report["candidate_metrics_path"] = str(config.candidate_metrics_path)
|
||||
report["baseline_profile"] = _metrics_profile(baseline, features)
|
||||
report["candidate_profile"] = _metrics_profile(candidate, features)
|
||||
report["paired_comparison"] = _paired_comparison(
|
||||
baseline=baseline,
|
||||
candidate=candidate,
|
||||
features=features,
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def print_profile_summary(report: dict[str, Any]) -> None:
|
||||
trace = report["trace_profile"]
|
||||
print(
|
||||
"trace: "
|
||||
f"{trace['request_count']} requests, "
|
||||
f"{trace['session_count']} sessions, "
|
||||
f"{trace['multi_turn_session_count']} multi-turn sessions"
|
||||
)
|
||||
print(
|
||||
"direct-eligible turns: "
|
||||
f"{trace['direct_eligible_turn2plus_count']}/"
|
||||
f"{trace['turn2plus_count']} "
|
||||
f"({trace['direct_eligible_turn2plus_ratio']:.3f})"
|
||||
)
|
||||
append_stats = trace.get("append_input_tokens_stats")
|
||||
output_stats = trace.get("output_tokens_stats")
|
||||
if append_stats is not None:
|
||||
print(
|
||||
"append tokens: "
|
||||
f"mean={append_stats['mean']:.1f} "
|
||||
f"p50={append_stats['p50']:.1f} "
|
||||
f"p90={append_stats['p90']:.1f} "
|
||||
f"p99={append_stats['p99']:.1f}"
|
||||
)
|
||||
if output_stats is not None:
|
||||
print(
|
||||
"output tokens: "
|
||||
f"mean={output_stats['mean']:.1f} "
|
||||
f"p50={output_stats['p50']:.1f} "
|
||||
f"p90={output_stats['p90']:.1f} "
|
||||
f"p99={output_stats['p99']:.1f}"
|
||||
)
|
||||
|
||||
comparison = report.get("paired_comparison")
|
||||
if isinstance(comparison, dict):
|
||||
overall = comparison.get("overall", {})
|
||||
delta = overall.get("latency_delta_s_stats")
|
||||
if delta is not None:
|
||||
print(
|
||||
"candidate - baseline E2E: "
|
||||
f"mean={delta['mean']:.3f}s "
|
||||
f"p50={delta['p50']:.3f}s "
|
||||
f"p90={delta['p90']:.3f}s"
|
||||
)
|
||||
print(
|
||||
"paired wins/losses: "
|
||||
f"{overall.get('candidate_faster_count', 0)} faster, "
|
||||
f"{overall.get('candidate_slower_count', 0)} slower, "
|
||||
f"{overall.get('paired_count', 0)} paired"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _TraceFeature:
|
||||
request_id: str
|
||||
session_id: str
|
||||
turn_id: int
|
||||
input_length: int
|
||||
output_length: int
|
||||
resident_tokens: int
|
||||
append_input_tokens: int | None
|
||||
inter_turn_gap_s: float | None
|
||||
overlap_blocks_with_previous: int | None
|
||||
overlap_ratio_with_previous: float | None
|
||||
direct_eligible: bool
|
||||
turn_class: str
|
||||
append_bin: str
|
||||
input_bin: str
|
||||
output_bin: str
|
||||
resident_bin: str
|
||||
|
||||
|
||||
def _build_trace_features(
|
||||
requests: list[TraceRequest],
|
||||
*,
|
||||
direct_max_uncached_tokens: int,
|
||||
) -> dict[str, _TraceFeature]:
|
||||
ordered_by_session: dict[str, list[TraceRequest]] = defaultdict(list)
|
||||
for request in requests:
|
||||
ordered_by_session[request.session_id].append(request)
|
||||
|
||||
previous_by_request_id: dict[str, TraceRequest | None] = {}
|
||||
for session_requests in ordered_by_session.values():
|
||||
ordered = sorted(
|
||||
session_requests,
|
||||
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
|
||||
)
|
||||
previous: TraceRequest | None = None
|
||||
for request in ordered:
|
||||
previous_by_request_id[request.request_id] = previous
|
||||
previous = request
|
||||
|
||||
features: dict[str, _TraceFeature] = {}
|
||||
for request in requests:
|
||||
previous = previous_by_request_id.get(request.request_id)
|
||||
append_input_tokens: int | None = None
|
||||
inter_turn_gap_s: float | None = None
|
||||
overlap_blocks: int | None = None
|
||||
overlap_ratio: float | None = None
|
||||
direct_eligible = False
|
||||
|
||||
if previous is not None:
|
||||
append_input_tokens = request.input_length - (
|
||||
previous.input_length + previous.output_length
|
||||
)
|
||||
inter_turn_gap_s = request.timestamp_s - previous.timestamp_s
|
||||
previous_blocks = set(previous.hash_ids)
|
||||
overlap_blocks = sum(1 for block in request.hash_ids if block in previous_blocks)
|
||||
overlap_ratio = (
|
||||
overlap_blocks / len(request.hash_ids) if request.hash_ids else 0.0
|
||||
)
|
||||
direct_eligible = (
|
||||
append_input_tokens > 0
|
||||
and append_input_tokens <= direct_max_uncached_tokens
|
||||
and overlap_blocks > 0
|
||||
)
|
||||
|
||||
features[request.request_id] = _TraceFeature(
|
||||
request_id=request.request_id,
|
||||
session_id=request.session_id,
|
||||
turn_id=request.turn_id,
|
||||
input_length=request.input_length,
|
||||
output_length=request.output_length,
|
||||
resident_tokens=request.input_length + request.output_length,
|
||||
append_input_tokens=append_input_tokens,
|
||||
inter_turn_gap_s=inter_turn_gap_s,
|
||||
overlap_blocks_with_previous=overlap_blocks,
|
||||
overlap_ratio_with_previous=overlap_ratio,
|
||||
direct_eligible=direct_eligible,
|
||||
turn_class="turn1" if request.turn_id <= 1 else "turn2plus",
|
||||
append_bin=_token_bin(append_input_tokens),
|
||||
input_bin=_token_bin(request.input_length),
|
||||
output_bin=_token_bin(request.output_length),
|
||||
resident_bin=_token_bin(request.input_length + request.output_length),
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
def _trace_profile(
|
||||
requests: list[TraceRequest],
|
||||
features: dict[str, _TraceFeature],
|
||||
) -> dict[str, Any]:
|
||||
session_turns = Counter(request.session_id for request in requests)
|
||||
turn2plus = [feature for feature in features.values() if feature.turn_id > 1]
|
||||
direct_eligible = [feature for feature in turn2plus if feature.direct_eligible]
|
||||
append_values = [
|
||||
feature.append_input_tokens
|
||||
for feature in turn2plus
|
||||
if feature.append_input_tokens is not None
|
||||
]
|
||||
positive_append_values = [
|
||||
value for value in append_values if value is not None and value > 0
|
||||
]
|
||||
overlap_ratios = [
|
||||
feature.overlap_ratio_with_previous
|
||||
for feature in turn2plus
|
||||
if feature.overlap_ratio_with_previous is not None
|
||||
]
|
||||
gaps = [
|
||||
feature.inter_turn_gap_s
|
||||
for feature in turn2plus
|
||||
if feature.inter_turn_gap_s is not None
|
||||
]
|
||||
return {
|
||||
"request_count": len(requests),
|
||||
"session_count": len(session_turns),
|
||||
"multi_turn_session_count": sum(1 for turns in session_turns.values() if turns > 1),
|
||||
"turn2plus_count": len(turn2plus),
|
||||
"direct_eligible_turn2plus_count": len(direct_eligible),
|
||||
"direct_eligible_turn2plus_ratio": (
|
||||
len(direct_eligible) / len(turn2plus) if turn2plus else 0.0
|
||||
),
|
||||
"turn_count_distribution": dict(sorted(Counter(session_turns.values()).items())),
|
||||
"request_type_distribution": dict(
|
||||
sorted(Counter(request.request_type for request in requests).items())
|
||||
),
|
||||
"turn_id_distribution": dict(
|
||||
sorted(Counter(request.turn_id for request in requests).items())
|
||||
),
|
||||
"append_bin_distribution": dict(
|
||||
sorted(Counter(feature.append_bin for feature in turn2plus).items())
|
||||
),
|
||||
"input_bin_distribution": dict(
|
||||
sorted(Counter(feature.input_bin for feature in features.values()).items())
|
||||
),
|
||||
"output_bin_distribution": dict(
|
||||
sorted(Counter(feature.output_bin for feature in features.values()).items())
|
||||
),
|
||||
"resident_bin_distribution": dict(
|
||||
sorted(Counter(feature.resident_bin for feature in features.values()).items())
|
||||
),
|
||||
"input_tokens_stats": _stats(
|
||||
[float(request.input_length) for request in requests]
|
||||
),
|
||||
"output_tokens_stats": _stats(
|
||||
[float(request.output_length) for request in requests]
|
||||
),
|
||||
"resident_tokens_stats": _stats(
|
||||
[float(feature.resident_tokens) for feature in features.values()]
|
||||
),
|
||||
"append_input_tokens_stats": _stats(
|
||||
[float(value) for value in append_values if value is not None]
|
||||
),
|
||||
"positive_append_input_tokens_stats": _stats(
|
||||
[float(value) for value in positive_append_values]
|
||||
),
|
||||
"inter_turn_gap_s_stats": _stats([float(value) for value in gaps]),
|
||||
"overlap_ratio_stats": _stats([float(value) for value in overlap_ratios]),
|
||||
"non_positive_append_count": sum(
|
||||
1 for value in append_values if value is not None and value <= 0
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _metrics_profile(
|
||||
rows: list[dict[str, Any]],
|
||||
features: dict[str, _TraceFeature],
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"request_count": len(rows),
|
||||
"mechanism_distribution": dict(
|
||||
sorted(Counter(str(row.get("mechanism_name")) for row in rows).items())
|
||||
),
|
||||
"execution_mode_distribution": dict(
|
||||
sorted(Counter(str(row.get("execution_mode")) for row in rows).items())
|
||||
),
|
||||
"latency_s_stats": _stats(_numeric_values(rows, "latency_s")),
|
||||
"ttft_s_stats": _stats(_numeric_values(rows, "ttft_s")),
|
||||
"tpot_s_stats": _stats(_numeric_values(rows, "tpot_s")),
|
||||
"cached_tokens_stats": _stats(_numeric_values(rows, "cached_tokens")),
|
||||
"actual_kv_transfer_blocks_stats": _stats(
|
||||
_numeric_values(rows, "actual_kv_transfer_blocks")
|
||||
),
|
||||
"session_reused_count": sum(1 for row in rows if row.get("session_reused")),
|
||||
"session_reset_count": sum(1 for row in rows if row.get("session_reset")),
|
||||
"error_count": sum(1 for row in rows if row.get("error") is not None),
|
||||
"by_turn_class": _group_metrics(rows, features, lambda feature, _row: feature.turn_class),
|
||||
"by_direct_eligible": _group_metrics(
|
||||
rows,
|
||||
features,
|
||||
lambda feature, _row: "eligible" if feature.direct_eligible else "not_eligible",
|
||||
),
|
||||
"by_append_bin": _group_metrics(rows, features, lambda feature, _row: feature.append_bin),
|
||||
"by_resident_bin": _group_metrics(
|
||||
rows,
|
||||
features,
|
||||
lambda feature, _row: feature.resident_bin,
|
||||
),
|
||||
"by_execution_mode": _group_metrics(
|
||||
rows,
|
||||
features,
|
||||
lambda _feature, row: str(row.get("execution_mode")),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _paired_comparison(
|
||||
*,
|
||||
baseline: list[dict[str, Any]],
|
||||
candidate: list[dict[str, Any]],
|
||||
features: dict[str, _TraceFeature],
|
||||
) -> dict[str, Any]:
|
||||
baseline_by_id = {
|
||||
str(row.get("request_id")): row
|
||||
for row in baseline
|
||||
if row.get("latency_s") is not None
|
||||
}
|
||||
candidate_by_id = {
|
||||
str(row.get("request_id")): row
|
||||
for row in candidate
|
||||
if row.get("latency_s") is not None
|
||||
}
|
||||
paired_ids = sorted(set(baseline_by_id) & set(candidate_by_id))
|
||||
pairs = [
|
||||
(baseline_by_id[request_id], candidate_by_id[request_id], features.get(request_id))
|
||||
for request_id in paired_ids
|
||||
]
|
||||
pairs = [pair for pair in pairs if pair[2] is not None]
|
||||
|
||||
return {
|
||||
"overall": _delta_summary(pairs),
|
||||
"by_turn_class": _group_deltas(
|
||||
pairs,
|
||||
lambda feature, _base, _cand: feature.turn_class,
|
||||
),
|
||||
"by_direct_eligible": _group_deltas(
|
||||
pairs,
|
||||
lambda feature, _base, _cand: (
|
||||
"eligible" if feature.direct_eligible else "not_eligible"
|
||||
),
|
||||
),
|
||||
"by_append_bin": _group_deltas(
|
||||
pairs,
|
||||
lambda feature, _base, _cand: feature.append_bin,
|
||||
),
|
||||
"by_resident_bin": _group_deltas(
|
||||
pairs,
|
||||
lambda feature, _base, _cand: feature.resident_bin,
|
||||
),
|
||||
"by_candidate_execution_mode": _group_deltas(
|
||||
pairs,
|
||||
lambda _feature, _base, cand: str(cand.get("execution_mode")),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _group_metrics(
|
||||
rows: list[dict[str, Any]],
|
||||
features: dict[str, _TraceFeature],
|
||||
key_fn,
|
||||
) -> dict[str, Any]:
|
||||
grouped: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
for row in rows:
|
||||
feature = features.get(str(row.get("request_id")))
|
||||
if feature is None:
|
||||
continue
|
||||
grouped[str(key_fn(feature, row))].append(row)
|
||||
return {
|
||||
key: {
|
||||
"count": len(group_rows),
|
||||
"latency_s_stats": _stats(_numeric_values(group_rows, "latency_s")),
|
||||
"ttft_s_stats": _stats(_numeric_values(group_rows, "ttft_s")),
|
||||
"tpot_s_stats": _stats(_numeric_values(group_rows, "tpot_s")),
|
||||
"session_reused_count": sum(
|
||||
1 for row in group_rows if row.get("session_reused")
|
||||
),
|
||||
"error_count": sum(1 for row in group_rows if row.get("error") is not None),
|
||||
}
|
||||
for key, group_rows in sorted(grouped.items())
|
||||
}
|
||||
|
||||
|
||||
def _group_deltas(
|
||||
pairs: list[tuple[dict[str, Any], dict[str, Any], _TraceFeature | None]],
|
||||
key_fn,
|
||||
) -> dict[str, Any]:
|
||||
grouped: dict[str, list[tuple[dict[str, Any], dict[str, Any], _TraceFeature | None]]] = (
|
||||
defaultdict(list)
|
||||
)
|
||||
for base, cand, feature in pairs:
|
||||
if feature is None:
|
||||
continue
|
||||
grouped[str(key_fn(feature, base, cand))].append((base, cand, feature))
|
||||
return {key: _delta_summary(group_pairs) for key, group_pairs in sorted(grouped.items())}
|
||||
|
||||
|
||||
def _delta_summary(
|
||||
pairs: list[tuple[dict[str, Any], dict[str, Any], _TraceFeature | None]],
|
||||
) -> dict[str, Any]:
|
||||
latency_deltas = [
|
||||
float(cand["latency_s"]) - float(base["latency_s"])
|
||||
for base, cand, _feature in pairs
|
||||
if base.get("latency_s") is not None and cand.get("latency_s") is not None
|
||||
]
|
||||
ttft_deltas = [
|
||||
float(cand["ttft_s"]) - float(base["ttft_s"])
|
||||
for base, cand, _feature in pairs
|
||||
if base.get("ttft_s") is not None and cand.get("ttft_s") is not None
|
||||
]
|
||||
return {
|
||||
"paired_count": len(latency_deltas),
|
||||
"candidate_faster_count": sum(1 for delta in latency_deltas if delta < 0),
|
||||
"candidate_slower_count": sum(1 for delta in latency_deltas if delta > 0),
|
||||
"latency_delta_s_stats": _stats(latency_deltas),
|
||||
"ttft_delta_s_stats": _stats(ttft_deltas),
|
||||
"total_latency_delta_s": sum(latency_deltas),
|
||||
"mean_baseline_latency_s": _mean(
|
||||
[
|
||||
float(base["latency_s"])
|
||||
for base, cand, _feature in pairs
|
||||
if base.get("latency_s") is not None and cand.get("latency_s") is not None
|
||||
]
|
||||
),
|
||||
"mean_candidate_latency_s": _mean(
|
||||
[
|
||||
float(cand["latency_s"])
|
||||
for base, cand, _feature in pairs
|
||||
if base.get("latency_s") is not None and cand.get("latency_s") is not None
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _load_jsonl(path: Path) -> list[dict[str, Any]]:
|
||||
rows: list[dict[str, Any]] = []
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if line.strip():
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def _numeric_values(rows: Iterable[dict[str, Any]], key: str) -> list[float]:
|
||||
values: list[float] = []
|
||||
for row in rows:
|
||||
value = row.get(key)
|
||||
if value is not None:
|
||||
values.append(float(value))
|
||||
return values
|
||||
|
||||
|
||||
def _stats(values: list[float]) -> dict[str, float] | None:
|
||||
clean = [value for value in values if value is not None]
|
||||
if not clean:
|
||||
return None
|
||||
clean.sort()
|
||||
return {
|
||||
"count": float(len(clean)),
|
||||
"mean": statistics.fmean(clean),
|
||||
"p50": _percentile(clean, 0.50),
|
||||
"p90": _percentile(clean, 0.90),
|
||||
"p99": _percentile(clean, 0.99),
|
||||
"min": clean[0],
|
||||
"max": clean[-1],
|
||||
}
|
||||
|
||||
|
||||
def _percentile(sorted_values: list[float], percentile: float) -> float:
|
||||
if len(sorted_values) == 1:
|
||||
return sorted_values[0]
|
||||
index = round((len(sorted_values) - 1) * percentile)
|
||||
return sorted_values[index]
|
||||
|
||||
|
||||
def _mean(values: list[float]) -> float | None:
|
||||
if not values:
|
||||
return None
|
||||
return statistics.fmean(values)
|
||||
|
||||
|
||||
def _token_bin(value: int | None) -> str:
|
||||
if value is None:
|
||||
return "none"
|
||||
if value <= 0:
|
||||
return "<=0"
|
||||
if value <= 512:
|
||||
return "1-512"
|
||||
if value <= 2048:
|
||||
return "513-2048"
|
||||
if value <= 8192:
|
||||
return "2049-8192"
|
||||
if value <= 32768:
|
||||
return "8193-32768"
|
||||
return ">32768"
|
||||
@@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field, replace
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
@@ -26,6 +27,8 @@ from agentic_pd_hybrid.trace import (
|
||||
|
||||
HeaderMode = Literal["none", "routing-key", "target-worker", "auto"]
|
||||
KvCacheAdmissionMode = Literal["router", "worker"]
|
||||
KvCachePrefillBackupPolicy = Literal["release-after-transfer", "capacity-backup"]
|
||||
_ADMISSION_PROBE_TIMEOUT_S = 2.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -47,6 +50,18 @@ class ReplayConfig:
|
||||
stream_idle_timeout_s: float | None = 900.0
|
||||
kvcache_direct_max_uncached_tokens: int = 2048
|
||||
kvcache_admission_mode: KvCacheAdmissionMode = "router"
|
||||
kvcache_seed_max_resident_tokens: int | None = None
|
||||
kvcache_seed_max_output_tokens: int | None = None
|
||||
kvcache_seed_min_turn_id: int = 1
|
||||
kvcache_seed_only_multiturn_sessions: bool = False
|
||||
kvcache_seed_allowed_session_ids: frozenset[str] | None = None
|
||||
kvcache_prefill_backup_policy: KvCachePrefillBackupPolicy = (
|
||||
"release-after-transfer"
|
||||
)
|
||||
kvcache_seed_max_inflight_decode: int | None = 3
|
||||
kvcache_prefill_priority_eviction: bool = False
|
||||
kvcache_prefill_direct_priority: int = -100
|
||||
kvcache_prefill_normal_priority: int = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -104,11 +119,23 @@ class ExecutionResult:
|
||||
latency_s: float | None
|
||||
ttft_s: float | None
|
||||
tpot_s: float | None
|
||||
prefill_request_priority: int | None = None
|
||||
decode_request_priority: int | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
|
||||
requests = load_trace(config.trace_path, request_limit=config.request_limit)
|
||||
if config.kvcache_seed_only_multiturn_sessions:
|
||||
session_turns = Counter(request.session_id for request in requests)
|
||||
config = replace(
|
||||
config,
|
||||
kvcache_seed_allowed_session_ids=frozenset(
|
||||
session_id
|
||||
for session_id, turn_count in session_turns.items()
|
||||
if turn_count > 1
|
||||
),
|
||||
)
|
||||
policy = create_policy(config.policy_name)
|
||||
state = RoutingState.create(config.topology)
|
||||
state_lock = asyncio.Lock()
|
||||
@@ -242,6 +269,8 @@ async def _run_request(
|
||||
latency_s=execution.latency_s,
|
||||
ttft_s=execution.ttft_s,
|
||||
tpot_s=execution.tpot_s,
|
||||
prefill_request_priority=execution.prefill_request_priority,
|
||||
decode_request_priority=execution.decode_request_priority,
|
||||
error=execution.error,
|
||||
)
|
||||
|
||||
@@ -253,6 +282,8 @@ async def _invoke_router(
|
||||
config: ReplayConfig,
|
||||
decode_worker_index: int,
|
||||
session_id: str | None = None,
|
||||
prefill_request_priority: int | None = None,
|
||||
decode_request_priority: int | None = None,
|
||||
) -> tuple[float, float | None, float | None, int]:
|
||||
headers = _build_headers(
|
||||
request=request,
|
||||
@@ -274,6 +305,10 @@ async def _invoke_router(
|
||||
}
|
||||
if session_id is not None:
|
||||
payload["session_params"] = {"id": session_id}
|
||||
if prefill_request_priority is not None:
|
||||
payload["smg_prefill_priority"] = prefill_request_priority
|
||||
if decode_request_priority is not None:
|
||||
payload["smg_decode_priority"] = decode_request_priority
|
||||
|
||||
return await _invoke_generate(
|
||||
client=client,
|
||||
@@ -462,6 +497,7 @@ async def _open_streaming_session(
|
||||
"session_id": session_id,
|
||||
"streaming": True,
|
||||
},
|
||||
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
||||
)
|
||||
response.raise_for_status()
|
||||
opened_session_id = response.json()
|
||||
@@ -481,6 +517,7 @@ async def _close_streaming_session(
|
||||
response = await client.post(
|
||||
f"{server_url.rstrip('/')}/close_session",
|
||||
json={"session_id": session_id},
|
||||
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
||||
)
|
||||
if response.is_success:
|
||||
return
|
||||
@@ -538,7 +575,10 @@ async def _fetch_decode_server_state(
|
||||
server_url: str,
|
||||
) -> tuple[dict[str, Any], int, int]:
|
||||
try:
|
||||
response = await client.get(f"{server_url.rstrip('/')}/server_info")
|
||||
response = await client.get(
|
||||
f"{server_url.rstrip('/')}/server_info",
|
||||
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
except Exception:
|
||||
@@ -567,6 +607,7 @@ async def _query_decode_direct_admission(
|
||||
"uncached_input_tokens": max(0, uncached_input_tokens),
|
||||
"output_tokens": max(0, output_tokens),
|
||||
},
|
||||
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
@@ -643,6 +684,51 @@ def _estimate_session_resident_tokens(request: TraceRequest) -> int:
|
||||
return request.input_length + request.output_length
|
||||
|
||||
|
||||
def _seed_filter_reason(
|
||||
*,
|
||||
request: TraceRequest,
|
||||
config: ReplayConfig,
|
||||
inflight_decode_load: int | None = None,
|
||||
) -> str | None:
|
||||
if request.turn_id < config.kvcache_seed_min_turn_id:
|
||||
return "seed-filter-early-turn"
|
||||
if (
|
||||
config.kvcache_seed_max_inflight_decode is not None
|
||||
and inflight_decode_load is not None
|
||||
and inflight_decode_load > config.kvcache_seed_max_inflight_decode
|
||||
):
|
||||
return "seed-filter-inflight-decode-load"
|
||||
if (
|
||||
config.kvcache_seed_allowed_session_ids is not None
|
||||
and request.session_id not in config.kvcache_seed_allowed_session_ids
|
||||
):
|
||||
return "seed-filter-single-turn-session"
|
||||
resident_tokens = _estimate_session_resident_tokens(request)
|
||||
if (
|
||||
config.kvcache_seed_max_resident_tokens is not None
|
||||
and resident_tokens > config.kvcache_seed_max_resident_tokens
|
||||
):
|
||||
return "seed-filter-resident-tokens"
|
||||
if (
|
||||
config.kvcache_seed_max_output_tokens is not None
|
||||
and request.output_length > config.kvcache_seed_max_output_tokens
|
||||
):
|
||||
return "seed-filter-output-tokens"
|
||||
return None
|
||||
|
||||
|
||||
def _prefill_priority_for_router_request(
|
||||
*,
|
||||
config: ReplayConfig,
|
||||
direct_to_d_predicted: bool,
|
||||
) -> int | None:
|
||||
if not config.kvcache_prefill_priority_eviction:
|
||||
return None
|
||||
if direct_to_d_predicted:
|
||||
return config.kvcache_prefill_direct_priority
|
||||
return config.kvcache_prefill_normal_priority
|
||||
|
||||
|
||||
def _inspect_direct_request(
|
||||
*,
|
||||
request: TraceRequest,
|
||||
@@ -802,6 +888,7 @@ async def _fetch_decode_load_snapshot(
|
||||
response = await client.get(
|
||||
f"{server_url.rstrip('/')}/v1/loads",
|
||||
params={"include": "core,disagg"},
|
||||
timeout=_ADMISSION_PROBE_TIMEOUT_S,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
@@ -865,6 +952,13 @@ def _is_decode_backpressure_reason(reason: str | None) -> bool:
|
||||
}
|
||||
|
||||
|
||||
def _is_stale_decode_session_error(exc: Exception) -> bool:
|
||||
return (
|
||||
isinstance(exc, httpx.HTTPStatusError)
|
||||
and exc.response.status_code == 400
|
||||
)
|
||||
|
||||
|
||||
def _dynamic_decode_headroom_tokens(
|
||||
*,
|
||||
residency: DecodeResidencyState,
|
||||
@@ -1469,17 +1563,23 @@ async def _invoke_plain_router(
|
||||
decision,
|
||||
execution_mode: str,
|
||||
) -> ExecutionResult:
|
||||
prefill_priority = _prefill_priority_for_router_request(
|
||||
config=config,
|
||||
direct_to_d_predicted=False,
|
||||
)
|
||||
latency_s, ttft_s, tpot_s, cached_tokens = await _invoke_router(
|
||||
client=client,
|
||||
request=request,
|
||||
config=config,
|
||||
decode_worker_index=decision.decode_worker_index,
|
||||
prefill_request_priority=prefill_priority,
|
||||
)
|
||||
return ExecutionResult(
|
||||
execution_mode=execution_mode,
|
||||
actual_kv_transfer_blocks=decision.kv_transfer_blocks,
|
||||
effective_input_length=request.input_length,
|
||||
cached_tokens=cached_tokens,
|
||||
prefill_request_priority=prefill_priority,
|
||||
session_reused=False,
|
||||
session_reset=False,
|
||||
latency_s=latency_s,
|
||||
@@ -1502,17 +1602,20 @@ async def _invoke_kvcache_seeded_router(
|
||||
reserved_tokens: int,
|
||||
execution_mode: str,
|
||||
) -> ExecutionResult:
|
||||
keep_prefill_backup = False
|
||||
prefill_reserved_tokens = 0
|
||||
async with direct_session_lock:
|
||||
keep_prefill_backup, prefill_reserved_tokens, _prefill_evicted = (
|
||||
await _reserve_prefill_backup_capacity(
|
||||
client=client,
|
||||
request=request,
|
||||
prefill_url=prefill_url,
|
||||
session=decode_session,
|
||||
direct_sessions=direct_sessions,
|
||||
residency=decode_residency,
|
||||
if config.kvcache_prefill_backup_policy == "capacity-backup":
|
||||
keep_prefill_backup, prefill_reserved_tokens, _prefill_evicted = (
|
||||
await _reserve_prefill_backup_capacity(
|
||||
client=client,
|
||||
request=request,
|
||||
prefill_url=prefill_url,
|
||||
session=decode_session,
|
||||
direct_sessions=direct_sessions,
|
||||
residency=decode_residency,
|
||||
)
|
||||
)
|
||||
)
|
||||
if (
|
||||
decode_session.prefill_opened
|
||||
and decode_session.prefill_server_url != prefill_url
|
||||
@@ -1538,6 +1641,10 @@ async def _invoke_kvcache_seeded_router(
|
||||
|
||||
decode_session_newly_opened = False
|
||||
try:
|
||||
prefill_priority = _prefill_priority_for_router_request(
|
||||
config=config,
|
||||
direct_to_d_predicted=True,
|
||||
)
|
||||
async with direct_session_lock:
|
||||
if not decode_session.opened:
|
||||
await _open_streaming_session(
|
||||
@@ -1555,6 +1662,7 @@ async def _invoke_kvcache_seeded_router(
|
||||
config=config,
|
||||
decode_worker_index=decision.decode_worker_index,
|
||||
session_id=request.session_id,
|
||||
prefill_request_priority=prefill_priority,
|
||||
)
|
||||
except Exception:
|
||||
async with direct_session_lock:
|
||||
@@ -1615,6 +1723,7 @@ async def _invoke_kvcache_seeded_router(
|
||||
actual_kv_transfer_blocks=decision.kv_transfer_blocks,
|
||||
effective_input_length=request.input_length,
|
||||
cached_tokens=cached_tokens,
|
||||
prefill_request_priority=prefill_priority,
|
||||
session_reused=False,
|
||||
session_reset=False,
|
||||
latency_s=latency_s,
|
||||
@@ -1697,6 +1806,19 @@ async def _execute_request(
|
||||
)
|
||||
|
||||
if request.turn_id == 1:
|
||||
seed_filter_reason = _seed_filter_reason(
|
||||
request=request,
|
||||
config=config,
|
||||
inflight_decode_load=decision.inflight_decode_load,
|
||||
)
|
||||
if seed_filter_reason is not None:
|
||||
return await _invoke_plain_router(
|
||||
client=client,
|
||||
request=request,
|
||||
config=config,
|
||||
decision=decision,
|
||||
execution_mode=f"pd-router-turn1-{seed_filter_reason}",
|
||||
)
|
||||
async with direct_session_lock:
|
||||
admit_new_decode_session = _should_admit_new_decode_session(
|
||||
residency=decode_residency,
|
||||
@@ -1800,16 +1922,33 @@ async def _execute_request(
|
||||
)
|
||||
)
|
||||
if can_direct:
|
||||
return await _invoke_decode_session_direct(
|
||||
client=client,
|
||||
request=request,
|
||||
config=config,
|
||||
decision=decision,
|
||||
direct_sessions=direct_sessions,
|
||||
direct_session_lock=direct_session_lock,
|
||||
decode_residency=decode_residency,
|
||||
reserved_tokens=direct_reserved_tokens,
|
||||
)
|
||||
try:
|
||||
return await _invoke_decode_session_direct(
|
||||
client=client,
|
||||
request=request,
|
||||
config=config,
|
||||
decision=decision,
|
||||
direct_sessions=direct_sessions,
|
||||
direct_session_lock=direct_session_lock,
|
||||
decode_residency=decode_residency,
|
||||
reserved_tokens=direct_reserved_tokens,
|
||||
)
|
||||
except Exception as exc:
|
||||
if not _is_stale_decode_session_error(exc):
|
||||
raise
|
||||
async with direct_session_lock:
|
||||
await _close_decode_session(
|
||||
client=client,
|
||||
session=decode_session,
|
||||
residency=decode_residency,
|
||||
)
|
||||
return await _invoke_plain_router(
|
||||
client=client,
|
||||
request=request,
|
||||
config=config,
|
||||
decision=decision,
|
||||
execution_mode="pd-router-fallback-stale-d-session",
|
||||
)
|
||||
if _is_decode_backpressure_reason(direct_reason):
|
||||
return await _invoke_plain_router(
|
||||
client=client,
|
||||
@@ -1819,6 +1958,19 @@ async def _execute_request(
|
||||
execution_mode="pd-router-fallback-d-backpressure",
|
||||
)
|
||||
|
||||
seed_filter_reason = _seed_filter_reason(
|
||||
request=request,
|
||||
config=config,
|
||||
inflight_decode_load=decision.inflight_decode_load,
|
||||
)
|
||||
if seed_filter_reason is not None:
|
||||
return await _invoke_plain_router(
|
||||
client=client,
|
||||
request=request,
|
||||
config=config,
|
||||
decision=decision,
|
||||
execution_mode=f"pd-router-fallback-{seed_filter_reason}",
|
||||
)
|
||||
async with direct_session_lock:
|
||||
admit_new_decode_session = _should_admit_new_decode_session(
|
||||
residency=decode_residency,
|
||||
@@ -1894,6 +2046,19 @@ async def _execute_request(
|
||||
),
|
||||
)
|
||||
|
||||
seed_filter_reason = _seed_filter_reason(
|
||||
request=request,
|
||||
config=config,
|
||||
inflight_decode_load=decision.inflight_decode_load,
|
||||
)
|
||||
if seed_filter_reason is not None:
|
||||
return await _invoke_plain_router(
|
||||
request=request,
|
||||
client=client,
|
||||
config=config,
|
||||
decision=decision,
|
||||
execution_mode=f"pd-router-fallback-large-append-{seed_filter_reason}",
|
||||
)
|
||||
async with direct_session_lock:
|
||||
admit_new_decode_session = _should_admit_new_decode_session(
|
||||
residency=decode_residency,
|
||||
|
||||
@@ -64,6 +64,7 @@ def launch_pd_stack(
|
||||
prefill_policy: str,
|
||||
decode_policy: str,
|
||||
timeout_s: float = 1200.0,
|
||||
router_request_timeout_s: float | None = None,
|
||||
include_router: bool = True,
|
||||
) -> ManagedPdStack:
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -75,6 +76,7 @@ def launch_pd_stack(
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
include_router=include_router,
|
||||
router_request_timeout_s=router_request_timeout_s,
|
||||
)
|
||||
|
||||
prefill_processes = [
|
||||
@@ -186,7 +188,7 @@ def _build_process_env(topology: SingleNodeTopology) -> dict[str, str]:
|
||||
env["NO_PROXY"] = "*"
|
||||
env["no_proxy"] = "*"
|
||||
env.setdefault("SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", "600")
|
||||
env.setdefault("SGLANG_DISAGGREGATION_WAITING_TIMEOUT", "60")
|
||||
env.setdefault("SGLANG_DISAGGREGATION_WAITING_TIMEOUT", "600")
|
||||
env.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1")
|
||||
if topology.force_rdma:
|
||||
env["MOONCAKE_PROTOCOL"] = "rdma"
|
||||
|
||||
@@ -176,7 +176,7 @@ NSA_CHOICES = [
|
||||
"trtllm",
|
||||
]
|
||||
|
||||
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu", "slru"]
|
||||
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu", "slru", "priority"]
|
||||
|
||||
RL_ON_POLICY_TARGET_CHOICES = ["fsdp"]
|
||||
|
||||
@@ -4049,7 +4049,7 @@ class ServerArgs:
|
||||
type=str,
|
||||
choices=RADIX_EVICTION_POLICY_CHOICES,
|
||||
default=ServerArgs.radix_eviction_policy,
|
||||
help="The eviction policy of radix trees. 'lru' stands for Least Recently Used, 'lfu' stands for Least Frequently Used, and 'slru' stands for Segmented Least Recently Used.",
|
||||
help="The eviction policy of radix trees. 'lru' stands for Least Recently Used, 'lfu' stands for Least Frequently Used, 'slru' stands for Segmented Least Recently Used, and 'priority' evicts lower request priority values first.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-prefill-delayer",
|
||||
|
||||
Reference in New Issue
Block a user