feat: add agentic pd hybrid benchmark prototype
This commit is contained in:
24
pyproject.toml
Normal file
24
pyproject.toml
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
[project]
|
||||||
|
name = "agentic-pd-hybrid"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Prototype for session-aware and KV-cache-aware PD routing on SGLang xPyD"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"httpx>=0.28.1",
|
||||||
|
"mooncake-transfer-engine",
|
||||||
|
"sglang==0.5.10",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
agentic-pd-hybrid = "agentic_pd_hybrid.cli:main"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=68"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["src"]
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
prerelease = "allow"
|
||||||
12
src/agentic_pd_hybrid/__init__.py
Normal file
12
src/agentic_pd_hybrid/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""Agentic PD hybrid prototype."""
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"cli",
|
||||||
|
"launcher",
|
||||||
|
"metrics",
|
||||||
|
"microbench",
|
||||||
|
"policies",
|
||||||
|
"replay",
|
||||||
|
"topology",
|
||||||
|
"trace",
|
||||||
|
]
|
||||||
221
src/agentic_pd_hybrid/benchmark.py
Normal file
221
src/agentic_pd_hybrid/benchmark.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import signal
|
||||||
|
from dataclasses import asdict, dataclass, replace
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from agentic_pd_hybrid.replay import ReplayConfig, replay_trace
|
||||||
|
from agentic_pd_hybrid.sampling import SessionSampleConfig, sample_trace_sessions
|
||||||
|
from agentic_pd_hybrid.stack import ManagedPdStack, launch_pd_stack
|
||||||
|
from agentic_pd_hybrid.topology import SingleNodeTopology
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class BenchmarkConfig:
|
||||||
|
trace_path: Path
|
||||||
|
output_root: Path
|
||||||
|
topology: SingleNodeTopology
|
||||||
|
policy_name: str
|
||||||
|
mechanism_name: str = "pd-disaggregation"
|
||||||
|
target_duration_s: float = 600.0
|
||||||
|
start_time_s: float = 0.0
|
||||||
|
session_sample_rate: float = 0.01
|
||||||
|
min_turns: int = 1
|
||||||
|
time_scale: float = 1.0
|
||||||
|
concurrency_limit: int = 32
|
||||||
|
timeout_s: float = 1200.0
|
||||||
|
stream: bool = True
|
||||||
|
stream_idle_timeout_s: float | None = 900.0
|
||||||
|
kvcache_direct_max_uncached_tokens: int = 2048
|
||||||
|
kvcache_admission_mode: str = "router"
|
||||||
|
sample_profile: str = "default"
|
||||||
|
min_initial_input_tokens: int | None = None
|
||||||
|
max_initial_input_tokens: int | None = None
|
||||||
|
max_append_input_tokens: int | None = None
|
||||||
|
max_output_tokens: int | None = None
|
||||||
|
min_overlap_ratio: float | None = None
|
||||||
|
launch_stack: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class BenchmarkArtifacts:
|
||||||
|
run_dir: Path
|
||||||
|
sampled_trace_path: Path
|
||||||
|
metrics_path: Path
|
||||||
|
summary_path: Path
|
||||||
|
benchmark_config_path: Path
|
||||||
|
|
||||||
|
|
||||||
|
def run_live_benchmark(config: BenchmarkConfig) -> BenchmarkArtifacts:
|
||||||
|
run_id = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ")
|
||||||
|
run_label = f"{config.mechanism_name}-{config.policy_name}"
|
||||||
|
if config.mechanism_name == "kvcache-centric":
|
||||||
|
run_label = f"{run_label}-{config.kvcache_admission_mode}-admission"
|
||||||
|
run_dir = config.output_root / f"{run_label}-{run_id}"
|
||||||
|
run_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
topology = config.topology
|
||||||
|
if config.mechanism_name == "kvcache-centric":
|
||||||
|
topology = replace(
|
||||||
|
topology,
|
||||||
|
prefill_extra_server_args=topology.prefill_extra_server_args
|
||||||
|
+ ("--enable-streaming-session",),
|
||||||
|
decode_extra_server_args=topology.decode_extra_server_args
|
||||||
|
+ (
|
||||||
|
"--enable-streaming-session",
|
||||||
|
"--disaggregation-decode-allow-local-prefill",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
sampled_trace_path = run_dir / "sampled-trace.jsonl"
|
||||||
|
sample_summary = sample_trace_sessions(
|
||||||
|
SessionSampleConfig(
|
||||||
|
trace_path=config.trace_path,
|
||||||
|
output_path=sampled_trace_path,
|
||||||
|
target_duration_s=config.target_duration_s,
|
||||||
|
start_time_s=config.start_time_s,
|
||||||
|
session_sample_rate=config.session_sample_rate,
|
||||||
|
min_turns=config.min_turns,
|
||||||
|
profile=config.sample_profile, # type: ignore[arg-type]
|
||||||
|
min_initial_input_tokens=config.min_initial_input_tokens,
|
||||||
|
max_initial_input_tokens=config.max_initial_input_tokens,
|
||||||
|
max_append_input_tokens=config.max_append_input_tokens,
|
||||||
|
max_output_tokens=config.max_output_tokens,
|
||||||
|
min_overlap_ratio=config.min_overlap_ratio,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
stack: ManagedPdStack | None = None
|
||||||
|
previous_sigint = signal.getsignal(signal.SIGINT)
|
||||||
|
previous_sigterm = signal.getsignal(signal.SIGTERM)
|
||||||
|
|
||||||
|
def _handle_termination(signum, _frame) -> None:
|
||||||
|
if stack is not None:
|
||||||
|
stack.stop()
|
||||||
|
raise SystemExit(128 + signum)
|
||||||
|
|
||||||
|
try:
|
||||||
|
signal.signal(signal.SIGINT, _handle_termination)
|
||||||
|
signal.signal(signal.SIGTERM, _handle_termination)
|
||||||
|
if config.launch_stack:
|
||||||
|
stack = launch_pd_stack(
|
||||||
|
topology=topology,
|
||||||
|
run_dir=run_dir,
|
||||||
|
prefill_policy="round_robin",
|
||||||
|
decode_policy=_decode_policy_for(config.policy_name),
|
||||||
|
timeout_s=config.timeout_s,
|
||||||
|
include_router=(
|
||||||
|
config.mechanism_name in {"pd-disaggregation", "kvcache-centric"}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
router_url = (
|
||||||
|
stack.router_url
|
||||||
|
if config.mechanism_name in {"pd-disaggregation", "kvcache-centric"}
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
router_url = (
|
||||||
|
topology.router_url
|
||||||
|
if config.mechanism_name in {"pd-disaggregation", "kvcache-centric"}
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics_path = run_dir / "request-metrics.jsonl"
|
||||||
|
replay_config = ReplayConfig(
|
||||||
|
trace_path=sampled_trace_path,
|
||||||
|
output_path=metrics_path,
|
||||||
|
policy_name=config.policy_name,
|
||||||
|
mechanism_name=config.mechanism_name,
|
||||||
|
topology=topology,
|
||||||
|
router_url=router_url,
|
||||||
|
model_name=topology.model_name,
|
||||||
|
pace=True,
|
||||||
|
time_scale=config.time_scale,
|
||||||
|
request_limit=None,
|
||||||
|
concurrency_limit=config.concurrency_limit,
|
||||||
|
header_mode=_header_mode_for(config.policy_name),
|
||||||
|
timeout_s=config.timeout_s,
|
||||||
|
stream=config.stream,
|
||||||
|
stream_idle_timeout_s=config.stream_idle_timeout_s,
|
||||||
|
kvcache_direct_max_uncached_tokens=config.kvcache_direct_max_uncached_tokens,
|
||||||
|
kvcache_admission_mode=config.kvcache_admission_mode, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
asyncio.run(replay_trace(replay_config))
|
||||||
|
finally:
|
||||||
|
signal.signal(signal.SIGINT, previous_sigint)
|
||||||
|
signal.signal(signal.SIGTERM, previous_sigterm)
|
||||||
|
if stack is not None:
|
||||||
|
stack.stop()
|
||||||
|
|
||||||
|
benchmark_config_path = run_dir / "benchmark-config.json"
|
||||||
|
with benchmark_config_path.open("w", encoding="utf-8") as handle:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"policy_name": config.policy_name,
|
||||||
|
"mechanism_name": config.mechanism_name,
|
||||||
|
"target_duration_s": config.target_duration_s,
|
||||||
|
"start_time_s": config.start_time_s,
|
||||||
|
"session_sample_rate": config.session_sample_rate,
|
||||||
|
"min_turns": config.min_turns,
|
||||||
|
"time_scale": config.time_scale,
|
||||||
|
"concurrency_limit": config.concurrency_limit,
|
||||||
|
"timeout_s": config.timeout_s,
|
||||||
|
"stream": config.stream,
|
||||||
|
"stream_idle_timeout_s": config.stream_idle_timeout_s,
|
||||||
|
"kvcache_direct_max_uncached_tokens": config.kvcache_direct_max_uncached_tokens,
|
||||||
|
"kvcache_admission_mode": config.kvcache_admission_mode,
|
||||||
|
"sample_profile": config.sample_profile,
|
||||||
|
"min_initial_input_tokens": config.min_initial_input_tokens,
|
||||||
|
"max_initial_input_tokens": config.max_initial_input_tokens,
|
||||||
|
"max_append_input_tokens": config.max_append_input_tokens,
|
||||||
|
"max_output_tokens": config.max_output_tokens,
|
||||||
|
"min_overlap_ratio": config.min_overlap_ratio,
|
||||||
|
"sample_summary": asdict(sample_summary),
|
||||||
|
"topology": {
|
||||||
|
"model_path": config.topology.model_path,
|
||||||
|
"router_url": topology.router_url,
|
||||||
|
"transfer_backend": topology.transfer_backend,
|
||||||
|
"force_rdma": topology.force_rdma,
|
||||||
|
"ib_device": topology.ib_device,
|
||||||
|
"prefill_workers": [
|
||||||
|
worker.worker_id for worker in topology.prefill_workers
|
||||||
|
],
|
||||||
|
"decode_workers": [
|
||||||
|
worker.worker_id for worker in topology.decode_workers
|
||||||
|
],
|
||||||
|
"direct_workers": [
|
||||||
|
worker.worker_id for worker in topology.direct_workers
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
handle,
|
||||||
|
indent=2,
|
||||||
|
sort_keys=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return BenchmarkArtifacts(
|
||||||
|
run_dir=run_dir,
|
||||||
|
sampled_trace_path=sampled_trace_path,
|
||||||
|
metrics_path=run_dir / "request-metrics.jsonl",
|
||||||
|
summary_path=run_dir / "request-metrics.jsonl.summary.json",
|
||||||
|
benchmark_config_path=benchmark_config_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_policy_for(policy_name: str) -> str:
|
||||||
|
if policy_name == "sticky":
|
||||||
|
return "manual"
|
||||||
|
if policy_name == "kv-aware":
|
||||||
|
return "consistent_hashing"
|
||||||
|
return "round_robin"
|
||||||
|
|
||||||
|
|
||||||
|
def _header_mode_for(policy_name: str) -> str:
|
||||||
|
if policy_name == "sticky":
|
||||||
|
return "routing-key"
|
||||||
|
if policy_name == "kv-aware":
|
||||||
|
return "target-worker"
|
||||||
|
return "none"
|
||||||
484
src/agentic_pd_hybrid/cli.py
Normal file
484
src/agentic_pd_hybrid/cli.py
Normal file
@@ -0,0 +1,484 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from agentic_pd_hybrid.benchmark import BenchmarkConfig, run_live_benchmark
|
||||||
|
from agentic_pd_hybrid.launcher import build_launch_plan
|
||||||
|
from agentic_pd_hybrid.microbench import SmallAppendTraceConfig, write_small_append_trace
|
||||||
|
from agentic_pd_hybrid.replay import ReplayConfig, replay_trace
|
||||||
|
from agentic_pd_hybrid.sampling import SessionSampleConfig, sample_trace_sessions
|
||||||
|
from agentic_pd_hybrid.trace_profiles import (
|
||||||
|
NormalizeTraceLengthsConfig,
|
||||||
|
normalize_trace_lengths,
|
||||||
|
)
|
||||||
|
from agentic_pd_hybrid.topology import build_single_node_topology
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_mechanism_name(name: str) -> str:
|
||||||
|
normalized = name.strip().lower()
|
||||||
|
aliases = {
|
||||||
|
"pd-disagg": "pd-disaggregation",
|
||||||
|
"pd-disaggregation": "pd-disaggregation",
|
||||||
|
"pd-hybrid": "pd-disaggregation",
|
||||||
|
"baseline-pd-disagg": "pd-disaggregation",
|
||||||
|
"pd-colo": "pd-colo",
|
||||||
|
"direct-d": "pd-colo",
|
||||||
|
"colocation": "pd-colo",
|
||||||
|
"kvcache-centric": "kvcache-centric",
|
||||||
|
"turn2+-direct-to-d": "kvcache-centric",
|
||||||
|
"pd-with-d-append": "kvcache-centric",
|
||||||
|
}
|
||||||
|
if normalized not in aliases:
|
||||||
|
raise ValueError(f"Unsupported mechanism: {name}")
|
||||||
|
return aliases[normalized]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_gpu_id_list(value: str | None) -> tuple[int, ...] | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
items = [item.strip() for item in value.split(",") if item.strip()]
|
||||||
|
if not items:
|
||||||
|
return tuple()
|
||||||
|
return tuple(int(item) for item in items)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="Agentic PD hybrid prototype")
|
||||||
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
print_launch = subparsers.add_parser(
|
||||||
|
"print-launch",
|
||||||
|
help="Print one-node SGLang PD launch commands",
|
||||||
|
)
|
||||||
|
_add_topology_arguments(print_launch)
|
||||||
|
print_launch.add_argument("--prefill-policy", default="round_robin")
|
||||||
|
print_launch.add_argument("--decode-policy", default="manual")
|
||||||
|
|
||||||
|
replay = subparsers.add_parser(
|
||||||
|
"replay",
|
||||||
|
help="Replay trace and log request-level metrics",
|
||||||
|
)
|
||||||
|
_add_topology_arguments(replay)
|
||||||
|
replay.add_argument("--trace", type=Path, required=True)
|
||||||
|
replay.add_argument("--output", type=Path, required=True)
|
||||||
|
replay.add_argument(
|
||||||
|
"--policy",
|
||||||
|
choices=["default", "sticky", "kv-aware"],
|
||||||
|
default="sticky",
|
||||||
|
)
|
||||||
|
replay.add_argument(
|
||||||
|
"--mechanism",
|
||||||
|
choices=[
|
||||||
|
"pd-disaggregation",
|
||||||
|
"pd-hybrid",
|
||||||
|
"pd-disagg",
|
||||||
|
"pd-colo",
|
||||||
|
"direct-d",
|
||||||
|
"kvcache-centric",
|
||||||
|
"turn2+-direct-to-d",
|
||||||
|
"pd-with-d-append",
|
||||||
|
],
|
||||||
|
default="pd-disaggregation",
|
||||||
|
)
|
||||||
|
replay.add_argument("--router-url")
|
||||||
|
replay.add_argument("--model")
|
||||||
|
replay.add_argument(
|
||||||
|
"--header-mode",
|
||||||
|
choices=["auto", "none", "routing-key", "target-worker"],
|
||||||
|
default="auto",
|
||||||
|
)
|
||||||
|
replay.add_argument(
|
||||||
|
"--request-limit",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Replay at most this many requests",
|
||||||
|
)
|
||||||
|
replay.add_argument(
|
||||||
|
"--no-pace",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable wall-clock pacing from trace timestamps",
|
||||||
|
)
|
||||||
|
replay.add_argument(
|
||||||
|
"--time-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Scale trace timing by this factor when pacing is enabled",
|
||||||
|
)
|
||||||
|
replay.add_argument(
|
||||||
|
"--concurrency-limit",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
)
|
||||||
|
replay.add_argument(
|
||||||
|
"--timeout-s",
|
||||||
|
type=float,
|
||||||
|
default=600.0,
|
||||||
|
)
|
||||||
|
replay.add_argument(
|
||||||
|
"--no-stream",
|
||||||
|
action="store_true",
|
||||||
|
help="Use non-streaming OpenAI responses for more robust E2E-only replay.",
|
||||||
|
)
|
||||||
|
replay.add_argument(
|
||||||
|
"--stream-idle-timeout-s",
|
||||||
|
type=float,
|
||||||
|
default=900.0,
|
||||||
|
help="Abort a streaming request if no SSE line arrives within this many seconds.",
|
||||||
|
)
|
||||||
|
replay.add_argument(
|
||||||
|
"--kvcache-direct-max-uncached-tokens",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="For kvcache-centric routing, bypass P when the uncached suffix is at most this many tokens.",
|
||||||
|
)
|
||||||
|
replay.add_argument(
|
||||||
|
"--kvcache-admission-mode",
|
||||||
|
choices=["router", "worker"],
|
||||||
|
default="router",
|
||||||
|
help=(
|
||||||
|
"For kvcache-centric routing, use router shadow-state admission "
|
||||||
|
"or query the decode worker on the critical path."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
sample = subparsers.add_parser(
|
||||||
|
"sample-sessions",
|
||||||
|
help="Sample a session-granularity trace shard for live benchmarking",
|
||||||
|
)
|
||||||
|
sample.add_argument("--trace", type=Path, required=True)
|
||||||
|
sample.add_argument("--output", type=Path, required=True)
|
||||||
|
sample.add_argument("--target-duration-s", type=float, default=600.0)
|
||||||
|
sample.add_argument("--start-time-s", type=float, default=0.0)
|
||||||
|
sample.add_argument("--session-sample-rate", type=float, default=0.01)
|
||||||
|
sample.add_argument("--min-turns", type=int, default=1)
|
||||||
|
sample.add_argument("--max-requests", type=int, default=None)
|
||||||
|
sample.add_argument(
|
||||||
|
"--profile",
|
||||||
|
choices=["default", "small-append"],
|
||||||
|
default="default",
|
||||||
|
help="Optional workload-shape filter for live benchmarks.",
|
||||||
|
)
|
||||||
|
sample.add_argument("--min-initial-input-tokens", type=int, default=None)
|
||||||
|
sample.add_argument("--max-initial-input-tokens", type=int, default=None)
|
||||||
|
sample.add_argument("--max-append-input-tokens", type=int, default=None)
|
||||||
|
sample.add_argument("--max-output-tokens", type=int, default=None)
|
||||||
|
sample.add_argument("--min-overlap-ratio", type=float, default=None)
|
||||||
|
|
||||||
|
normalize = subparsers.add_parser(
|
||||||
|
"normalize-trace-lengths",
|
||||||
|
help="Rewrite a trace to a fixed turn1/append/output length profile",
|
||||||
|
)
|
||||||
|
normalize.add_argument("--trace", type=Path, required=True)
|
||||||
|
normalize.add_argument("--output", type=Path, required=True)
|
||||||
|
normalize.add_argument("--initial-input-length", type=int, default=10_000)
|
||||||
|
normalize.add_argument("--append-input-length", type=int, default=1_000)
|
||||||
|
normalize.add_argument("--output-length", type=int, default=1_000)
|
||||||
|
normalize.add_argument("--max-requests", type=int, default=None)
|
||||||
|
|
||||||
|
micro = subparsers.add_parser(
|
||||||
|
"make-small-append-trace",
|
||||||
|
help="Generate a synthetic multi-turn trace with small turn2+ appends",
|
||||||
|
)
|
||||||
|
micro.add_argument("--output", type=Path, required=True)
|
||||||
|
micro.add_argument("--session-count", type=int, default=8)
|
||||||
|
micro.add_argument("--turns-per-session", type=int, default=3)
|
||||||
|
micro.add_argument("--initial-input-length", type=int, default=10_000)
|
||||||
|
micro.add_argument("--append-input-length", type=int, default=1_000)
|
||||||
|
micro.add_argument("--output-length", type=int, default=1_000)
|
||||||
|
micro.add_argument("--inter-turn-gap-s", type=float, default=1.0)
|
||||||
|
micro.add_argument("--session-stagger-s", type=float, default=0.1)
|
||||||
|
|
||||||
|
benchmark = subparsers.add_parser(
|
||||||
|
"benchmark-live",
|
||||||
|
help="Launch a real PD stack, sample sessions, and collect live E2E numbers",
|
||||||
|
)
|
||||||
|
_add_topology_arguments(benchmark)
|
||||||
|
benchmark.add_argument("--trace", type=Path, required=True)
|
||||||
|
benchmark.add_argument(
|
||||||
|
"--policy",
|
||||||
|
choices=["default", "sticky", "kv-aware"],
|
||||||
|
default="sticky",
|
||||||
|
)
|
||||||
|
benchmark.add_argument(
|
||||||
|
"--mechanism",
|
||||||
|
choices=[
|
||||||
|
"pd-disaggregation",
|
||||||
|
"pd-hybrid",
|
||||||
|
"pd-disagg",
|
||||||
|
"pd-colo",
|
||||||
|
"direct-d",
|
||||||
|
"kvcache-centric",
|
||||||
|
"turn2+-direct-to-d",
|
||||||
|
"pd-with-d-append",
|
||||||
|
],
|
||||||
|
default="pd-disaggregation",
|
||||||
|
)
|
||||||
|
benchmark.add_argument("--output-root", type=Path, default=Path("outputs/live"))
|
||||||
|
benchmark.add_argument("--target-duration-s", type=float, default=600.0)
|
||||||
|
benchmark.add_argument("--start-time-s", type=float, default=0.0)
|
||||||
|
benchmark.add_argument("--session-sample-rate", type=float, default=0.01)
|
||||||
|
benchmark.add_argument("--min-turns", type=int, default=1)
|
||||||
|
benchmark.add_argument("--time-scale", type=float, default=1.0)
|
||||||
|
benchmark.add_argument("--concurrency-limit", type=int, default=32)
|
||||||
|
benchmark.add_argument("--timeout-s", type=float, default=1200.0)
|
||||||
|
benchmark.add_argument(
|
||||||
|
"--no-stream",
|
||||||
|
action="store_true",
|
||||||
|
help="Use non-streaming OpenAI responses for E2E-only live benchmarking.",
|
||||||
|
)
|
||||||
|
benchmark.add_argument(
|
||||||
|
"--stream-idle-timeout-s",
|
||||||
|
type=float,
|
||||||
|
default=900.0,
|
||||||
|
help="Abort a streaming request if no SSE line arrives within this many seconds.",
|
||||||
|
)
|
||||||
|
benchmark.add_argument(
|
||||||
|
"--kvcache-direct-max-uncached-tokens",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="For kvcache-centric routing, bypass P when the uncached suffix is at most this many tokens.",
|
||||||
|
)
|
||||||
|
benchmark.add_argument(
|
||||||
|
"--kvcache-admission-mode",
|
||||||
|
choices=["router", "worker"],
|
||||||
|
default="router",
|
||||||
|
help=(
|
||||||
|
"For kvcache-centric routing, use router shadow-state admission "
|
||||||
|
"or query the decode worker on the critical path."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
benchmark.add_argument(
|
||||||
|
"--sample-profile",
|
||||||
|
choices=["default", "small-append"],
|
||||||
|
default="default",
|
||||||
|
help="Optional session-shape filter applied before live replay.",
|
||||||
|
)
|
||||||
|
benchmark.add_argument("--min-initial-input-tokens", type=int, default=None)
|
||||||
|
benchmark.add_argument("--max-initial-input-tokens", type=int, default=None)
|
||||||
|
benchmark.add_argument("--max-append-input-tokens", type=int, default=None)
|
||||||
|
benchmark.add_argument("--max-output-tokens", type=int, default=None)
|
||||||
|
benchmark.add_argument("--min-overlap-ratio", type=float, default=None)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.command == "print-launch":
|
||||||
|
topology = _topology_from_args(args)
|
||||||
|
plan = build_launch_plan(
|
||||||
|
topology,
|
||||||
|
prefill_policy=args.prefill_policy,
|
||||||
|
decode_policy=args.decode_policy,
|
||||||
|
include_router=bool(topology.prefill_workers and topology.decode_workers),
|
||||||
|
)
|
||||||
|
print(plan.render())
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.command == "replay":
|
||||||
|
topology = _topology_from_args(args)
|
||||||
|
config = ReplayConfig(
|
||||||
|
trace_path=args.trace,
|
||||||
|
output_path=args.output,
|
||||||
|
policy_name=args.policy,
|
||||||
|
mechanism_name=_normalize_mechanism_name(args.mechanism),
|
||||||
|
topology=topology,
|
||||||
|
router_url=args.router_url,
|
||||||
|
model_name=args.model,
|
||||||
|
pace=not args.no_pace,
|
||||||
|
time_scale=args.time_scale,
|
||||||
|
request_limit=args.request_limit,
|
||||||
|
concurrency_limit=args.concurrency_limit,
|
||||||
|
header_mode=args.header_mode,
|
||||||
|
timeout_s=args.timeout_s,
|
||||||
|
stream=not args.no_stream,
|
||||||
|
stream_idle_timeout_s=args.stream_idle_timeout_s,
|
||||||
|
kvcache_direct_max_uncached_tokens=args.kvcache_direct_max_uncached_tokens,
|
||||||
|
kvcache_admission_mode=args.kvcache_admission_mode,
|
||||||
|
)
|
||||||
|
results = asyncio.run(replay_trace(config))
|
||||||
|
print(
|
||||||
|
f"wrote {len(results)} request records to {args.output} and "
|
||||||
|
f"{args.output}{'.summary.json'}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.command == "sample-sessions":
|
||||||
|
summary = sample_trace_sessions(
|
||||||
|
SessionSampleConfig(
|
||||||
|
trace_path=args.trace,
|
||||||
|
output_path=args.output,
|
||||||
|
target_duration_s=args.target_duration_s,
|
||||||
|
start_time_s=args.start_time_s,
|
||||||
|
session_sample_rate=args.session_sample_rate,
|
||||||
|
min_turns=args.min_turns,
|
||||||
|
max_requests=args.max_requests,
|
||||||
|
profile=args.profile,
|
||||||
|
min_initial_input_tokens=args.min_initial_input_tokens,
|
||||||
|
max_initial_input_tokens=args.max_initial_input_tokens,
|
||||||
|
max_append_input_tokens=args.max_append_input_tokens,
|
||||||
|
max_output_tokens=args.max_output_tokens,
|
||||||
|
min_overlap_ratio=args.min_overlap_ratio,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"wrote {summary.request_count} requests from {summary.session_count} sessions "
|
||||||
|
f"covering {summary.sampled_duration_s:.3f}s to {args.output}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.command == "normalize-trace-lengths":
|
||||||
|
summary = normalize_trace_lengths(
|
||||||
|
NormalizeTraceLengthsConfig(
|
||||||
|
trace_path=args.trace,
|
||||||
|
output_path=args.output,
|
||||||
|
initial_input_length=args.initial_input_length,
|
||||||
|
append_input_length=args.append_input_length,
|
||||||
|
output_length=args.output_length,
|
||||||
|
max_requests=args.max_requests,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"wrote {summary.request_count} normalized requests from "
|
||||||
|
f"{summary.session_count} sessions to {args.output}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.command == "make-small-append-trace":
|
||||||
|
summary = write_small_append_trace(
|
||||||
|
SmallAppendTraceConfig(
|
||||||
|
output_path=args.output,
|
||||||
|
session_count=args.session_count,
|
||||||
|
turns_per_session=args.turns_per_session,
|
||||||
|
initial_input_length=args.initial_input_length,
|
||||||
|
append_input_length=args.append_input_length,
|
||||||
|
output_length=args.output_length,
|
||||||
|
inter_turn_gap_s=args.inter_turn_gap_s,
|
||||||
|
session_stagger_s=args.session_stagger_s,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"wrote {summary.request_count} requests across {summary.session_count} sessions "
|
||||||
|
f"to {args.output}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.command == "benchmark-live":
|
||||||
|
topology = _topology_from_args(args)
|
||||||
|
artifacts = run_live_benchmark(
|
||||||
|
BenchmarkConfig(
|
||||||
|
trace_path=args.trace,
|
||||||
|
output_root=args.output_root,
|
||||||
|
topology=topology,
|
||||||
|
policy_name=args.policy,
|
||||||
|
mechanism_name=_normalize_mechanism_name(args.mechanism),
|
||||||
|
target_duration_s=args.target_duration_s,
|
||||||
|
start_time_s=args.start_time_s,
|
||||||
|
session_sample_rate=args.session_sample_rate,
|
||||||
|
min_turns=args.min_turns,
|
||||||
|
time_scale=args.time_scale,
|
||||||
|
concurrency_limit=args.concurrency_limit,
|
||||||
|
timeout_s=args.timeout_s,
|
||||||
|
stream=not args.no_stream,
|
||||||
|
stream_idle_timeout_s=args.stream_idle_timeout_s,
|
||||||
|
kvcache_direct_max_uncached_tokens=args.kvcache_direct_max_uncached_tokens,
|
||||||
|
kvcache_admission_mode=args.kvcache_admission_mode,
|
||||||
|
sample_profile=args.sample_profile,
|
||||||
|
min_initial_input_tokens=args.min_initial_input_tokens,
|
||||||
|
max_initial_input_tokens=args.max_initial_input_tokens,
|
||||||
|
max_append_input_tokens=args.max_append_input_tokens,
|
||||||
|
max_output_tokens=args.max_output_tokens,
|
||||||
|
min_overlap_ratio=args.min_overlap_ratio,
|
||||||
|
launch_stack=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"benchmark artifacts written under {artifacts.run_dir}; "
|
||||||
|
f"metrics={artifacts.metrics_path} summary={artifacts.summary_path}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
raise AssertionError(f"Unhandled command: {args.command}")
|
||||||
|
|
||||||
|
|
||||||
|
def _add_topology_arguments(parser: argparse.ArgumentParser) -> None:
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
default="~/models/Qwen/Qwen3-Coder-30B-A3B-Instruct",
|
||||||
|
)
|
||||||
|
parser.add_argument("--prefill-workers", type=int, default=1)
|
||||||
|
parser.add_argument("--decode-workers", type=int, default=1)
|
||||||
|
parser.add_argument("--direct-workers", type=int, default=0)
|
||||||
|
parser.add_argument("--prefill-tp-size", type=int, default=1)
|
||||||
|
parser.add_argument("--decode-tp-size", type=int, default=1)
|
||||||
|
parser.add_argument("--direct-tp-size", type=int, default=1)
|
||||||
|
parser.add_argument("--gpu-budget", type=int, default=8)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefill-gpu-ids",
|
||||||
|
default=None,
|
||||||
|
help="Comma-separated GPU IDs for prefill workers, e.g. 3,4",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-gpu-ids",
|
||||||
|
default=None,
|
||||||
|
help="Comma-separated GPU IDs for decode workers, e.g. 5",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--direct-gpu-ids",
|
||||||
|
default=None,
|
||||||
|
help="Comma-separated GPU IDs for direct workers, e.g. 6",
|
||||||
|
)
|
||||||
|
parser.add_argument("--host", default="127.0.0.1")
|
||||||
|
parser.add_argument("--router-port", type=int, default=8000)
|
||||||
|
parser.add_argument("--prefill-port-base", type=int, default=30000)
|
||||||
|
parser.add_argument("--decode-port-base", type=int, default=31000)
|
||||||
|
parser.add_argument("--direct-port-base", type=int, default=32000)
|
||||||
|
parser.add_argument("--bootstrap-port-base", type=int, default=8998)
|
||||||
|
parser.add_argument("--transfer-backend", default="nixl")
|
||||||
|
parser.add_argument(
|
||||||
|
"--force-rdma",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Force real RDMA transport for PD KV transfer. "
|
||||||
|
"Currently this requires Mooncake plus --ib-device."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument("--ib-device", default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-trust-remote-code",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _topology_from_args(args: argparse.Namespace):
|
||||||
|
transfer_backend = args.transfer_backend
|
||||||
|
if args.force_rdma:
|
||||||
|
transfer_backend = "mooncake"
|
||||||
|
|
||||||
|
return build_single_node_topology(
|
||||||
|
model_path=str(Path(args.model_path).expanduser()),
|
||||||
|
prefill_worker_count=args.prefill_workers,
|
||||||
|
decode_worker_count=args.decode_workers,
|
||||||
|
direct_worker_count=args.direct_workers,
|
||||||
|
prefill_tp_size=args.prefill_tp_size,
|
||||||
|
decode_tp_size=args.decode_tp_size,
|
||||||
|
direct_tp_size=args.direct_tp_size,
|
||||||
|
prefill_gpu_ids=_parse_gpu_id_list(args.prefill_gpu_ids),
|
||||||
|
decode_gpu_ids=_parse_gpu_id_list(args.decode_gpu_ids),
|
||||||
|
direct_gpu_ids=_parse_gpu_id_list(args.direct_gpu_ids),
|
||||||
|
total_gpu_budget=args.gpu_budget,
|
||||||
|
host=args.host,
|
||||||
|
router_port=args.router_port,
|
||||||
|
prefill_port_base=args.prefill_port_base,
|
||||||
|
decode_port_base=args.decode_port_base,
|
||||||
|
direct_port_base=args.direct_port_base,
|
||||||
|
bootstrap_port_base=args.bootstrap_port_base,
|
||||||
|
transfer_backend=transfer_backend,
|
||||||
|
force_rdma=args.force_rdma,
|
||||||
|
trust_remote_code=not args.no_trust_remote_code,
|
||||||
|
ib_device=args.ib_device,
|
||||||
|
direct_extra_server_args=("--enable-streaming-session",),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
140
src/agentic_pd_hybrid/launcher.py
Normal file
140
src/agentic_pd_hybrid/launcher.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import shlex
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from agentic_pd_hybrid.topology import SingleNodeTopology, WorkerSpec
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LaunchPlan:
|
||||||
|
prefill_commands: tuple[tuple[str, ...], ...]
|
||||||
|
decode_commands: tuple[tuple[str, ...], ...]
|
||||||
|
direct_commands: tuple[tuple[str, ...], ...]
|
||||||
|
router_command: tuple[str, ...] | None
|
||||||
|
|
||||||
|
def render(self) -> str:
|
||||||
|
sections: list[str] = []
|
||||||
|
for idx, command in enumerate(self.prefill_commands):
|
||||||
|
sections.append(_render_named_command(f"prefill-{idx}", command))
|
||||||
|
for idx, command in enumerate(self.decode_commands):
|
||||||
|
sections.append(_render_named_command(f"decode-{idx}", command))
|
||||||
|
for idx, command in enumerate(self.direct_commands):
|
||||||
|
sections.append(_render_named_command(f"direct-{idx}", command))
|
||||||
|
if self.router_command is not None:
|
||||||
|
sections.append(_render_named_command("router", self.router_command))
|
||||||
|
return "\n\n".join(sections)
|
||||||
|
|
||||||
|
|
||||||
|
def build_launch_plan(
|
||||||
|
topology: SingleNodeTopology,
|
||||||
|
*,
|
||||||
|
prefill_policy: str = "round_robin",
|
||||||
|
decode_policy: str = "manual",
|
||||||
|
include_router: bool = True,
|
||||||
|
) -> LaunchPlan:
|
||||||
|
return LaunchPlan(
|
||||||
|
prefill_commands=tuple(
|
||||||
|
_build_server_command(topology, worker) for worker in topology.prefill_workers
|
||||||
|
),
|
||||||
|
decode_commands=tuple(
|
||||||
|
_build_server_command(topology, worker) for worker in topology.decode_workers
|
||||||
|
),
|
||||||
|
direct_commands=tuple(
|
||||||
|
_build_server_command(topology, worker) for worker in topology.direct_workers
|
||||||
|
),
|
||||||
|
router_command=(
|
||||||
|
_build_router_command(
|
||||||
|
topology,
|
||||||
|
prefill_policy=prefill_policy,
|
||||||
|
decode_policy=decode_policy,
|
||||||
|
)
|
||||||
|
if include_router and topology.prefill_workers and topology.decode_workers
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_server_command(
|
||||||
|
topology: SingleNodeTopology,
|
||||||
|
worker: WorkerSpec,
|
||||||
|
) -> tuple[str, ...]:
|
||||||
|
command = [
|
||||||
|
sys.executable,
|
||||||
|
"-B",
|
||||||
|
"-u",
|
||||||
|
"-m",
|
||||||
|
"sglang.launch_server",
|
||||||
|
"--model-path",
|
||||||
|
topology.model_path,
|
||||||
|
"--host",
|
||||||
|
worker.host,
|
||||||
|
"--port",
|
||||||
|
str(worker.port),
|
||||||
|
"--base-gpu-id",
|
||||||
|
str(worker.gpu_id),
|
||||||
|
"--disaggregation-mode",
|
||||||
|
_disaggregation_mode_for(worker),
|
||||||
|
"--disaggregation-transfer-backend",
|
||||||
|
topology.transfer_backend,
|
||||||
|
]
|
||||||
|
if worker.tp_size > 1:
|
||||||
|
command.extend(["--tp-size", str(worker.tp_size)])
|
||||||
|
if topology.trust_remote_code:
|
||||||
|
command.append("--trust-remote-code")
|
||||||
|
command.append("--enable-cache-report")
|
||||||
|
if worker.bootstrap_port is not None:
|
||||||
|
command.extend(
|
||||||
|
["--disaggregation-bootstrap-port", str(worker.bootstrap_port)]
|
||||||
|
)
|
||||||
|
if topology.ib_device:
|
||||||
|
command.extend(["--disaggregation-ib-device", topology.ib_device])
|
||||||
|
command.extend(topology.extra_server_args)
|
||||||
|
if worker.role == "prefill":
|
||||||
|
command.extend(topology.prefill_extra_server_args)
|
||||||
|
elif worker.role == "decode":
|
||||||
|
command.extend(topology.decode_extra_server_args)
|
||||||
|
else:
|
||||||
|
command.extend(topology.direct_extra_server_args)
|
||||||
|
return tuple(command)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_router_command(
|
||||||
|
topology: SingleNodeTopology,
|
||||||
|
*,
|
||||||
|
prefill_policy: str,
|
||||||
|
decode_policy: str,
|
||||||
|
) -> tuple[str, ...]:
|
||||||
|
command: list[str] = [
|
||||||
|
sys.executable,
|
||||||
|
"-B",
|
||||||
|
"-u",
|
||||||
|
"-m",
|
||||||
|
"agentic_pd_hybrid.pd_router",
|
||||||
|
"--host",
|
||||||
|
topology.router_host,
|
||||||
|
"--port",
|
||||||
|
str(topology.router_port),
|
||||||
|
"--prefill-policy",
|
||||||
|
prefill_policy,
|
||||||
|
"--decode-policy",
|
||||||
|
decode_policy,
|
||||||
|
]
|
||||||
|
for worker in topology.prefill_workers:
|
||||||
|
command.extend(
|
||||||
|
["--prefill", worker.url, str(worker.bootstrap_port or topology.router_port)]
|
||||||
|
)
|
||||||
|
for worker in topology.decode_workers:
|
||||||
|
command.extend(["--decode", worker.url])
|
||||||
|
return tuple(command)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_named_command(name: str, command: tuple[str, ...]) -> str:
|
||||||
|
return f"# {name}\n" + " ".join(shlex.quote(part) for part in command)
|
||||||
|
|
||||||
|
|
||||||
|
def _disaggregation_mode_for(worker: WorkerSpec) -> str:
|
||||||
|
if worker.role == "direct":
|
||||||
|
return "null"
|
||||||
|
return worker.role
|
||||||
165
src/agentic_pd_hybrid/metrics.py
Normal file
165
src/agentic_pd_hybrid/metrics.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import statistics
|
||||||
|
from collections import Counter
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentic_pd_hybrid.policies import RoutingDecision
|
||||||
|
from agentic_pd_hybrid.trace import TraceRequest
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RequestMetrics:
|
||||||
|
request_id: str
|
||||||
|
session_id: str
|
||||||
|
turn_id: int
|
||||||
|
mechanism_name: str
|
||||||
|
execution_mode: str
|
||||||
|
trace_timestamp_s: float
|
||||||
|
input_length: int
|
||||||
|
output_length: int
|
||||||
|
request_type: str
|
||||||
|
policy_name: str
|
||||||
|
assigned_prefill_node: str
|
||||||
|
assigned_decode_node: str
|
||||||
|
assigned_decode_index: int
|
||||||
|
inflight_decode_load_at_assignment: int
|
||||||
|
reuse_expected: bool
|
||||||
|
reuse_observed: bool
|
||||||
|
observed_overlap_blocks: int
|
||||||
|
kv_transfer_blocks: int
|
||||||
|
actual_kv_transfer_blocks: int
|
||||||
|
cached_tokens: int
|
||||||
|
re_prefill_required: bool
|
||||||
|
effective_input_length: int | None
|
||||||
|
session_reused: bool
|
||||||
|
session_reset: bool
|
||||||
|
latency_s: float | None
|
||||||
|
ttft_s: float | None
|
||||||
|
tpot_s: float | None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_decision(
|
||||||
|
cls,
|
||||||
|
request: TraceRequest,
|
||||||
|
decision: RoutingDecision,
|
||||||
|
*,
|
||||||
|
mechanism_name: str,
|
||||||
|
execution_mode: str,
|
||||||
|
actual_kv_transfer_blocks: int,
|
||||||
|
effective_input_length: int | None,
|
||||||
|
cached_tokens: int,
|
||||||
|
session_reused: bool,
|
||||||
|
session_reset: bool,
|
||||||
|
latency_s: float | None,
|
||||||
|
ttft_s: float | None,
|
||||||
|
tpot_s: float | None,
|
||||||
|
error: str | None = None,
|
||||||
|
) -> "RequestMetrics":
|
||||||
|
return cls(
|
||||||
|
request_id=request.request_id,
|
||||||
|
session_id=request.session_id,
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
mechanism_name=mechanism_name,
|
||||||
|
execution_mode=execution_mode,
|
||||||
|
trace_timestamp_s=request.timestamp_s,
|
||||||
|
input_length=request.input_length,
|
||||||
|
output_length=request.output_length,
|
||||||
|
request_type=request.request_type,
|
||||||
|
policy_name=decision.policy_name,
|
||||||
|
assigned_prefill_node=decision.prefill_worker_id,
|
||||||
|
assigned_decode_node=decision.decode_worker_id,
|
||||||
|
assigned_decode_index=decision.decode_worker_index,
|
||||||
|
inflight_decode_load_at_assignment=decision.inflight_decode_load,
|
||||||
|
reuse_expected=decision.reuse_expected,
|
||||||
|
reuse_observed=decision.observed_reuse,
|
||||||
|
observed_overlap_blocks=decision.observed_overlap_blocks,
|
||||||
|
kv_transfer_blocks=decision.kv_transfer_blocks,
|
||||||
|
actual_kv_transfer_blocks=actual_kv_transfer_blocks,
|
||||||
|
cached_tokens=cached_tokens,
|
||||||
|
re_prefill_required=decision.re_prefill_required,
|
||||||
|
effective_input_length=effective_input_length,
|
||||||
|
session_reused=session_reused,
|
||||||
|
session_reset=session_reset,
|
||||||
|
latency_s=latency_s,
|
||||||
|
ttft_s=ttft_s,
|
||||||
|
tpot_s=tpot_s,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def write_metrics_jsonl(path: Path, rows: list[RequestMetrics]) -> None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with path.open("w", encoding="utf-8") as handle:
|
||||||
|
for row in rows:
|
||||||
|
handle.write(json.dumps(asdict(row), sort_keys=True) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def write_summary_json(
|
||||||
|
path: Path,
|
||||||
|
rows: list[RequestMetrics],
|
||||||
|
*,
|
||||||
|
trace_path: Path,
|
||||||
|
router_url: str | None,
|
||||||
|
) -> None:
|
||||||
|
latencies = [row.latency_s for row in rows if row.latency_s is not None]
|
||||||
|
ttfts = [row.ttft_s for row in rows if row.ttft_s is not None]
|
||||||
|
tpots = [row.tpot_s for row in rows if row.tpot_s is not None]
|
||||||
|
per_decode_load = Counter(row.assigned_decode_node for row in rows)
|
||||||
|
per_prefill_load = Counter(row.assigned_prefill_node for row in rows)
|
||||||
|
|
||||||
|
summary: dict[str, Any] = {
|
||||||
|
"trace_path": str(trace_path),
|
||||||
|
"router_url": router_url,
|
||||||
|
"request_count": len(rows),
|
||||||
|
"mechanisms": dict(sorted(Counter(row.mechanism_name for row in rows).items())),
|
||||||
|
"execution_modes": dict(sorted(Counter(row.execution_mode for row in rows).items())),
|
||||||
|
"latency_stats_s": _stats(latencies),
|
||||||
|
"ttft_stats_s": _stats(ttfts),
|
||||||
|
"tpot_stats_s": _stats(tpots),
|
||||||
|
"reuse_expected_count": sum(1 for row in rows if row.reuse_expected),
|
||||||
|
"reuse_observed_count": sum(1 for row in rows if row.reuse_observed),
|
||||||
|
"re_prefill_count": sum(1 for row in rows if row.re_prefill_required),
|
||||||
|
"cache_hit_request_count": sum(1 for row in rows if row.cached_tokens > 0),
|
||||||
|
"total_cached_tokens": sum(row.cached_tokens for row in rows),
|
||||||
|
"cached_tokens_stats": _stats([float(row.cached_tokens) for row in rows]),
|
||||||
|
"session_reused_count": sum(1 for row in rows if row.session_reused),
|
||||||
|
"session_reset_count": sum(1 for row in rows if row.session_reset),
|
||||||
|
"total_kv_transfer_blocks": sum(row.kv_transfer_blocks for row in rows),
|
||||||
|
"total_actual_kv_transfer_blocks": sum(
|
||||||
|
row.actual_kv_transfer_blocks for row in rows
|
||||||
|
),
|
||||||
|
"per_decode_load": dict(sorted(per_decode_load.items())),
|
||||||
|
"per_prefill_load": dict(sorted(per_prefill_load.items())),
|
||||||
|
"error_count": sum(1 for row in rows if row.error is not None),
|
||||||
|
}
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with path.open("w", encoding="utf-8") as handle:
|
||||||
|
json.dump(summary, handle, indent=2, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _stats(values: list[float | None]) -> dict[str, float] | None:
|
||||||
|
clean = [value for value in values if value is not None]
|
||||||
|
if not clean:
|
||||||
|
return None
|
||||||
|
clean.sort()
|
||||||
|
return {
|
||||||
|
"count": float(len(clean)),
|
||||||
|
"mean": statistics.fmean(clean),
|
||||||
|
"p50": _percentile(clean, 0.50),
|
||||||
|
"p90": _percentile(clean, 0.90),
|
||||||
|
"p99": _percentile(clean, 0.99),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _percentile(sorted_values: list[float], percentile: float) -> float:
|
||||||
|
if not sorted_values:
|
||||||
|
raise ValueError("sorted_values must not be empty")
|
||||||
|
if len(sorted_values) == 1:
|
||||||
|
return sorted_values[0]
|
||||||
|
index = round((len(sorted_values) - 1) * percentile)
|
||||||
|
return sorted_values[index]
|
||||||
123
src/agentic_pd_hybrid/microbench.py
Normal file
123
src/agentic_pd_hybrid/microbench.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from math import ceil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
BLOCK_TOKEN_BUDGET = 24
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SmallAppendTraceConfig:
|
||||||
|
output_path: Path
|
||||||
|
session_count: int = 8
|
||||||
|
turns_per_session: int = 3
|
||||||
|
initial_input_length: int = 10_000
|
||||||
|
append_input_length: int = 1_000
|
||||||
|
output_length: int = 1_000
|
||||||
|
inter_turn_gap_s: float = 1.0
|
||||||
|
session_stagger_s: float = 0.1
|
||||||
|
request_type: str = "coder"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SmallAppendTraceSummary:
|
||||||
|
output_path: str
|
||||||
|
session_count: int
|
||||||
|
turns_per_session: int
|
||||||
|
request_count: int
|
||||||
|
initial_input_length: int
|
||||||
|
append_input_length: int
|
||||||
|
output_length: int
|
||||||
|
inter_turn_gap_s: float
|
||||||
|
session_stagger_s: float
|
||||||
|
|
||||||
|
|
||||||
|
def write_small_append_trace(config: SmallAppendTraceConfig) -> SmallAppendTraceSummary:
|
||||||
|
if config.session_count <= 0:
|
||||||
|
raise ValueError("session_count must be > 0")
|
||||||
|
if config.turns_per_session <= 0:
|
||||||
|
raise ValueError("turns_per_session must be > 0")
|
||||||
|
if config.initial_input_length < 0:
|
||||||
|
raise ValueError("initial_input_length must be >= 0")
|
||||||
|
if config.append_input_length < 0:
|
||||||
|
raise ValueError("append_input_length must be >= 0")
|
||||||
|
if config.output_length < 0:
|
||||||
|
raise ValueError("output_length must be >= 0")
|
||||||
|
|
||||||
|
config.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
records: list[dict[str, object]] = []
|
||||||
|
next_chat_id = 1_000_000
|
||||||
|
|
||||||
|
for session_idx in range(config.session_count):
|
||||||
|
root_chat_id = next_chat_id
|
||||||
|
previous_chat_id = -1
|
||||||
|
session_base_time = session_idx * config.session_stagger_s
|
||||||
|
base_block_count = ceil(config.initial_input_length / BLOCK_TOKEN_BUDGET)
|
||||||
|
base_hash_ids = [
|
||||||
|
_hash_id_for(session_idx=session_idx, block_idx=block_idx)
|
||||||
|
for block_idx in range(base_block_count)
|
||||||
|
]
|
||||||
|
|
||||||
|
for turn_idx in range(config.turns_per_session):
|
||||||
|
chat_id = root_chat_id if turn_idx == 0 else next_chat_id
|
||||||
|
if turn_idx > 0:
|
||||||
|
next_chat_id += 1
|
||||||
|
|
||||||
|
input_length = config.initial_input_length + turn_idx * (
|
||||||
|
config.append_input_length + config.output_length
|
||||||
|
)
|
||||||
|
total_block_count = ceil(input_length / BLOCK_TOKEN_BUDGET)
|
||||||
|
hash_ids = base_hash_ids + [
|
||||||
|
_hash_id_for(
|
||||||
|
session_idx=session_idx,
|
||||||
|
block_idx=base_block_count + append_block_idx,
|
||||||
|
)
|
||||||
|
for append_block_idx in range(max(0, total_block_count - base_block_count))
|
||||||
|
]
|
||||||
|
|
||||||
|
records.append(
|
||||||
|
{
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"parent_chat_id": previous_chat_id,
|
||||||
|
"timestamp": session_base_time
|
||||||
|
+ turn_idx * config.inter_turn_gap_s,
|
||||||
|
"input_length": input_length,
|
||||||
|
"output_length": config.output_length,
|
||||||
|
"type": config.request_type,
|
||||||
|
"turn": turn_idx + 1,
|
||||||
|
"hash_ids": hash_ids,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
previous_chat_id = chat_id
|
||||||
|
|
||||||
|
next_chat_id += 1
|
||||||
|
|
||||||
|
records.sort(key=lambda item: float(item["timestamp"]))
|
||||||
|
with config.output_path.open("w", encoding="utf-8") as handle:
|
||||||
|
for record in records:
|
||||||
|
handle.write(json.dumps(record, sort_keys=True) + "\n")
|
||||||
|
|
||||||
|
summary = SmallAppendTraceSummary(
|
||||||
|
output_path=str(config.output_path),
|
||||||
|
session_count=config.session_count,
|
||||||
|
turns_per_session=config.turns_per_session,
|
||||||
|
request_count=len(records),
|
||||||
|
initial_input_length=config.initial_input_length,
|
||||||
|
append_input_length=config.append_input_length,
|
||||||
|
output_length=config.output_length,
|
||||||
|
inter_turn_gap_s=config.inter_turn_gap_s,
|
||||||
|
session_stagger_s=config.session_stagger_s,
|
||||||
|
)
|
||||||
|
summary_path = config.output_path.with_suffix(
|
||||||
|
config.output_path.suffix + ".summary.json"
|
||||||
|
)
|
||||||
|
with summary_path.open("w", encoding="utf-8") as handle:
|
||||||
|
json.dump(asdict(summary), handle, indent=2, sort_keys=True)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_id_for(*, session_idx: int, block_idx: int) -> int:
|
||||||
|
return session_idx * 1_000_000 + block_idx
|
||||||
267
src/agentic_pd_hybrid/pd_router.py
Normal file
267
src/agentic_pd_hybrid/pd_router.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
import urllib.parse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from http import HTTPStatus
|
||||||
|
from itertools import chain
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import orjson
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||||
|
|
||||||
|
_STREAM_CHUNK_SIZE = 1024 * 64
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RouterConfig:
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
prefill_urls: list[tuple[str, int]]
|
||||||
|
decode_urls: list[str]
|
||||||
|
prefill_policy: str = "round_robin"
|
||||||
|
decode_policy: str = "manual"
|
||||||
|
request_timeout_s: float = 1800.0
|
||||||
|
|
||||||
|
|
||||||
|
class RouterState:
|
||||||
|
def __init__(self, config: RouterConfig):
|
||||||
|
if not config.prefill_urls:
|
||||||
|
raise ValueError("At least one prefill worker is required")
|
||||||
|
if not config.decode_urls:
|
||||||
|
raise ValueError("At least one decode worker is required")
|
||||||
|
self.config = config
|
||||||
|
self.prefill_cursor = 0
|
||||||
|
self.decode_cursor = 0
|
||||||
|
self.sticky_decode_map: dict[str, int] = {}
|
||||||
|
|
||||||
|
def select_pair(self, headers: dict[str, str]) -> tuple[str, int, str]:
|
||||||
|
prefill_url, bootstrap_port = self.config.prefill_urls[
|
||||||
|
self.prefill_cursor % len(self.config.prefill_urls)
|
||||||
|
]
|
||||||
|
self.prefill_cursor += 1
|
||||||
|
decode_index = self._select_decode_index(headers)
|
||||||
|
return prefill_url, bootstrap_port, self.config.decode_urls[decode_index]
|
||||||
|
|
||||||
|
def _select_decode_index(self, headers: dict[str, str]) -> int:
|
||||||
|
target_worker = headers.get("x-smg-target-worker")
|
||||||
|
routing_key = headers.get("x-smg-routing-key")
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config.decode_policy == "consistent_hashing"
|
||||||
|
and target_worker is not None
|
||||||
|
):
|
||||||
|
idx = int(target_worker)
|
||||||
|
if 0 <= idx < len(self.config.decode_urls):
|
||||||
|
return idx
|
||||||
|
|
||||||
|
if self.config.decode_policy == "manual" and routing_key:
|
||||||
|
cached = self.sticky_decode_map.get(routing_key)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
idx = self.decode_cursor % len(self.config.decode_urls)
|
||||||
|
self.decode_cursor += 1
|
||||||
|
self.sticky_decode_map[routing_key] = idx
|
||||||
|
return idx
|
||||||
|
|
||||||
|
idx = self.decode_cursor % len(self.config.decode_urls)
|
||||||
|
self.decode_cursor += 1
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
router_state: RouterState | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health() -> Response:
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health_generate")
|
||||||
|
async def health_generate() -> Response:
|
||||||
|
state = _require_state()
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
tasks = []
|
||||||
|
for server in chain(
|
||||||
|
(url for url, _ in state.config.prefill_urls),
|
||||||
|
state.config.decode_urls,
|
||||||
|
):
|
||||||
|
tasks.append(session.get(f"{server}/health_generate"))
|
||||||
|
for response in asyncio.as_completed(tasks):
|
||||||
|
async with await response:
|
||||||
|
pass
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
async def models() -> ORJSONResponse:
|
||||||
|
state = _require_state()
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(f"{state.config.prefill_urls[0][0]}/v1/models") as response:
|
||||||
|
payload = await response.json()
|
||||||
|
return ORJSONResponse(payload, status_code=response.status)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def chat_completions(request: Request) -> Response:
|
||||||
|
request_data = await request.json()
|
||||||
|
headers = {key.lower(): value for key, value in request.headers.items()}
|
||||||
|
return await _forward_to_backend(
|
||||||
|
request_data=request_data,
|
||||||
|
headers=headers,
|
||||||
|
endpoint_name="v1/chat/completions",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/completions")
|
||||||
|
async def completions(request: Request) -> Response:
|
||||||
|
request_data = await request.json()
|
||||||
|
headers = {key.lower(): value for key, value in request.headers.items()}
|
||||||
|
return await _forward_to_backend(
|
||||||
|
request_data=request_data,
|
||||||
|
headers=headers,
|
||||||
|
endpoint_name="v1/completions",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate")
|
||||||
|
async def generate(request: Request) -> Response:
|
||||||
|
request_data = await request.json()
|
||||||
|
headers = {key.lower(): value for key, value in request.headers.items()}
|
||||||
|
return await _forward_to_backend(
|
||||||
|
request_data=request_data,
|
||||||
|
headers=headers,
|
||||||
|
endpoint_name="generate",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _forward_to_backend(
|
||||||
|
*,
|
||||||
|
request_data: dict,
|
||||||
|
headers: dict[str, str],
|
||||||
|
endpoint_name: str,
|
||||||
|
) -> Response:
|
||||||
|
state = _require_state()
|
||||||
|
prefill_server, bootstrap_port, decode_server = state.select_pair(headers)
|
||||||
|
modified_request = request_data.copy()
|
||||||
|
modified_request.update(_build_bootstrap_payload(prefill_server, bootstrap_port))
|
||||||
|
|
||||||
|
if request_data.get("stream", False):
|
||||||
|
return StreamingResponse(
|
||||||
|
_stream_generate(
|
||||||
|
modified_request=modified_request,
|
||||||
|
prefill_server=prefill_server,
|
||||||
|
decode_server=decode_server,
|
||||||
|
endpoint_name=endpoint_name,
|
||||||
|
timeout_s=state.config.request_timeout_s,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(total=state.config.request_timeout_s)
|
||||||
|
) as session:
|
||||||
|
prefill_response, decode_response = await asyncio.gather(
|
||||||
|
session.post(f"{prefill_server}/{endpoint_name}", json=modified_request),
|
||||||
|
session.post(f"{decode_server}/{endpoint_name}", json=modified_request),
|
||||||
|
)
|
||||||
|
async with prefill_response:
|
||||||
|
await prefill_response.read()
|
||||||
|
async with decode_response:
|
||||||
|
body = await decode_response.read()
|
||||||
|
return Response(
|
||||||
|
content=body,
|
||||||
|
status_code=decode_response.status,
|
||||||
|
media_type=decode_response.content_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _stream_generate(
|
||||||
|
*,
|
||||||
|
modified_request: dict,
|
||||||
|
prefill_server: str,
|
||||||
|
decode_server: str,
|
||||||
|
endpoint_name: str,
|
||||||
|
timeout_s: float,
|
||||||
|
) -> AsyncIterator[bytes]:
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(total=timeout_s)
|
||||||
|
) as session:
|
||||||
|
prefill_response, decode_response = await asyncio.gather(
|
||||||
|
session.post(f"{prefill_server}/{endpoint_name}", json=modified_request),
|
||||||
|
session.post(f"{decode_server}/{endpoint_name}", json=modified_request),
|
||||||
|
)
|
||||||
|
async with prefill_response, decode_response:
|
||||||
|
if decode_response.status != HTTPStatus.OK:
|
||||||
|
payload = await decode_response.read()
|
||||||
|
yield payload
|
||||||
|
return
|
||||||
|
async for chunk in decode_response.content.iter_chunked(_STREAM_CHUNK_SIZE):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
def _build_bootstrap_payload(prefill_server: str, bootstrap_port: int) -> dict[str, object]:
|
||||||
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||||
|
hostname = parsed_url.hostname
|
||||||
|
if hostname is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Unable to parse prefill hostname from {prefill_server}",
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"bootstrap_host": hostname,
|
||||||
|
"bootstrap_port": bootstrap_port,
|
||||||
|
"bootstrap_room": random.randint(0, 2**63 - 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _require_state() -> RouterState:
|
||||||
|
if router_state is None:
|
||||||
|
raise HTTPException(status_code=500, detail="router not initialized")
|
||||||
|
return router_state
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="Minimal local PD router")
|
||||||
|
parser.add_argument("--host", default="127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefill",
|
||||||
|
nargs=2,
|
||||||
|
metavar=("URL", "BOOTSTRAP_PORT"),
|
||||||
|
action="append",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode",
|
||||||
|
action="append",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
parser.add_argument("--prefill-policy", default="round_robin")
|
||||||
|
parser.add_argument("--decode-policy", default="manual")
|
||||||
|
parser.add_argument("--request-timeout-s", type=float, default=1800.0)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
global router_state
|
||||||
|
router_state = RouterState(
|
||||||
|
RouterConfig(
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
prefill_urls=[(url, int(port)) for url, port in args.prefill],
|
||||||
|
decode_urls=list(args.decode),
|
||||||
|
prefill_policy=args.prefill_policy,
|
||||||
|
decode_policy=args.decode_policy,
|
||||||
|
request_timeout_s=args.request_timeout_s,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
235
src/agentic_pd_hybrid/policies.py
Normal file
235
src/agentic_pd_hybrid/policies.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import Counter
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from agentic_pd_hybrid.topology import SingleNodeTopology
|
||||||
|
from agentic_pd_hybrid.trace import TraceRequest
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionRouteState:
|
||||||
|
last_decode_worker: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RoutingDecision:
|
||||||
|
policy_name: str
|
||||||
|
prefill_worker_id: str
|
||||||
|
decode_worker_id: str
|
||||||
|
decode_worker_index: int
|
||||||
|
reuse_expected: bool
|
||||||
|
observed_overlap_blocks: int
|
||||||
|
kv_transfer_blocks: int
|
||||||
|
inflight_decode_load: int
|
||||||
|
session_id: str
|
||||||
|
request_id: str
|
||||||
|
turn_id: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observed_reuse(self) -> bool:
|
||||||
|
return self.observed_overlap_blocks > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def re_prefill_required(self) -> bool:
|
||||||
|
return self.turn_id > 1 and self.observed_overlap_blocks == 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RoutingState:
|
||||||
|
prefill_cursor: int = 0
|
||||||
|
decode_cursor: int = 0
|
||||||
|
session_state: dict[str, SessionRouteState] = field(default_factory=dict)
|
||||||
|
inflight_decode: Counter[str] = field(default_factory=Counter)
|
||||||
|
decode_assignment_counts: Counter[str] = field(default_factory=Counter)
|
||||||
|
decode_resident_blocks: dict[str, set[int]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, topology: SingleNodeTopology) -> "RoutingState":
|
||||||
|
return cls(
|
||||||
|
decode_resident_blocks={
|
||||||
|
worker.worker_id: set() for worker in topology.route_workers
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def next_prefill_worker_id(self, topology: SingleNodeTopology) -> str:
|
||||||
|
if not topology.prefill_workers:
|
||||||
|
return "none"
|
||||||
|
worker = topology.prefill_workers[self.prefill_cursor % len(topology.prefill_workers)]
|
||||||
|
self.prefill_cursor += 1
|
||||||
|
return worker.worker_id
|
||||||
|
|
||||||
|
def next_decode_worker_id(self, topology: SingleNodeTopology) -> str:
|
||||||
|
route_workers = topology.route_workers
|
||||||
|
worker = route_workers[self.decode_cursor % len(route_workers)]
|
||||||
|
self.decode_cursor += 1
|
||||||
|
return worker.worker_id
|
||||||
|
|
||||||
|
def finish(self, request: TraceRequest, decision: RoutingDecision) -> None:
|
||||||
|
session = self.session_state.setdefault(request.session_id, SessionRouteState())
|
||||||
|
session.last_decode_worker = decision.decode_worker_id
|
||||||
|
self.decode_resident_blocks[decision.decode_worker_id].update(request.hash_ids)
|
||||||
|
self.inflight_decode[decision.decode_worker_id] -= 1
|
||||||
|
if self.inflight_decode[decision.decode_worker_id] <= 0:
|
||||||
|
del self.inflight_decode[decision.decode_worker_id]
|
||||||
|
|
||||||
|
|
||||||
|
class RoutingPolicy(Protocol):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def select(
|
||||||
|
self,
|
||||||
|
request: TraceRequest,
|
||||||
|
*,
|
||||||
|
topology: SingleNodeTopology,
|
||||||
|
state: RoutingState,
|
||||||
|
) -> RoutingDecision:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DefaultPolicy:
|
||||||
|
name: str = "default"
|
||||||
|
|
||||||
|
def select(
|
||||||
|
self,
|
||||||
|
request: TraceRequest,
|
||||||
|
*,
|
||||||
|
topology: SingleNodeTopology,
|
||||||
|
state: RoutingState,
|
||||||
|
) -> RoutingDecision:
|
||||||
|
prefill_worker_id = state.next_prefill_worker_id(topology)
|
||||||
|
decode_worker_id = state.next_decode_worker_id(topology)
|
||||||
|
return _build_decision(
|
||||||
|
policy_name=self.name,
|
||||||
|
request=request,
|
||||||
|
topology=topology,
|
||||||
|
state=state,
|
||||||
|
prefill_worker_id=prefill_worker_id,
|
||||||
|
decode_worker_id=decode_worker_id,
|
||||||
|
reuse_expected=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class StickyDecodePolicy:
|
||||||
|
name: str = "sticky"
|
||||||
|
|
||||||
|
def select(
|
||||||
|
self,
|
||||||
|
request: TraceRequest,
|
||||||
|
*,
|
||||||
|
topology: SingleNodeTopology,
|
||||||
|
state: RoutingState,
|
||||||
|
) -> RoutingDecision:
|
||||||
|
session = state.session_state.get(request.session_id)
|
||||||
|
prefill_worker_id = state.next_prefill_worker_id(topology)
|
||||||
|
if request.turn_id > 1 and session and session.last_decode_worker is not None:
|
||||||
|
decode_worker_id = session.last_decode_worker
|
||||||
|
reuse_expected = True
|
||||||
|
else:
|
||||||
|
decode_worker_id = state.next_decode_worker_id(topology)
|
||||||
|
reuse_expected = False
|
||||||
|
return _build_decision(
|
||||||
|
policy_name=self.name,
|
||||||
|
request=request,
|
||||||
|
topology=topology,
|
||||||
|
state=state,
|
||||||
|
prefill_worker_id=prefill_worker_id,
|
||||||
|
decode_worker_id=decode_worker_id,
|
||||||
|
reuse_expected=reuse_expected,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class KvAwarePolicy:
|
||||||
|
name: str = "kv-aware"
|
||||||
|
sticky_bonus: int = 1
|
||||||
|
|
||||||
|
def select(
|
||||||
|
self,
|
||||||
|
request: TraceRequest,
|
||||||
|
*,
|
||||||
|
topology: SingleNodeTopology,
|
||||||
|
state: RoutingState,
|
||||||
|
) -> RoutingDecision:
|
||||||
|
prefill_worker_id = state.next_prefill_worker_id(topology)
|
||||||
|
session = state.session_state.get(request.session_id)
|
||||||
|
|
||||||
|
best_decode_worker_id: str | None = None
|
||||||
|
best_score: tuple[int, int, int] | None = None
|
||||||
|
for worker in topology.route_workers:
|
||||||
|
overlap = _overlap_blocks(request, state, worker.worker_id)
|
||||||
|
sticky = int(session is not None and session.last_decode_worker == worker.worker_id)
|
||||||
|
inflight_penalty = -state.inflight_decode.get(worker.worker_id, 0)
|
||||||
|
assignment_penalty = -state.decode_assignment_counts.get(worker.worker_id, 0)
|
||||||
|
score = (
|
||||||
|
overlap + sticky * self.sticky_bonus,
|
||||||
|
sticky,
|
||||||
|
inflight_penalty,
|
||||||
|
assignment_penalty,
|
||||||
|
)
|
||||||
|
if best_score is None or score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_decode_worker_id = worker.worker_id
|
||||||
|
|
||||||
|
assert best_decode_worker_id is not None
|
||||||
|
reuse_expected = bool(best_score and best_score[0] > 0)
|
||||||
|
return _build_decision(
|
||||||
|
policy_name=self.name,
|
||||||
|
request=request,
|
||||||
|
topology=topology,
|
||||||
|
state=state,
|
||||||
|
prefill_worker_id=prefill_worker_id,
|
||||||
|
decode_worker_id=best_decode_worker_id,
|
||||||
|
reuse_expected=reuse_expected,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_policy(name: str) -> RoutingPolicy:
|
||||||
|
normalized = name.strip().lower()
|
||||||
|
if normalized == "default":
|
||||||
|
return DefaultPolicy()
|
||||||
|
if normalized == "sticky":
|
||||||
|
return StickyDecodePolicy()
|
||||||
|
if normalized in {"kv-aware", "kv_aware", "kv"}:
|
||||||
|
return KvAwarePolicy()
|
||||||
|
raise ValueError(f"Unsupported policy: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_decision(
|
||||||
|
*,
|
||||||
|
policy_name: str,
|
||||||
|
request: TraceRequest,
|
||||||
|
topology: SingleNodeTopology,
|
||||||
|
state: RoutingState,
|
||||||
|
prefill_worker_id: str,
|
||||||
|
decode_worker_id: str,
|
||||||
|
reuse_expected: bool,
|
||||||
|
) -> RoutingDecision:
|
||||||
|
overlap = _overlap_blocks(request, state, decode_worker_id)
|
||||||
|
state.inflight_decode[decode_worker_id] += 1
|
||||||
|
state.decode_assignment_counts[decode_worker_id] += 1
|
||||||
|
return RoutingDecision(
|
||||||
|
policy_name=policy_name,
|
||||||
|
prefill_worker_id=prefill_worker_id,
|
||||||
|
decode_worker_id=decode_worker_id,
|
||||||
|
decode_worker_index=topology.route_index(decode_worker_id),
|
||||||
|
reuse_expected=reuse_expected,
|
||||||
|
observed_overlap_blocks=overlap,
|
||||||
|
kv_transfer_blocks=max(0, len(request.hash_ids) - overlap),
|
||||||
|
inflight_decode_load=state.inflight_decode[decode_worker_id],
|
||||||
|
session_id=request.session_id,
|
||||||
|
request_id=request.request_id,
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _overlap_blocks(
|
||||||
|
request: TraceRequest,
|
||||||
|
state: RoutingState,
|
||||||
|
decode_worker_id: str,
|
||||||
|
) -> int:
|
||||||
|
resident = state.decode_resident_blocks.get(decode_worker_id, set())
|
||||||
|
return sum(1 for block in request.hash_ids if block in resident)
|
||||||
2271
src/agentic_pd_hybrid/replay.py
Normal file
2271
src/agentic_pd_hybrid/replay.py
Normal file
File diff suppressed because it is too large
Load Diff
295
src/agentic_pd_hybrid/sampling.py
Normal file
295
src/agentic_pd_hybrid/sampling.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from agentic_pd_hybrid.trace import TraceRequest, load_trace
|
||||||
|
|
||||||
|
|
||||||
|
SampleProfile = Literal["default", "small-append"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SessionSampleConfig:
|
||||||
|
trace_path: Path
|
||||||
|
output_path: Path
|
||||||
|
target_duration_s: float = 600.0
|
||||||
|
start_time_s: float = 0.0
|
||||||
|
session_sample_rate: float = 1.0
|
||||||
|
min_turns: int = 1
|
||||||
|
max_requests: int | None = None
|
||||||
|
profile: SampleProfile = "default"
|
||||||
|
min_initial_input_tokens: int | None = None
|
||||||
|
max_initial_input_tokens: int | None = None
|
||||||
|
max_append_input_tokens: int | None = None
|
||||||
|
max_output_tokens: int | None = None
|
||||||
|
min_overlap_ratio: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SessionSampleSummary:
|
||||||
|
input_trace_path: str
|
||||||
|
output_trace_path: str
|
||||||
|
request_count: int
|
||||||
|
session_count: int
|
||||||
|
multi_turn_session_count: int
|
||||||
|
start_time_s: float
|
||||||
|
end_time_s: float
|
||||||
|
sampled_duration_s: float
|
||||||
|
session_sample_rate: float
|
||||||
|
min_turns: int
|
||||||
|
profile: str
|
||||||
|
min_initial_input_tokens: int | None
|
||||||
|
max_initial_input_tokens: int | None
|
||||||
|
max_append_input_tokens: int | None
|
||||||
|
max_output_tokens: int | None
|
||||||
|
min_overlap_ratio: float | None
|
||||||
|
mean_append_input_tokens: float | None
|
||||||
|
mean_turn_overlap_ratio: float | None
|
||||||
|
|
||||||
|
|
||||||
|
def sample_trace_sessions(config: SessionSampleConfig) -> SessionSampleSummary:
|
||||||
|
requests = load_trace(config.trace_path)
|
||||||
|
sessions: dict[str, list[TraceRequest]] = defaultdict(list)
|
||||||
|
for request in requests:
|
||||||
|
sessions[request.session_id].append(request)
|
||||||
|
|
||||||
|
filters = _resolve_filters(config)
|
||||||
|
eligible_sessions = {
|
||||||
|
session_id: session_requests
|
||||||
|
for session_id, session_requests in sessions.items()
|
||||||
|
if len(session_requests) >= filters.min_turns
|
||||||
|
and _session_matches_filters(session_requests, filters)
|
||||||
|
and _keep_session(session_id, config.session_sample_rate)
|
||||||
|
}
|
||||||
|
ordered_sessions = sorted(
|
||||||
|
eligible_sessions.values(),
|
||||||
|
key=lambda session_requests: session_requests[0].timestamp_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
selected_requests: list[TraceRequest] = []
|
||||||
|
sampled_start: float | None = None
|
||||||
|
sampled_end: float | None = None
|
||||||
|
for session_requests in ordered_sessions:
|
||||||
|
session_first = session_requests[0].timestamp_s
|
||||||
|
if session_first < config.start_time_s:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if sampled_start is None:
|
||||||
|
sampled_start = session_first
|
||||||
|
|
||||||
|
selected_requests.extend(session_requests)
|
||||||
|
sampled_end = max(request.timestamp_s for request in session_requests)
|
||||||
|
|
||||||
|
if config.max_requests is not None and len(selected_requests) >= config.max_requests:
|
||||||
|
break
|
||||||
|
if sampled_end - sampled_start >= config.target_duration_s:
|
||||||
|
break
|
||||||
|
|
||||||
|
selected_requests.sort(key=lambda request: request.timestamp_s)
|
||||||
|
if config.max_requests is not None:
|
||||||
|
selected_requests = selected_requests[: config.max_requests]
|
||||||
|
|
||||||
|
if not selected_requests:
|
||||||
|
raise ValueError("Sampling produced no requests; adjust the sampling arguments")
|
||||||
|
|
||||||
|
config.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with config.output_path.open("w", encoding="utf-8") as handle:
|
||||||
|
for request in selected_requests:
|
||||||
|
payload = {
|
||||||
|
"request_id": request.request_id,
|
||||||
|
"session_id": request.session_id,
|
||||||
|
"chat_id": request.chat_id,
|
||||||
|
"parent_chat_id": request.parent_chat_id,
|
||||||
|
"timestamp": request.timestamp_s,
|
||||||
|
"input_length": request.input_length,
|
||||||
|
"output_length": request.output_length,
|
||||||
|
"type": request.request_type,
|
||||||
|
"turn": request.turn_id,
|
||||||
|
"hash_ids": list(request.hash_ids),
|
||||||
|
}
|
||||||
|
handle.write(json.dumps(payload, sort_keys=True) + "\n")
|
||||||
|
|
||||||
|
selected_session_ids = {request.session_id for request in selected_requests}
|
||||||
|
selected_session_requests = [
|
||||||
|
eligible_sessions[session_id] for session_id in selected_session_ids
|
||||||
|
]
|
||||||
|
append_lengths = [
|
||||||
|
length
|
||||||
|
for session_requests in selected_session_requests
|
||||||
|
for length in _turn_append_lengths(session_requests)
|
||||||
|
]
|
||||||
|
overlap_ratios = [
|
||||||
|
ratio
|
||||||
|
for session_requests in selected_session_requests
|
||||||
|
for ratio in _turn_overlap_ratios(session_requests)
|
||||||
|
]
|
||||||
|
summary = SessionSampleSummary(
|
||||||
|
input_trace_path=str(config.trace_path),
|
||||||
|
output_trace_path=str(config.output_path),
|
||||||
|
request_count=len(selected_requests),
|
||||||
|
session_count=len(selected_session_ids),
|
||||||
|
multi_turn_session_count=sum(
|
||||||
|
1
|
||||||
|
for session_id in selected_session_ids
|
||||||
|
if len(eligible_sessions[session_id]) > 1
|
||||||
|
),
|
||||||
|
start_time_s=selected_requests[0].timestamp_s,
|
||||||
|
end_time_s=selected_requests[-1].timestamp_s,
|
||||||
|
sampled_duration_s=selected_requests[-1].timestamp_s
|
||||||
|
- selected_requests[0].timestamp_s,
|
||||||
|
session_sample_rate=config.session_sample_rate,
|
||||||
|
min_turns=filters.min_turns,
|
||||||
|
profile=config.profile,
|
||||||
|
min_initial_input_tokens=filters.min_initial_input_tokens,
|
||||||
|
max_initial_input_tokens=filters.max_initial_input_tokens,
|
||||||
|
max_append_input_tokens=filters.max_append_input_tokens,
|
||||||
|
max_output_tokens=filters.max_output_tokens,
|
||||||
|
min_overlap_ratio=filters.min_overlap_ratio,
|
||||||
|
mean_append_input_tokens=_mean(append_lengths),
|
||||||
|
mean_turn_overlap_ratio=_mean(overlap_ratios),
|
||||||
|
)
|
||||||
|
summary_path = config.output_path.with_suffix(config.output_path.suffix + ".summary.json")
|
||||||
|
with summary_path.open("w", encoding="utf-8") as handle:
|
||||||
|
json.dump(asdict(summary), handle, indent=2, sort_keys=True)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _ResolvedFilters:
|
||||||
|
min_turns: int
|
||||||
|
min_initial_input_tokens: int | None
|
||||||
|
max_initial_input_tokens: int | None
|
||||||
|
max_append_input_tokens: int | None
|
||||||
|
max_output_tokens: int | None
|
||||||
|
min_overlap_ratio: float | None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_filters(config: SessionSampleConfig) -> _ResolvedFilters:
|
||||||
|
if config.profile == "default":
|
||||||
|
return _ResolvedFilters(
|
||||||
|
min_turns=config.min_turns,
|
||||||
|
min_initial_input_tokens=config.min_initial_input_tokens,
|
||||||
|
max_initial_input_tokens=config.max_initial_input_tokens,
|
||||||
|
max_append_input_tokens=config.max_append_input_tokens,
|
||||||
|
max_output_tokens=config.max_output_tokens,
|
||||||
|
min_overlap_ratio=config.min_overlap_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.profile != "small-append":
|
||||||
|
raise ValueError(f"Unsupported sample profile: {config.profile}")
|
||||||
|
|
||||||
|
return _ResolvedFilters(
|
||||||
|
min_turns=max(config.min_turns, 2),
|
||||||
|
min_initial_input_tokens=(
|
||||||
|
2048
|
||||||
|
if config.min_initial_input_tokens is None
|
||||||
|
else config.min_initial_input_tokens
|
||||||
|
),
|
||||||
|
max_initial_input_tokens=(
|
||||||
|
16000
|
||||||
|
if config.max_initial_input_tokens is None
|
||||||
|
else config.max_initial_input_tokens
|
||||||
|
),
|
||||||
|
max_append_input_tokens=(
|
||||||
|
2048
|
||||||
|
if config.max_append_input_tokens is None
|
||||||
|
else config.max_append_input_tokens
|
||||||
|
),
|
||||||
|
max_output_tokens=(
|
||||||
|
2048 if config.max_output_tokens is None else config.max_output_tokens
|
||||||
|
),
|
||||||
|
min_overlap_ratio=(
|
||||||
|
0.75 if config.min_overlap_ratio is None else config.min_overlap_ratio
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _session_matches_filters(
|
||||||
|
session_requests: list[TraceRequest],
|
||||||
|
filters: _ResolvedFilters,
|
||||||
|
) -> bool:
|
||||||
|
ordered = sorted(
|
||||||
|
session_requests,
|
||||||
|
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
|
||||||
|
)
|
||||||
|
if not ordered:
|
||||||
|
return False
|
||||||
|
|
||||||
|
initial = ordered[0]
|
||||||
|
if (
|
||||||
|
filters.min_initial_input_tokens is not None
|
||||||
|
and initial.input_length < filters.min_initial_input_tokens
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
if (
|
||||||
|
filters.max_initial_input_tokens is not None
|
||||||
|
and initial.input_length > filters.max_initial_input_tokens
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
if filters.max_output_tokens is not None and any(
|
||||||
|
request.output_length > filters.max_output_tokens for request in ordered
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
append_lengths = _turn_append_lengths(ordered)
|
||||||
|
if filters.max_append_input_tokens is not None and any(
|
||||||
|
append_length <= 0 or append_length > filters.max_append_input_tokens
|
||||||
|
for append_length in append_lengths
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
overlap_ratios = _turn_overlap_ratios(ordered)
|
||||||
|
if filters.min_overlap_ratio is not None and any(
|
||||||
|
overlap_ratio < filters.min_overlap_ratio for overlap_ratio in overlap_ratios
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _turn_append_lengths(session_requests: list[TraceRequest]) -> list[int]:
|
||||||
|
ordered = sorted(
|
||||||
|
session_requests,
|
||||||
|
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
current.input_length - (previous.input_length + previous.output_length)
|
||||||
|
for previous, current in zip(ordered, ordered[1:], strict=False)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _turn_overlap_ratios(session_requests: list[TraceRequest]) -> list[float]:
|
||||||
|
ordered = sorted(
|
||||||
|
session_requests,
|
||||||
|
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
|
||||||
|
)
|
||||||
|
ratios: list[float] = []
|
||||||
|
for previous, current in zip(ordered, ordered[1:], strict=False):
|
||||||
|
if not current.hash_ids:
|
||||||
|
ratios.append(0.0)
|
||||||
|
continue
|
||||||
|
previous_blocks = set(previous.hash_ids)
|
||||||
|
overlap = sum(1 for block in current.hash_ids if block in previous_blocks)
|
||||||
|
ratios.append(overlap / len(current.hash_ids))
|
||||||
|
return ratios
|
||||||
|
|
||||||
|
|
||||||
|
def _mean(values: list[int] | list[float]) -> float | None:
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
return sum(values) / len(values)
|
||||||
|
|
||||||
|
|
||||||
|
def _keep_session(session_id: str, sample_rate: float) -> bool:
|
||||||
|
if sample_rate >= 1.0:
|
||||||
|
return True
|
||||||
|
if sample_rate <= 0.0:
|
||||||
|
return False
|
||||||
|
digest = hashlib.blake2b(session_id.encode("utf-8"), digest_size=8).digest()
|
||||||
|
bucket = int.from_bytes(digest, byteorder="big", signed=False) / 2**64
|
||||||
|
return bucket < sample_rate
|
||||||
222
src/agentic_pd_hybrid/stack.py
Normal file
222
src/agentic_pd_hybrid/stack.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from agentic_pd_hybrid.launcher import build_launch_plan
|
||||||
|
from agentic_pd_hybrid.topology import SingleNodeTopology
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ManagedProcess:
|
||||||
|
name: str
|
||||||
|
command: tuple[str, ...]
|
||||||
|
process: subprocess.Popen[bytes]
|
||||||
|
log_path: Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ManagedPdStack:
|
||||||
|
topology: SingleNodeTopology
|
||||||
|
run_dir: Path
|
||||||
|
prefill_processes: list[ManagedProcess]
|
||||||
|
decode_processes: list[ManagedProcess]
|
||||||
|
direct_processes: list[ManagedProcess]
|
||||||
|
router_process: ManagedProcess | None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def router_url(self) -> str:
|
||||||
|
return self.topology.router_url
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
processes = (
|
||||||
|
([self.router_process] if self.router_process is not None else [])
|
||||||
|
+ self.direct_processes
|
||||||
|
+ self.decode_processes
|
||||||
|
+ self.prefill_processes
|
||||||
|
)
|
||||||
|
for managed in processes:
|
||||||
|
if managed.process.poll() is None:
|
||||||
|
os.killpg(os.getpgid(managed.process.pid), signal.SIGTERM)
|
||||||
|
deadline = time.time() + 20
|
||||||
|
for managed in processes:
|
||||||
|
if managed.process.poll() is not None:
|
||||||
|
continue
|
||||||
|
remaining = max(0.0, deadline - time.time())
|
||||||
|
try:
|
||||||
|
managed.process.wait(timeout=remaining)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
if managed.process.poll() is None:
|
||||||
|
os.killpg(os.getpgid(managed.process.pid), signal.SIGKILL)
|
||||||
|
managed.process.wait(timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_pd_stack(
|
||||||
|
*,
|
||||||
|
topology: SingleNodeTopology,
|
||||||
|
run_dir: Path,
|
||||||
|
prefill_policy: str,
|
||||||
|
decode_policy: str,
|
||||||
|
timeout_s: float = 1200.0,
|
||||||
|
include_router: bool = True,
|
||||||
|
) -> ManagedPdStack:
|
||||||
|
run_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
logs_dir = run_dir / "logs"
|
||||||
|
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
plan = build_launch_plan(
|
||||||
|
topology,
|
||||||
|
prefill_policy=prefill_policy,
|
||||||
|
decode_policy=decode_policy,
|
||||||
|
include_router=include_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
prefill_processes = [
|
||||||
|
_spawn_process(
|
||||||
|
name=f"prefill-{idx}",
|
||||||
|
command=command,
|
||||||
|
log_path=logs_dir / f"prefill-{idx}.log",
|
||||||
|
topology=topology,
|
||||||
|
)
|
||||||
|
for idx, command in enumerate(plan.prefill_commands)
|
||||||
|
]
|
||||||
|
decode_processes = [
|
||||||
|
_spawn_process(
|
||||||
|
name=f"decode-{idx}",
|
||||||
|
command=command,
|
||||||
|
log_path=logs_dir / f"decode-{idx}.log",
|
||||||
|
topology=topology,
|
||||||
|
)
|
||||||
|
for idx, command in enumerate(plan.decode_commands)
|
||||||
|
]
|
||||||
|
direct_processes = [
|
||||||
|
_spawn_process(
|
||||||
|
name=f"direct-{idx}",
|
||||||
|
command=command,
|
||||||
|
log_path=logs_dir / f"direct-{idx}.log",
|
||||||
|
topology=topology,
|
||||||
|
)
|
||||||
|
for idx, command in enumerate(plan.direct_commands)
|
||||||
|
]
|
||||||
|
|
||||||
|
router_process: ManagedProcess | None = None
|
||||||
|
try:
|
||||||
|
for worker in topology.prefill_workers:
|
||||||
|
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
|
||||||
|
for worker in topology.decode_workers:
|
||||||
|
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
|
||||||
|
for worker in topology.direct_workers:
|
||||||
|
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
|
||||||
|
|
||||||
|
if plan.router_command is not None:
|
||||||
|
router_process = _spawn_process(
|
||||||
|
name="router",
|
||||||
|
command=plan.router_command,
|
||||||
|
log_path=logs_dir / "router.log",
|
||||||
|
topology=topology,
|
||||||
|
)
|
||||||
|
_wait_for_ready_endpoint(f"{topology.router_url}/health", timeout_s=timeout_s)
|
||||||
|
except Exception:
|
||||||
|
stack = ManagedPdStack(
|
||||||
|
topology=topology,
|
||||||
|
run_dir=run_dir,
|
||||||
|
prefill_processes=prefill_processes,
|
||||||
|
decode_processes=decode_processes,
|
||||||
|
direct_processes=direct_processes,
|
||||||
|
router_process=router_process,
|
||||||
|
)
|
||||||
|
stack.stop()
|
||||||
|
raise
|
||||||
|
|
||||||
|
return ManagedPdStack(
|
||||||
|
topology=topology,
|
||||||
|
run_dir=run_dir,
|
||||||
|
prefill_processes=prefill_processes,
|
||||||
|
decode_processes=decode_processes,
|
||||||
|
direct_processes=direct_processes,
|
||||||
|
router_process=router_process,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _spawn_process(
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
command: tuple[str, ...],
|
||||||
|
log_path: Path,
|
||||||
|
topology: SingleNodeTopology,
|
||||||
|
) -> ManagedProcess:
|
||||||
|
log_handle = log_path.open("wb")
|
||||||
|
env = _build_process_env(topology)
|
||||||
|
process = subprocess.Popen(
|
||||||
|
command,
|
||||||
|
stdout=log_handle,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
env=env,
|
||||||
|
preexec_fn=os.setsid,
|
||||||
|
)
|
||||||
|
return ManagedProcess(
|
||||||
|
name=name,
|
||||||
|
command=command,
|
||||||
|
process=process,
|
||||||
|
log_path=log_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_process_env(topology: SingleNodeTopology) -> dict[str, str]:
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["PYTHONDONTWRITEBYTECODE"] = "1"
|
||||||
|
env["PYTHONUNBUFFERED"] = "1"
|
||||||
|
|
||||||
|
# SGLang's PD bootstrap path uses `requests`; force localhost traffic to stay local.
|
||||||
|
for key in (
|
||||||
|
"http_proxy",
|
||||||
|
"https_proxy",
|
||||||
|
"all_proxy",
|
||||||
|
"HTTP_PROXY",
|
||||||
|
"HTTPS_PROXY",
|
||||||
|
"ALL_PROXY",
|
||||||
|
):
|
||||||
|
env.pop(key, None)
|
||||||
|
env["NO_PROXY"] = "*"
|
||||||
|
env["no_proxy"] = "*"
|
||||||
|
env.setdefault("SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", "600")
|
||||||
|
env.setdefault("SGLANG_DISAGGREGATION_WAITING_TIMEOUT", "60")
|
||||||
|
env.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1")
|
||||||
|
if topology.force_rdma:
|
||||||
|
env["MOONCAKE_PROTOCOL"] = "rdma"
|
||||||
|
env["MC_MS_AUTO_DISC"] = "0"
|
||||||
|
if topology.ib_device:
|
||||||
|
env["MOONCAKE_DEVICE"] = topology.ib_device
|
||||||
|
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
python_paths = [
|
||||||
|
str(repo_root / "src"),
|
||||||
|
str(repo_root / "third_party" / "sglang" / "python"),
|
||||||
|
]
|
||||||
|
existing_pythonpath = env.get("PYTHONPATH")
|
||||||
|
if existing_pythonpath:
|
||||||
|
python_paths.append(existing_pythonpath)
|
||||||
|
env["PYTHONPATH"] = os.pathsep.join(python_paths)
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_for_ready_endpoint(url: str, *, timeout_s: float) -> None:
|
||||||
|
start = time.perf_counter()
|
||||||
|
last_error: str | None = None
|
||||||
|
with httpx.Client(timeout=5.0, trust_env=False) as client:
|
||||||
|
while time.perf_counter() - start < timeout_s:
|
||||||
|
try:
|
||||||
|
response = client.get(url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return
|
||||||
|
last_error = f"status={response.status_code}"
|
||||||
|
except Exception as exc: # pragma: no cover
|
||||||
|
last_error = f"{type(exc).__name__}: {exc}"
|
||||||
|
time.sleep(1.0)
|
||||||
|
raise TimeoutError(f"Timed out waiting for {url} ({last_error})")
|
||||||
245
src/agentic_pd_hybrid/topology.py
Normal file
245
src/agentic_pd_hybrid/topology.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
WorkerRole = Literal["prefill", "decode", "direct"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class WorkerSpec:
|
||||||
|
role: WorkerRole
|
||||||
|
ordinal: int
|
||||||
|
gpu_ids: tuple[int, ...]
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
bootstrap_port: int | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def worker_id(self) -> str:
|
||||||
|
return f"{self.role}-{self.ordinal}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url(self) -> str:
|
||||||
|
return f"http://{self.host}:{self.port}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gpu_id(self) -> int:
|
||||||
|
return self.gpu_ids[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tp_size(self) -> int:
|
||||||
|
return len(self.gpu_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SingleNodeTopology:
|
||||||
|
model_path: str
|
||||||
|
prefill_workers: tuple[WorkerSpec, ...]
|
||||||
|
decode_workers: tuple[WorkerSpec, ...]
|
||||||
|
direct_workers: tuple[WorkerSpec, ...]
|
||||||
|
router_host: str
|
||||||
|
router_port: int
|
||||||
|
transfer_backend: str
|
||||||
|
trust_remote_code: bool
|
||||||
|
force_rdma: bool = False
|
||||||
|
ib_device: str | None = None
|
||||||
|
extra_server_args: tuple[str, ...] = ()
|
||||||
|
prefill_extra_server_args: tuple[str, ...] = ()
|
||||||
|
decode_extra_server_args: tuple[str, ...] = ()
|
||||||
|
direct_extra_server_args: tuple[str, ...] = ()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
return Path(self.model_path).name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def router_url(self) -> str:
|
||||||
|
return f"http://{self.router_host}:{self.router_port}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def route_workers(self) -> tuple[WorkerSpec, ...]:
|
||||||
|
if self.decode_workers:
|
||||||
|
return self.decode_workers
|
||||||
|
return self.direct_workers
|
||||||
|
|
||||||
|
def route_index(self, worker_id: str) -> int:
|
||||||
|
for idx, worker in enumerate(self.route_workers):
|
||||||
|
if worker.worker_id == worker_id:
|
||||||
|
return idx
|
||||||
|
raise KeyError(f"Unknown route worker: {worker_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def build_single_node_topology(
|
||||||
|
*,
|
||||||
|
model_path: str,
|
||||||
|
prefill_worker_count: int,
|
||||||
|
decode_worker_count: int,
|
||||||
|
direct_worker_count: int = 0,
|
||||||
|
prefill_tp_size: int = 1,
|
||||||
|
decode_tp_size: int = 1,
|
||||||
|
direct_tp_size: int = 1,
|
||||||
|
prefill_gpu_ids: tuple[int, ...] | None = None,
|
||||||
|
decode_gpu_ids: tuple[int, ...] | None = None,
|
||||||
|
direct_gpu_ids: tuple[int, ...] | None = None,
|
||||||
|
total_gpu_budget: int = 8,
|
||||||
|
host: str = "127.0.0.1",
|
||||||
|
router_port: int = 8000,
|
||||||
|
prefill_port_base: int = 30000,
|
||||||
|
decode_port_base: int = 31000,
|
||||||
|
direct_port_base: int = 32000,
|
||||||
|
bootstrap_port_base: int = 8998,
|
||||||
|
transfer_backend: str = "nixl",
|
||||||
|
force_rdma: bool = False,
|
||||||
|
trust_remote_code: bool = True,
|
||||||
|
ib_device: str | None = None,
|
||||||
|
extra_server_args: tuple[str, ...] = (),
|
||||||
|
prefill_extra_server_args: tuple[str, ...] = (),
|
||||||
|
decode_extra_server_args: tuple[str, ...] = (),
|
||||||
|
direct_extra_server_args: tuple[str, ...] = (),
|
||||||
|
) -> SingleNodeTopology:
|
||||||
|
if prefill_worker_count < 0:
|
||||||
|
raise ValueError("prefill_worker_count must be >= 0")
|
||||||
|
if decode_worker_count < 0:
|
||||||
|
raise ValueError("decode_worker_count must be >= 0")
|
||||||
|
if direct_worker_count < 0:
|
||||||
|
raise ValueError("direct_worker_count must be >= 0")
|
||||||
|
if (
|
||||||
|
prefill_worker_count == 0
|
||||||
|
and decode_worker_count == 0
|
||||||
|
and direct_worker_count == 0
|
||||||
|
):
|
||||||
|
raise ValueError("At least one worker must be configured")
|
||||||
|
if prefill_tp_size <= 0:
|
||||||
|
raise ValueError("prefill_tp_size must be >= 1")
|
||||||
|
if decode_tp_size <= 0:
|
||||||
|
raise ValueError("decode_tp_size must be >= 1")
|
||||||
|
if direct_tp_size <= 0:
|
||||||
|
raise ValueError("direct_tp_size must be >= 1")
|
||||||
|
if force_rdma and not ib_device:
|
||||||
|
raise ValueError("force_rdma requires --ib-device to be set")
|
||||||
|
if force_rdma and transfer_backend != "mooncake":
|
||||||
|
raise ValueError(
|
||||||
|
"force_rdma currently requires transfer_backend='mooncake' "
|
||||||
|
"to guarantee an RDMA path"
|
||||||
|
)
|
||||||
|
|
||||||
|
total_gpus_required = (
|
||||||
|
prefill_worker_count * prefill_tp_size
|
||||||
|
+ decode_worker_count * decode_tp_size
|
||||||
|
+ direct_worker_count * direct_tp_size
|
||||||
|
)
|
||||||
|
if total_gpus_required > total_gpu_budget:
|
||||||
|
raise ValueError(
|
||||||
|
"Single-node GPU budget exceeded: "
|
||||||
|
f"{prefill_worker_count} prefill x tp={prefill_tp_size} + "
|
||||||
|
f"{decode_worker_count} decode x tp={decode_tp_size} + "
|
||||||
|
f"{direct_worker_count} direct x tp={direct_tp_size} > "
|
||||||
|
f"{total_gpu_budget} GPUs"
|
||||||
|
)
|
||||||
|
|
||||||
|
if prefill_gpu_ids is None:
|
||||||
|
prefill_gpu_ids = tuple(range(prefill_worker_count * prefill_tp_size))
|
||||||
|
if decode_gpu_ids is None:
|
||||||
|
decode_gpu_ids = tuple(
|
||||||
|
range(
|
||||||
|
len(prefill_gpu_ids),
|
||||||
|
len(prefill_gpu_ids) + decode_worker_count * decode_tp_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if direct_gpu_ids is None:
|
||||||
|
direct_gpu_ids = tuple(
|
||||||
|
range(
|
||||||
|
len(prefill_gpu_ids) + len(decode_gpu_ids),
|
||||||
|
len(prefill_gpu_ids)
|
||||||
|
+ len(decode_gpu_ids)
|
||||||
|
+ direct_worker_count * direct_tp_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(prefill_gpu_ids) != prefill_worker_count * prefill_tp_size:
|
||||||
|
raise ValueError(
|
||||||
|
"prefill_gpu_ids length must equal prefill_worker_count * prefill_tp_size: "
|
||||||
|
f"{len(prefill_gpu_ids)} != {prefill_worker_count * prefill_tp_size}"
|
||||||
|
)
|
||||||
|
if len(decode_gpu_ids) != decode_worker_count * decode_tp_size:
|
||||||
|
raise ValueError(
|
||||||
|
"decode_gpu_ids length must equal decode_worker_count * decode_tp_size: "
|
||||||
|
f"{len(decode_gpu_ids)} != {decode_worker_count * decode_tp_size}"
|
||||||
|
)
|
||||||
|
if len(direct_gpu_ids) != direct_worker_count * direct_tp_size:
|
||||||
|
raise ValueError(
|
||||||
|
"direct_gpu_ids length must equal direct_worker_count * direct_tp_size: "
|
||||||
|
f"{len(direct_gpu_ids)} != {direct_worker_count * direct_tp_size}"
|
||||||
|
)
|
||||||
|
assigned_gpu_ids = prefill_gpu_ids + decode_gpu_ids + direct_gpu_ids
|
||||||
|
if len(set(assigned_gpu_ids)) != len(assigned_gpu_ids):
|
||||||
|
raise ValueError("prefill/decode/direct GPU IDs must be unique")
|
||||||
|
if any(gpu_id < 0 or gpu_id >= total_gpu_budget for gpu_id in assigned_gpu_ids):
|
||||||
|
raise ValueError(
|
||||||
|
"GPU IDs must fall within the single-node budget range "
|
||||||
|
f"[0, {total_gpu_budget - 1}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
prefill_workers = tuple(
|
||||||
|
WorkerSpec(
|
||||||
|
role="prefill",
|
||||||
|
ordinal=idx,
|
||||||
|
gpu_ids=tuple(
|
||||||
|
prefill_gpu_ids[
|
||||||
|
idx * prefill_tp_size : (idx + 1) * prefill_tp_size
|
||||||
|
]
|
||||||
|
),
|
||||||
|
host=host,
|
||||||
|
port=prefill_port_base + idx,
|
||||||
|
bootstrap_port=bootstrap_port_base + idx,
|
||||||
|
)
|
||||||
|
for idx in range(prefill_worker_count)
|
||||||
|
)
|
||||||
|
decode_workers = tuple(
|
||||||
|
WorkerSpec(
|
||||||
|
role="decode",
|
||||||
|
ordinal=idx,
|
||||||
|
gpu_ids=tuple(
|
||||||
|
decode_gpu_ids[
|
||||||
|
idx * decode_tp_size : (idx + 1) * decode_tp_size
|
||||||
|
]
|
||||||
|
),
|
||||||
|
host=host,
|
||||||
|
port=decode_port_base + idx,
|
||||||
|
)
|
||||||
|
for idx in range(decode_worker_count)
|
||||||
|
)
|
||||||
|
direct_workers = tuple(
|
||||||
|
WorkerSpec(
|
||||||
|
role="direct",
|
||||||
|
ordinal=idx,
|
||||||
|
gpu_ids=tuple(
|
||||||
|
direct_gpu_ids[
|
||||||
|
idx * direct_tp_size : (idx + 1) * direct_tp_size
|
||||||
|
]
|
||||||
|
),
|
||||||
|
host=host,
|
||||||
|
port=direct_port_base + idx,
|
||||||
|
)
|
||||||
|
for idx in range(direct_worker_count)
|
||||||
|
)
|
||||||
|
|
||||||
|
return SingleNodeTopology(
|
||||||
|
model_path=model_path,
|
||||||
|
prefill_workers=prefill_workers,
|
||||||
|
decode_workers=decode_workers,
|
||||||
|
direct_workers=direct_workers,
|
||||||
|
router_host=host,
|
||||||
|
router_port=router_port,
|
||||||
|
transfer_backend=transfer_backend,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
force_rdma=force_rdma,
|
||||||
|
ib_device=ib_device,
|
||||||
|
extra_server_args=extra_server_args,
|
||||||
|
prefill_extra_server_args=prefill_extra_server_args,
|
||||||
|
decode_extra_server_args=decode_extra_server_args,
|
||||||
|
direct_extra_server_args=direct_extra_server_args,
|
||||||
|
)
|
||||||
106
src/agentic_pd_hybrid/trace.py
Normal file
106
src/agentic_pd_hybrid/trace.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TraceRequest:
|
||||||
|
request_id: str
|
||||||
|
session_id: str
|
||||||
|
chat_id: int
|
||||||
|
parent_chat_id: int
|
||||||
|
timestamp_s: float
|
||||||
|
input_length: int
|
||||||
|
output_length: int
|
||||||
|
request_type: str
|
||||||
|
turn_id: int
|
||||||
|
hash_ids: tuple[int, ...]
|
||||||
|
|
||||||
|
|
||||||
|
def load_trace(path: Path, *, request_limit: int | None = None) -> list[TraceRequest]:
|
||||||
|
chat_to_session: dict[int, str] = {}
|
||||||
|
requests: list[TraceRequest] = []
|
||||||
|
|
||||||
|
with path.open("r", encoding="utf-8") as handle:
|
||||||
|
for index, line in enumerate(handle):
|
||||||
|
if request_limit is not None and len(requests) >= request_limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
payload = json.loads(line)
|
||||||
|
chat_id = int(payload["chat_id"])
|
||||||
|
parent_chat_id = int(payload["parent_chat_id"])
|
||||||
|
session_id = _resolve_session_id(
|
||||||
|
chat_id=chat_id,
|
||||||
|
parent_chat_id=parent_chat_id,
|
||||||
|
chat_to_session=chat_to_session,
|
||||||
|
)
|
||||||
|
turn_id = int(payload["turn"])
|
||||||
|
request_id = f"{session_id}:{turn_id}:{chat_id}:{index}"
|
||||||
|
requests.append(
|
||||||
|
TraceRequest(
|
||||||
|
request_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
parent_chat_id=parent_chat_id,
|
||||||
|
timestamp_s=float(payload["timestamp"]),
|
||||||
|
input_length=int(payload["input_length"]),
|
||||||
|
output_length=int(payload["output_length"]),
|
||||||
|
request_type=str(payload["type"]),
|
||||||
|
turn_id=turn_id,
|
||||||
|
hash_ids=tuple(int(item) for item in payload.get("hash_ids", [])),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return requests
|
||||||
|
|
||||||
|
|
||||||
|
def build_synthetic_prompt(
|
||||||
|
request: TraceRequest,
|
||||||
|
*,
|
||||||
|
block_token_budget: int = 24,
|
||||||
|
) -> str:
|
||||||
|
return " ".join(build_synthetic_prompt_tokens(request, block_token_budget=block_token_budget))
|
||||||
|
|
||||||
|
|
||||||
|
def build_synthetic_prompt_tokens(
|
||||||
|
request: TraceRequest,
|
||||||
|
*,
|
||||||
|
block_token_budget: int = 24,
|
||||||
|
) -> list[str]:
|
||||||
|
tokens: list[str] = []
|
||||||
|
for hash_id in request.hash_ids:
|
||||||
|
for offset in range(block_token_budget):
|
||||||
|
tokens.append(f"blk{hash_id}_{offset}")
|
||||||
|
|
||||||
|
while len(tokens) < request.input_length:
|
||||||
|
tokens.append(f"fill_{len(tokens) % 64}")
|
||||||
|
|
||||||
|
return tokens[: request.input_length]
|
||||||
|
|
||||||
|
|
||||||
|
def build_synthetic_append_chunk(
|
||||||
|
request: TraceRequest,
|
||||||
|
append_length: int,
|
||||||
|
) -> str:
|
||||||
|
if append_length <= 0:
|
||||||
|
return ""
|
||||||
|
return " ".join(
|
||||||
|
f"turn{request.turn_id}_append_{request.chat_id}_{offset}"
|
||||||
|
for offset in range(append_length)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_session_id(
|
||||||
|
*,
|
||||||
|
chat_id: int,
|
||||||
|
parent_chat_id: int,
|
||||||
|
chat_to_session: dict[int, str],
|
||||||
|
) -> str:
|
||||||
|
if parent_chat_id < 0:
|
||||||
|
session_id = str(chat_id)
|
||||||
|
else:
|
||||||
|
session_id = chat_to_session.get(parent_chat_id, str(parent_chat_id))
|
||||||
|
chat_to_session[chat_id] = session_id
|
||||||
|
return session_id
|
||||||
127
src/agentic_pd_hybrid/trace_profiles.py
Normal file
127
src/agentic_pd_hybrid/trace_profiles.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from math import ceil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from agentic_pd_hybrid.trace import TraceRequest, load_trace
|
||||||
|
|
||||||
|
|
||||||
|
BLOCK_TOKEN_BUDGET = 24
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class NormalizeTraceLengthsConfig:
|
||||||
|
trace_path: Path
|
||||||
|
output_path: Path
|
||||||
|
initial_input_length: int = 10_000
|
||||||
|
append_input_length: int = 1_000
|
||||||
|
output_length: int = 1_000
|
||||||
|
max_requests: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class NormalizeTraceLengthsSummary:
|
||||||
|
input_trace_path: str
|
||||||
|
output_trace_path: str
|
||||||
|
request_count: int
|
||||||
|
session_count: int
|
||||||
|
multi_turn_session_count: int
|
||||||
|
initial_input_length: int
|
||||||
|
append_input_length: int
|
||||||
|
output_length: int
|
||||||
|
max_turns_per_session: int
|
||||||
|
max_input_length: int
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_trace_lengths(
|
||||||
|
config: NormalizeTraceLengthsConfig,
|
||||||
|
) -> NormalizeTraceLengthsSummary:
|
||||||
|
if config.initial_input_length < 0:
|
||||||
|
raise ValueError("initial_input_length must be >= 0")
|
||||||
|
if config.append_input_length < 0:
|
||||||
|
raise ValueError("append_input_length must be >= 0")
|
||||||
|
if config.output_length < 0:
|
||||||
|
raise ValueError("output_length must be >= 0")
|
||||||
|
|
||||||
|
requests = load_trace(config.trace_path, request_limit=config.max_requests)
|
||||||
|
sessions: dict[str, list[TraceRequest]] = defaultdict(list)
|
||||||
|
for request in requests:
|
||||||
|
sessions[request.session_id].append(request)
|
||||||
|
|
||||||
|
normalized_records: list[dict[str, object]] = []
|
||||||
|
max_turns_per_session = 0
|
||||||
|
max_input_length = 0
|
||||||
|
|
||||||
|
for session_idx, session_id in enumerate(sorted(sessions, key=_session_sort_key)):
|
||||||
|
session_requests = sorted(
|
||||||
|
sessions[session_id],
|
||||||
|
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
|
||||||
|
)
|
||||||
|
max_turns_per_session = max(max_turns_per_session, len(session_requests))
|
||||||
|
base_block_count = ceil(config.initial_input_length / BLOCK_TOKEN_BUDGET)
|
||||||
|
base_hash_ids = [
|
||||||
|
_hash_id_for(session_idx=session_idx, block_idx=block_idx)
|
||||||
|
for block_idx in range(base_block_count)
|
||||||
|
]
|
||||||
|
|
||||||
|
for turn_idx, request in enumerate(session_requests):
|
||||||
|
input_length = config.initial_input_length + turn_idx * (
|
||||||
|
config.append_input_length + config.output_length
|
||||||
|
)
|
||||||
|
total_block_count = ceil(input_length / BLOCK_TOKEN_BUDGET)
|
||||||
|
hash_ids = base_hash_ids + [
|
||||||
|
_hash_id_for(
|
||||||
|
session_idx=session_idx,
|
||||||
|
block_idx=base_block_count + append_block_idx,
|
||||||
|
)
|
||||||
|
for append_block_idx in range(max(0, total_block_count - base_block_count))
|
||||||
|
]
|
||||||
|
max_input_length = max(max_input_length, input_length)
|
||||||
|
normalized_records.append(
|
||||||
|
{
|
||||||
|
"chat_id": request.chat_id,
|
||||||
|
"parent_chat_id": request.parent_chat_id,
|
||||||
|
"timestamp": request.timestamp_s,
|
||||||
|
"input_length": input_length,
|
||||||
|
"output_length": config.output_length,
|
||||||
|
"type": request.request_type,
|
||||||
|
"turn": request.turn_id,
|
||||||
|
"hash_ids": hash_ids,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
normalized_records.sort(key=lambda item: float(item["timestamp"]))
|
||||||
|
config.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with config.output_path.open("w", encoding="utf-8") as handle:
|
||||||
|
for record in normalized_records:
|
||||||
|
handle.write(json.dumps(record, sort_keys=True) + "\n")
|
||||||
|
|
||||||
|
summary = NormalizeTraceLengthsSummary(
|
||||||
|
input_trace_path=str(config.trace_path),
|
||||||
|
output_trace_path=str(config.output_path),
|
||||||
|
request_count=len(normalized_records),
|
||||||
|
session_count=len(sessions),
|
||||||
|
multi_turn_session_count=sum(
|
||||||
|
1 for session_requests in sessions.values() if len(session_requests) > 1
|
||||||
|
),
|
||||||
|
initial_input_length=config.initial_input_length,
|
||||||
|
append_input_length=config.append_input_length,
|
||||||
|
output_length=config.output_length,
|
||||||
|
max_turns_per_session=max_turns_per_session,
|
||||||
|
max_input_length=max_input_length,
|
||||||
|
)
|
||||||
|
summary_path = config.output_path.with_suffix(config.output_path.suffix + ".summary.json")
|
||||||
|
with summary_path.open("w", encoding="utf-8") as handle:
|
||||||
|
json.dump(asdict(summary), handle, indent=2, sort_keys=True)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_id_for(*, session_idx: int, block_idx: int) -> int:
|
||||||
|
return session_idx * 1_000_000 + block_idx
|
||||||
|
|
||||||
|
|
||||||
|
def _session_sort_key(session_id: str) -> tuple[int, str]:
|
||||||
|
return (0, session_id) if session_id.isdigit() else (1, session_id)
|
||||||
Reference in New Issue
Block a user