feat: add agentic pd hybrid benchmark prototype
This commit is contained in:
24
pyproject.toml
Normal file
24
pyproject.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[project]
|
||||
name = "agentic-pd-hybrid"
|
||||
version = "0.1.0"
|
||||
description = "Prototype for session-aware and KV-cache-aware PD routing on SGLang xPyD"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"httpx>=0.28.1",
|
||||
"mooncake-transfer-engine",
|
||||
"sglang==0.5.10",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
agentic-pd-hybrid = "agentic_pd_hybrid.cli:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=68"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[tool.uv]
|
||||
prerelease = "allow"
|
||||
12
src/agentic_pd_hybrid/__init__.py
Normal file
12
src/agentic_pd_hybrid/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Agentic PD hybrid prototype."""
|
||||
|
||||
__all__ = [
|
||||
"cli",
|
||||
"launcher",
|
||||
"metrics",
|
||||
"microbench",
|
||||
"policies",
|
||||
"replay",
|
||||
"topology",
|
||||
"trace",
|
||||
]
|
||||
221
src/agentic_pd_hybrid/benchmark.py
Normal file
221
src/agentic_pd_hybrid/benchmark.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import signal
|
||||
from dataclasses import asdict, dataclass, replace
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from agentic_pd_hybrid.replay import ReplayConfig, replay_trace
|
||||
from agentic_pd_hybrid.sampling import SessionSampleConfig, sample_trace_sessions
|
||||
from agentic_pd_hybrid.stack import ManagedPdStack, launch_pd_stack
|
||||
from agentic_pd_hybrid.topology import SingleNodeTopology
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BenchmarkConfig:
|
||||
trace_path: Path
|
||||
output_root: Path
|
||||
topology: SingleNodeTopology
|
||||
policy_name: str
|
||||
mechanism_name: str = "pd-disaggregation"
|
||||
target_duration_s: float = 600.0
|
||||
start_time_s: float = 0.0
|
||||
session_sample_rate: float = 0.01
|
||||
min_turns: int = 1
|
||||
time_scale: float = 1.0
|
||||
concurrency_limit: int = 32
|
||||
timeout_s: float = 1200.0
|
||||
stream: bool = True
|
||||
stream_idle_timeout_s: float | None = 900.0
|
||||
kvcache_direct_max_uncached_tokens: int = 2048
|
||||
kvcache_admission_mode: str = "router"
|
||||
sample_profile: str = "default"
|
||||
min_initial_input_tokens: int | None = None
|
||||
max_initial_input_tokens: int | None = None
|
||||
max_append_input_tokens: int | None = None
|
||||
max_output_tokens: int | None = None
|
||||
min_overlap_ratio: float | None = None
|
||||
launch_stack: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BenchmarkArtifacts:
|
||||
run_dir: Path
|
||||
sampled_trace_path: Path
|
||||
metrics_path: Path
|
||||
summary_path: Path
|
||||
benchmark_config_path: Path
|
||||
|
||||
|
||||
def run_live_benchmark(config: BenchmarkConfig) -> BenchmarkArtifacts:
|
||||
run_id = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ")
|
||||
run_label = f"{config.mechanism_name}-{config.policy_name}"
|
||||
if config.mechanism_name == "kvcache-centric":
|
||||
run_label = f"{run_label}-{config.kvcache_admission_mode}-admission"
|
||||
run_dir = config.output_root / f"{run_label}-{run_id}"
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
topology = config.topology
|
||||
if config.mechanism_name == "kvcache-centric":
|
||||
topology = replace(
|
||||
topology,
|
||||
prefill_extra_server_args=topology.prefill_extra_server_args
|
||||
+ ("--enable-streaming-session",),
|
||||
decode_extra_server_args=topology.decode_extra_server_args
|
||||
+ (
|
||||
"--enable-streaming-session",
|
||||
"--disaggregation-decode-allow-local-prefill",
|
||||
),
|
||||
)
|
||||
|
||||
sampled_trace_path = run_dir / "sampled-trace.jsonl"
|
||||
sample_summary = sample_trace_sessions(
|
||||
SessionSampleConfig(
|
||||
trace_path=config.trace_path,
|
||||
output_path=sampled_trace_path,
|
||||
target_duration_s=config.target_duration_s,
|
||||
start_time_s=config.start_time_s,
|
||||
session_sample_rate=config.session_sample_rate,
|
||||
min_turns=config.min_turns,
|
||||
profile=config.sample_profile, # type: ignore[arg-type]
|
||||
min_initial_input_tokens=config.min_initial_input_tokens,
|
||||
max_initial_input_tokens=config.max_initial_input_tokens,
|
||||
max_append_input_tokens=config.max_append_input_tokens,
|
||||
max_output_tokens=config.max_output_tokens,
|
||||
min_overlap_ratio=config.min_overlap_ratio,
|
||||
)
|
||||
)
|
||||
|
||||
stack: ManagedPdStack | None = None
|
||||
previous_sigint = signal.getsignal(signal.SIGINT)
|
||||
previous_sigterm = signal.getsignal(signal.SIGTERM)
|
||||
|
||||
def _handle_termination(signum, _frame) -> None:
|
||||
if stack is not None:
|
||||
stack.stop()
|
||||
raise SystemExit(128 + signum)
|
||||
|
||||
try:
|
||||
signal.signal(signal.SIGINT, _handle_termination)
|
||||
signal.signal(signal.SIGTERM, _handle_termination)
|
||||
if config.launch_stack:
|
||||
stack = launch_pd_stack(
|
||||
topology=topology,
|
||||
run_dir=run_dir,
|
||||
prefill_policy="round_robin",
|
||||
decode_policy=_decode_policy_for(config.policy_name),
|
||||
timeout_s=config.timeout_s,
|
||||
include_router=(
|
||||
config.mechanism_name in {"pd-disaggregation", "kvcache-centric"}
|
||||
),
|
||||
)
|
||||
router_url = (
|
||||
stack.router_url
|
||||
if config.mechanism_name in {"pd-disaggregation", "kvcache-centric"}
|
||||
else None
|
||||
)
|
||||
else:
|
||||
router_url = (
|
||||
topology.router_url
|
||||
if config.mechanism_name in {"pd-disaggregation", "kvcache-centric"}
|
||||
else None
|
||||
)
|
||||
|
||||
metrics_path = run_dir / "request-metrics.jsonl"
|
||||
replay_config = ReplayConfig(
|
||||
trace_path=sampled_trace_path,
|
||||
output_path=metrics_path,
|
||||
policy_name=config.policy_name,
|
||||
mechanism_name=config.mechanism_name,
|
||||
topology=topology,
|
||||
router_url=router_url,
|
||||
model_name=topology.model_name,
|
||||
pace=True,
|
||||
time_scale=config.time_scale,
|
||||
request_limit=None,
|
||||
concurrency_limit=config.concurrency_limit,
|
||||
header_mode=_header_mode_for(config.policy_name),
|
||||
timeout_s=config.timeout_s,
|
||||
stream=config.stream,
|
||||
stream_idle_timeout_s=config.stream_idle_timeout_s,
|
||||
kvcache_direct_max_uncached_tokens=config.kvcache_direct_max_uncached_tokens,
|
||||
kvcache_admission_mode=config.kvcache_admission_mode, # type: ignore[arg-type]
|
||||
)
|
||||
asyncio.run(replay_trace(replay_config))
|
||||
finally:
|
||||
signal.signal(signal.SIGINT, previous_sigint)
|
||||
signal.signal(signal.SIGTERM, previous_sigterm)
|
||||
if stack is not None:
|
||||
stack.stop()
|
||||
|
||||
benchmark_config_path = run_dir / "benchmark-config.json"
|
||||
with benchmark_config_path.open("w", encoding="utf-8") as handle:
|
||||
json.dump(
|
||||
{
|
||||
"policy_name": config.policy_name,
|
||||
"mechanism_name": config.mechanism_name,
|
||||
"target_duration_s": config.target_duration_s,
|
||||
"start_time_s": config.start_time_s,
|
||||
"session_sample_rate": config.session_sample_rate,
|
||||
"min_turns": config.min_turns,
|
||||
"time_scale": config.time_scale,
|
||||
"concurrency_limit": config.concurrency_limit,
|
||||
"timeout_s": config.timeout_s,
|
||||
"stream": config.stream,
|
||||
"stream_idle_timeout_s": config.stream_idle_timeout_s,
|
||||
"kvcache_direct_max_uncached_tokens": config.kvcache_direct_max_uncached_tokens,
|
||||
"kvcache_admission_mode": config.kvcache_admission_mode,
|
||||
"sample_profile": config.sample_profile,
|
||||
"min_initial_input_tokens": config.min_initial_input_tokens,
|
||||
"max_initial_input_tokens": config.max_initial_input_tokens,
|
||||
"max_append_input_tokens": config.max_append_input_tokens,
|
||||
"max_output_tokens": config.max_output_tokens,
|
||||
"min_overlap_ratio": config.min_overlap_ratio,
|
||||
"sample_summary": asdict(sample_summary),
|
||||
"topology": {
|
||||
"model_path": config.topology.model_path,
|
||||
"router_url": topology.router_url,
|
||||
"transfer_backend": topology.transfer_backend,
|
||||
"force_rdma": topology.force_rdma,
|
||||
"ib_device": topology.ib_device,
|
||||
"prefill_workers": [
|
||||
worker.worker_id for worker in topology.prefill_workers
|
||||
],
|
||||
"decode_workers": [
|
||||
worker.worker_id for worker in topology.decode_workers
|
||||
],
|
||||
"direct_workers": [
|
||||
worker.worker_id for worker in topology.direct_workers
|
||||
],
|
||||
},
|
||||
},
|
||||
handle,
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
)
|
||||
|
||||
return BenchmarkArtifacts(
|
||||
run_dir=run_dir,
|
||||
sampled_trace_path=sampled_trace_path,
|
||||
metrics_path=run_dir / "request-metrics.jsonl",
|
||||
summary_path=run_dir / "request-metrics.jsonl.summary.json",
|
||||
benchmark_config_path=benchmark_config_path,
|
||||
)
|
||||
|
||||
|
||||
def _decode_policy_for(policy_name: str) -> str:
|
||||
if policy_name == "sticky":
|
||||
return "manual"
|
||||
if policy_name == "kv-aware":
|
||||
return "consistent_hashing"
|
||||
return "round_robin"
|
||||
|
||||
|
||||
def _header_mode_for(policy_name: str) -> str:
|
||||
if policy_name == "sticky":
|
||||
return "routing-key"
|
||||
if policy_name == "kv-aware":
|
||||
return "target-worker"
|
||||
return "none"
|
||||
484
src/agentic_pd_hybrid/cli.py
Normal file
484
src/agentic_pd_hybrid/cli.py
Normal file
@@ -0,0 +1,484 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from agentic_pd_hybrid.benchmark import BenchmarkConfig, run_live_benchmark
|
||||
from agentic_pd_hybrid.launcher import build_launch_plan
|
||||
from agentic_pd_hybrid.microbench import SmallAppendTraceConfig, write_small_append_trace
|
||||
from agentic_pd_hybrid.replay import ReplayConfig, replay_trace
|
||||
from agentic_pd_hybrid.sampling import SessionSampleConfig, sample_trace_sessions
|
||||
from agentic_pd_hybrid.trace_profiles import (
|
||||
NormalizeTraceLengthsConfig,
|
||||
normalize_trace_lengths,
|
||||
)
|
||||
from agentic_pd_hybrid.topology import build_single_node_topology
|
||||
|
||||
|
||||
def _normalize_mechanism_name(name: str) -> str:
|
||||
normalized = name.strip().lower()
|
||||
aliases = {
|
||||
"pd-disagg": "pd-disaggregation",
|
||||
"pd-disaggregation": "pd-disaggregation",
|
||||
"pd-hybrid": "pd-disaggregation",
|
||||
"baseline-pd-disagg": "pd-disaggregation",
|
||||
"pd-colo": "pd-colo",
|
||||
"direct-d": "pd-colo",
|
||||
"colocation": "pd-colo",
|
||||
"kvcache-centric": "kvcache-centric",
|
||||
"turn2+-direct-to-d": "kvcache-centric",
|
||||
"pd-with-d-append": "kvcache-centric",
|
||||
}
|
||||
if normalized not in aliases:
|
||||
raise ValueError(f"Unsupported mechanism: {name}")
|
||||
return aliases[normalized]
|
||||
|
||||
|
||||
def _parse_gpu_id_list(value: str | None) -> tuple[int, ...] | None:
|
||||
if value is None:
|
||||
return None
|
||||
items = [item.strip() for item in value.split(",") if item.strip()]
|
||||
if not items:
|
||||
return tuple()
|
||||
return tuple(int(item) for item in items)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Agentic PD hybrid prototype")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
print_launch = subparsers.add_parser(
|
||||
"print-launch",
|
||||
help="Print one-node SGLang PD launch commands",
|
||||
)
|
||||
_add_topology_arguments(print_launch)
|
||||
print_launch.add_argument("--prefill-policy", default="round_robin")
|
||||
print_launch.add_argument("--decode-policy", default="manual")
|
||||
|
||||
replay = subparsers.add_parser(
|
||||
"replay",
|
||||
help="Replay trace and log request-level metrics",
|
||||
)
|
||||
_add_topology_arguments(replay)
|
||||
replay.add_argument("--trace", type=Path, required=True)
|
||||
replay.add_argument("--output", type=Path, required=True)
|
||||
replay.add_argument(
|
||||
"--policy",
|
||||
choices=["default", "sticky", "kv-aware"],
|
||||
default="sticky",
|
||||
)
|
||||
replay.add_argument(
|
||||
"--mechanism",
|
||||
choices=[
|
||||
"pd-disaggregation",
|
||||
"pd-hybrid",
|
||||
"pd-disagg",
|
||||
"pd-colo",
|
||||
"direct-d",
|
||||
"kvcache-centric",
|
||||
"turn2+-direct-to-d",
|
||||
"pd-with-d-append",
|
||||
],
|
||||
default="pd-disaggregation",
|
||||
)
|
||||
replay.add_argument("--router-url")
|
||||
replay.add_argument("--model")
|
||||
replay.add_argument(
|
||||
"--header-mode",
|
||||
choices=["auto", "none", "routing-key", "target-worker"],
|
||||
default="auto",
|
||||
)
|
||||
replay.add_argument(
|
||||
"--request-limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Replay at most this many requests",
|
||||
)
|
||||
replay.add_argument(
|
||||
"--no-pace",
|
||||
action="store_true",
|
||||
help="Disable wall-clock pacing from trace timestamps",
|
||||
)
|
||||
replay.add_argument(
|
||||
"--time-scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Scale trace timing by this factor when pacing is enabled",
|
||||
)
|
||||
replay.add_argument(
|
||||
"--concurrency-limit",
|
||||
type=int,
|
||||
default=32,
|
||||
)
|
||||
replay.add_argument(
|
||||
"--timeout-s",
|
||||
type=float,
|
||||
default=600.0,
|
||||
)
|
||||
replay.add_argument(
|
||||
"--no-stream",
|
||||
action="store_true",
|
||||
help="Use non-streaming OpenAI responses for more robust E2E-only replay.",
|
||||
)
|
||||
replay.add_argument(
|
||||
"--stream-idle-timeout-s",
|
||||
type=float,
|
||||
default=900.0,
|
||||
help="Abort a streaming request if no SSE line arrives within this many seconds.",
|
||||
)
|
||||
replay.add_argument(
|
||||
"--kvcache-direct-max-uncached-tokens",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="For kvcache-centric routing, bypass P when the uncached suffix is at most this many tokens.",
|
||||
)
|
||||
replay.add_argument(
|
||||
"--kvcache-admission-mode",
|
||||
choices=["router", "worker"],
|
||||
default="router",
|
||||
help=(
|
||||
"For kvcache-centric routing, use router shadow-state admission "
|
||||
"or query the decode worker on the critical path."
|
||||
),
|
||||
)
|
||||
|
||||
sample = subparsers.add_parser(
|
||||
"sample-sessions",
|
||||
help="Sample a session-granularity trace shard for live benchmarking",
|
||||
)
|
||||
sample.add_argument("--trace", type=Path, required=True)
|
||||
sample.add_argument("--output", type=Path, required=True)
|
||||
sample.add_argument("--target-duration-s", type=float, default=600.0)
|
||||
sample.add_argument("--start-time-s", type=float, default=0.0)
|
||||
sample.add_argument("--session-sample-rate", type=float, default=0.01)
|
||||
sample.add_argument("--min-turns", type=int, default=1)
|
||||
sample.add_argument("--max-requests", type=int, default=None)
|
||||
sample.add_argument(
|
||||
"--profile",
|
||||
choices=["default", "small-append"],
|
||||
default="default",
|
||||
help="Optional workload-shape filter for live benchmarks.",
|
||||
)
|
||||
sample.add_argument("--min-initial-input-tokens", type=int, default=None)
|
||||
sample.add_argument("--max-initial-input-tokens", type=int, default=None)
|
||||
sample.add_argument("--max-append-input-tokens", type=int, default=None)
|
||||
sample.add_argument("--max-output-tokens", type=int, default=None)
|
||||
sample.add_argument("--min-overlap-ratio", type=float, default=None)
|
||||
|
||||
normalize = subparsers.add_parser(
|
||||
"normalize-trace-lengths",
|
||||
help="Rewrite a trace to a fixed turn1/append/output length profile",
|
||||
)
|
||||
normalize.add_argument("--trace", type=Path, required=True)
|
||||
normalize.add_argument("--output", type=Path, required=True)
|
||||
normalize.add_argument("--initial-input-length", type=int, default=10_000)
|
||||
normalize.add_argument("--append-input-length", type=int, default=1_000)
|
||||
normalize.add_argument("--output-length", type=int, default=1_000)
|
||||
normalize.add_argument("--max-requests", type=int, default=None)
|
||||
|
||||
micro = subparsers.add_parser(
|
||||
"make-small-append-trace",
|
||||
help="Generate a synthetic multi-turn trace with small turn2+ appends",
|
||||
)
|
||||
micro.add_argument("--output", type=Path, required=True)
|
||||
micro.add_argument("--session-count", type=int, default=8)
|
||||
micro.add_argument("--turns-per-session", type=int, default=3)
|
||||
micro.add_argument("--initial-input-length", type=int, default=10_000)
|
||||
micro.add_argument("--append-input-length", type=int, default=1_000)
|
||||
micro.add_argument("--output-length", type=int, default=1_000)
|
||||
micro.add_argument("--inter-turn-gap-s", type=float, default=1.0)
|
||||
micro.add_argument("--session-stagger-s", type=float, default=0.1)
|
||||
|
||||
benchmark = subparsers.add_parser(
|
||||
"benchmark-live",
|
||||
help="Launch a real PD stack, sample sessions, and collect live E2E numbers",
|
||||
)
|
||||
_add_topology_arguments(benchmark)
|
||||
benchmark.add_argument("--trace", type=Path, required=True)
|
||||
benchmark.add_argument(
|
||||
"--policy",
|
||||
choices=["default", "sticky", "kv-aware"],
|
||||
default="sticky",
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--mechanism",
|
||||
choices=[
|
||||
"pd-disaggregation",
|
||||
"pd-hybrid",
|
||||
"pd-disagg",
|
||||
"pd-colo",
|
||||
"direct-d",
|
||||
"kvcache-centric",
|
||||
"turn2+-direct-to-d",
|
||||
"pd-with-d-append",
|
||||
],
|
||||
default="pd-disaggregation",
|
||||
)
|
||||
benchmark.add_argument("--output-root", type=Path, default=Path("outputs/live"))
|
||||
benchmark.add_argument("--target-duration-s", type=float, default=600.0)
|
||||
benchmark.add_argument("--start-time-s", type=float, default=0.0)
|
||||
benchmark.add_argument("--session-sample-rate", type=float, default=0.01)
|
||||
benchmark.add_argument("--min-turns", type=int, default=1)
|
||||
benchmark.add_argument("--time-scale", type=float, default=1.0)
|
||||
benchmark.add_argument("--concurrency-limit", type=int, default=32)
|
||||
benchmark.add_argument("--timeout-s", type=float, default=1200.0)
|
||||
benchmark.add_argument(
|
||||
"--no-stream",
|
||||
action="store_true",
|
||||
help="Use non-streaming OpenAI responses for E2E-only live benchmarking.",
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--stream-idle-timeout-s",
|
||||
type=float,
|
||||
default=900.0,
|
||||
help="Abort a streaming request if no SSE line arrives within this many seconds.",
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--kvcache-direct-max-uncached-tokens",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="For kvcache-centric routing, bypass P when the uncached suffix is at most this many tokens.",
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--kvcache-admission-mode",
|
||||
choices=["router", "worker"],
|
||||
default="router",
|
||||
help=(
|
||||
"For kvcache-centric routing, use router shadow-state admission "
|
||||
"or query the decode worker on the critical path."
|
||||
),
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--sample-profile",
|
||||
choices=["default", "small-append"],
|
||||
default="default",
|
||||
help="Optional session-shape filter applied before live replay.",
|
||||
)
|
||||
benchmark.add_argument("--min-initial-input-tokens", type=int, default=None)
|
||||
benchmark.add_argument("--max-initial-input-tokens", type=int, default=None)
|
||||
benchmark.add_argument("--max-append-input-tokens", type=int, default=None)
|
||||
benchmark.add_argument("--max-output-tokens", type=int, default=None)
|
||||
benchmark.add_argument("--min-overlap-ratio", type=float, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "print-launch":
|
||||
topology = _topology_from_args(args)
|
||||
plan = build_launch_plan(
|
||||
topology,
|
||||
prefill_policy=args.prefill_policy,
|
||||
decode_policy=args.decode_policy,
|
||||
include_router=bool(topology.prefill_workers and topology.decode_workers),
|
||||
)
|
||||
print(plan.render())
|
||||
return
|
||||
|
||||
if args.command == "replay":
|
||||
topology = _topology_from_args(args)
|
||||
config = ReplayConfig(
|
||||
trace_path=args.trace,
|
||||
output_path=args.output,
|
||||
policy_name=args.policy,
|
||||
mechanism_name=_normalize_mechanism_name(args.mechanism),
|
||||
topology=topology,
|
||||
router_url=args.router_url,
|
||||
model_name=args.model,
|
||||
pace=not args.no_pace,
|
||||
time_scale=args.time_scale,
|
||||
request_limit=args.request_limit,
|
||||
concurrency_limit=args.concurrency_limit,
|
||||
header_mode=args.header_mode,
|
||||
timeout_s=args.timeout_s,
|
||||
stream=not args.no_stream,
|
||||
stream_idle_timeout_s=args.stream_idle_timeout_s,
|
||||
kvcache_direct_max_uncached_tokens=args.kvcache_direct_max_uncached_tokens,
|
||||
kvcache_admission_mode=args.kvcache_admission_mode,
|
||||
)
|
||||
results = asyncio.run(replay_trace(config))
|
||||
print(
|
||||
f"wrote {len(results)} request records to {args.output} and "
|
||||
f"{args.output}{'.summary.json'}"
|
||||
)
|
||||
return
|
||||
|
||||
if args.command == "sample-sessions":
|
||||
summary = sample_trace_sessions(
|
||||
SessionSampleConfig(
|
||||
trace_path=args.trace,
|
||||
output_path=args.output,
|
||||
target_duration_s=args.target_duration_s,
|
||||
start_time_s=args.start_time_s,
|
||||
session_sample_rate=args.session_sample_rate,
|
||||
min_turns=args.min_turns,
|
||||
max_requests=args.max_requests,
|
||||
profile=args.profile,
|
||||
min_initial_input_tokens=args.min_initial_input_tokens,
|
||||
max_initial_input_tokens=args.max_initial_input_tokens,
|
||||
max_append_input_tokens=args.max_append_input_tokens,
|
||||
max_output_tokens=args.max_output_tokens,
|
||||
min_overlap_ratio=args.min_overlap_ratio,
|
||||
)
|
||||
)
|
||||
print(
|
||||
f"wrote {summary.request_count} requests from {summary.session_count} sessions "
|
||||
f"covering {summary.sampled_duration_s:.3f}s to {args.output}"
|
||||
)
|
||||
return
|
||||
|
||||
if args.command == "normalize-trace-lengths":
|
||||
summary = normalize_trace_lengths(
|
||||
NormalizeTraceLengthsConfig(
|
||||
trace_path=args.trace,
|
||||
output_path=args.output,
|
||||
initial_input_length=args.initial_input_length,
|
||||
append_input_length=args.append_input_length,
|
||||
output_length=args.output_length,
|
||||
max_requests=args.max_requests,
|
||||
)
|
||||
)
|
||||
print(
|
||||
f"wrote {summary.request_count} normalized requests from "
|
||||
f"{summary.session_count} sessions to {args.output}"
|
||||
)
|
||||
return
|
||||
|
||||
if args.command == "make-small-append-trace":
|
||||
summary = write_small_append_trace(
|
||||
SmallAppendTraceConfig(
|
||||
output_path=args.output,
|
||||
session_count=args.session_count,
|
||||
turns_per_session=args.turns_per_session,
|
||||
initial_input_length=args.initial_input_length,
|
||||
append_input_length=args.append_input_length,
|
||||
output_length=args.output_length,
|
||||
inter_turn_gap_s=args.inter_turn_gap_s,
|
||||
session_stagger_s=args.session_stagger_s,
|
||||
)
|
||||
)
|
||||
print(
|
||||
f"wrote {summary.request_count} requests across {summary.session_count} sessions "
|
||||
f"to {args.output}"
|
||||
)
|
||||
return
|
||||
|
||||
if args.command == "benchmark-live":
|
||||
topology = _topology_from_args(args)
|
||||
artifacts = run_live_benchmark(
|
||||
BenchmarkConfig(
|
||||
trace_path=args.trace,
|
||||
output_root=args.output_root,
|
||||
topology=topology,
|
||||
policy_name=args.policy,
|
||||
mechanism_name=_normalize_mechanism_name(args.mechanism),
|
||||
target_duration_s=args.target_duration_s,
|
||||
start_time_s=args.start_time_s,
|
||||
session_sample_rate=args.session_sample_rate,
|
||||
min_turns=args.min_turns,
|
||||
time_scale=args.time_scale,
|
||||
concurrency_limit=args.concurrency_limit,
|
||||
timeout_s=args.timeout_s,
|
||||
stream=not args.no_stream,
|
||||
stream_idle_timeout_s=args.stream_idle_timeout_s,
|
||||
kvcache_direct_max_uncached_tokens=args.kvcache_direct_max_uncached_tokens,
|
||||
kvcache_admission_mode=args.kvcache_admission_mode,
|
||||
sample_profile=args.sample_profile,
|
||||
min_initial_input_tokens=args.min_initial_input_tokens,
|
||||
max_initial_input_tokens=args.max_initial_input_tokens,
|
||||
max_append_input_tokens=args.max_append_input_tokens,
|
||||
max_output_tokens=args.max_output_tokens,
|
||||
min_overlap_ratio=args.min_overlap_ratio,
|
||||
launch_stack=True,
|
||||
)
|
||||
)
|
||||
print(
|
||||
f"benchmark artifacts written under {artifacts.run_dir}; "
|
||||
f"metrics={artifacts.metrics_path} summary={artifacts.summary_path}"
|
||||
)
|
||||
return
|
||||
|
||||
raise AssertionError(f"Unhandled command: {args.command}")
|
||||
|
||||
|
||||
def _add_topology_arguments(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
default="~/models/Qwen/Qwen3-Coder-30B-A3B-Instruct",
|
||||
)
|
||||
parser.add_argument("--prefill-workers", type=int, default=1)
|
||||
parser.add_argument("--decode-workers", type=int, default=1)
|
||||
parser.add_argument("--direct-workers", type=int, default=0)
|
||||
parser.add_argument("--prefill-tp-size", type=int, default=1)
|
||||
parser.add_argument("--decode-tp-size", type=int, default=1)
|
||||
parser.add_argument("--direct-tp-size", type=int, default=1)
|
||||
parser.add_argument("--gpu-budget", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--prefill-gpu-ids",
|
||||
default=None,
|
||||
help="Comma-separated GPU IDs for prefill workers, e.g. 3,4",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-gpu-ids",
|
||||
default=None,
|
||||
help="Comma-separated GPU IDs for decode workers, e.g. 5",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--direct-gpu-ids",
|
||||
default=None,
|
||||
help="Comma-separated GPU IDs for direct workers, e.g. 6",
|
||||
)
|
||||
parser.add_argument("--host", default="127.0.0.1")
|
||||
parser.add_argument("--router-port", type=int, default=8000)
|
||||
parser.add_argument("--prefill-port-base", type=int, default=30000)
|
||||
parser.add_argument("--decode-port-base", type=int, default=31000)
|
||||
parser.add_argument("--direct-port-base", type=int, default=32000)
|
||||
parser.add_argument("--bootstrap-port-base", type=int, default=8998)
|
||||
parser.add_argument("--transfer-backend", default="nixl")
|
||||
parser.add_argument(
|
||||
"--force-rdma",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Force real RDMA transport for PD KV transfer. "
|
||||
"Currently this requires Mooncake plus --ib-device."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--ib-device", default=None)
|
||||
parser.add_argument(
|
||||
"--no-trust-remote-code",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
|
||||
def _topology_from_args(args: argparse.Namespace):
|
||||
transfer_backend = args.transfer_backend
|
||||
if args.force_rdma:
|
||||
transfer_backend = "mooncake"
|
||||
|
||||
return build_single_node_topology(
|
||||
model_path=str(Path(args.model_path).expanduser()),
|
||||
prefill_worker_count=args.prefill_workers,
|
||||
decode_worker_count=args.decode_workers,
|
||||
direct_worker_count=args.direct_workers,
|
||||
prefill_tp_size=args.prefill_tp_size,
|
||||
decode_tp_size=args.decode_tp_size,
|
||||
direct_tp_size=args.direct_tp_size,
|
||||
prefill_gpu_ids=_parse_gpu_id_list(args.prefill_gpu_ids),
|
||||
decode_gpu_ids=_parse_gpu_id_list(args.decode_gpu_ids),
|
||||
direct_gpu_ids=_parse_gpu_id_list(args.direct_gpu_ids),
|
||||
total_gpu_budget=args.gpu_budget,
|
||||
host=args.host,
|
||||
router_port=args.router_port,
|
||||
prefill_port_base=args.prefill_port_base,
|
||||
decode_port_base=args.decode_port_base,
|
||||
direct_port_base=args.direct_port_base,
|
||||
bootstrap_port_base=args.bootstrap_port_base,
|
||||
transfer_backend=transfer_backend,
|
||||
force_rdma=args.force_rdma,
|
||||
trust_remote_code=not args.no_trust_remote_code,
|
||||
ib_device=args.ib_device,
|
||||
direct_extra_server_args=("--enable-streaming-session",),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
140
src/agentic_pd_hybrid/launcher.py
Normal file
140
src/agentic_pd_hybrid/launcher.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import shlex
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agentic_pd_hybrid.topology import SingleNodeTopology, WorkerSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LaunchPlan:
|
||||
prefill_commands: tuple[tuple[str, ...], ...]
|
||||
decode_commands: tuple[tuple[str, ...], ...]
|
||||
direct_commands: tuple[tuple[str, ...], ...]
|
||||
router_command: tuple[str, ...] | None
|
||||
|
||||
def render(self) -> str:
|
||||
sections: list[str] = []
|
||||
for idx, command in enumerate(self.prefill_commands):
|
||||
sections.append(_render_named_command(f"prefill-{idx}", command))
|
||||
for idx, command in enumerate(self.decode_commands):
|
||||
sections.append(_render_named_command(f"decode-{idx}", command))
|
||||
for idx, command in enumerate(self.direct_commands):
|
||||
sections.append(_render_named_command(f"direct-{idx}", command))
|
||||
if self.router_command is not None:
|
||||
sections.append(_render_named_command("router", self.router_command))
|
||||
return "\n\n".join(sections)
|
||||
|
||||
|
||||
def build_launch_plan(
|
||||
topology: SingleNodeTopology,
|
||||
*,
|
||||
prefill_policy: str = "round_robin",
|
||||
decode_policy: str = "manual",
|
||||
include_router: bool = True,
|
||||
) -> LaunchPlan:
|
||||
return LaunchPlan(
|
||||
prefill_commands=tuple(
|
||||
_build_server_command(topology, worker) for worker in topology.prefill_workers
|
||||
),
|
||||
decode_commands=tuple(
|
||||
_build_server_command(topology, worker) for worker in topology.decode_workers
|
||||
),
|
||||
direct_commands=tuple(
|
||||
_build_server_command(topology, worker) for worker in topology.direct_workers
|
||||
),
|
||||
router_command=(
|
||||
_build_router_command(
|
||||
topology,
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
)
|
||||
if include_router and topology.prefill_workers and topology.decode_workers
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_server_command(
|
||||
topology: SingleNodeTopology,
|
||||
worker: WorkerSpec,
|
||||
) -> tuple[str, ...]:
|
||||
command = [
|
||||
sys.executable,
|
||||
"-B",
|
||||
"-u",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
topology.model_path,
|
||||
"--host",
|
||||
worker.host,
|
||||
"--port",
|
||||
str(worker.port),
|
||||
"--base-gpu-id",
|
||||
str(worker.gpu_id),
|
||||
"--disaggregation-mode",
|
||||
_disaggregation_mode_for(worker),
|
||||
"--disaggregation-transfer-backend",
|
||||
topology.transfer_backend,
|
||||
]
|
||||
if worker.tp_size > 1:
|
||||
command.extend(["--tp-size", str(worker.tp_size)])
|
||||
if topology.trust_remote_code:
|
||||
command.append("--trust-remote-code")
|
||||
command.append("--enable-cache-report")
|
||||
if worker.bootstrap_port is not None:
|
||||
command.extend(
|
||||
["--disaggregation-bootstrap-port", str(worker.bootstrap_port)]
|
||||
)
|
||||
if topology.ib_device:
|
||||
command.extend(["--disaggregation-ib-device", topology.ib_device])
|
||||
command.extend(topology.extra_server_args)
|
||||
if worker.role == "prefill":
|
||||
command.extend(topology.prefill_extra_server_args)
|
||||
elif worker.role == "decode":
|
||||
command.extend(topology.decode_extra_server_args)
|
||||
else:
|
||||
command.extend(topology.direct_extra_server_args)
|
||||
return tuple(command)
|
||||
|
||||
|
||||
def _build_router_command(
|
||||
topology: SingleNodeTopology,
|
||||
*,
|
||||
prefill_policy: str,
|
||||
decode_policy: str,
|
||||
) -> tuple[str, ...]:
|
||||
command: list[str] = [
|
||||
sys.executable,
|
||||
"-B",
|
||||
"-u",
|
||||
"-m",
|
||||
"agentic_pd_hybrid.pd_router",
|
||||
"--host",
|
||||
topology.router_host,
|
||||
"--port",
|
||||
str(topology.router_port),
|
||||
"--prefill-policy",
|
||||
prefill_policy,
|
||||
"--decode-policy",
|
||||
decode_policy,
|
||||
]
|
||||
for worker in topology.prefill_workers:
|
||||
command.extend(
|
||||
["--prefill", worker.url, str(worker.bootstrap_port or topology.router_port)]
|
||||
)
|
||||
for worker in topology.decode_workers:
|
||||
command.extend(["--decode", worker.url])
|
||||
return tuple(command)
|
||||
|
||||
|
||||
def _render_named_command(name: str, command: tuple[str, ...]) -> str:
|
||||
return f"# {name}\n" + " ".join(shlex.quote(part) for part in command)
|
||||
|
||||
|
||||
def _disaggregation_mode_for(worker: WorkerSpec) -> str:
|
||||
if worker.role == "direct":
|
||||
return "null"
|
||||
return worker.role
|
||||
165
src/agentic_pd_hybrid/metrics.py
Normal file
165
src/agentic_pd_hybrid/metrics.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import statistics
|
||||
from collections import Counter
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from agentic_pd_hybrid.policies import RoutingDecision
|
||||
from agentic_pd_hybrid.trace import TraceRequest
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestMetrics:
|
||||
request_id: str
|
||||
session_id: str
|
||||
turn_id: int
|
||||
mechanism_name: str
|
||||
execution_mode: str
|
||||
trace_timestamp_s: float
|
||||
input_length: int
|
||||
output_length: int
|
||||
request_type: str
|
||||
policy_name: str
|
||||
assigned_prefill_node: str
|
||||
assigned_decode_node: str
|
||||
assigned_decode_index: int
|
||||
inflight_decode_load_at_assignment: int
|
||||
reuse_expected: bool
|
||||
reuse_observed: bool
|
||||
observed_overlap_blocks: int
|
||||
kv_transfer_blocks: int
|
||||
actual_kv_transfer_blocks: int
|
||||
cached_tokens: int
|
||||
re_prefill_required: bool
|
||||
effective_input_length: int | None
|
||||
session_reused: bool
|
||||
session_reset: bool
|
||||
latency_s: float | None
|
||||
ttft_s: float | None
|
||||
tpot_s: float | None
|
||||
error: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_decision(
|
||||
cls,
|
||||
request: TraceRequest,
|
||||
decision: RoutingDecision,
|
||||
*,
|
||||
mechanism_name: str,
|
||||
execution_mode: str,
|
||||
actual_kv_transfer_blocks: int,
|
||||
effective_input_length: int | None,
|
||||
cached_tokens: int,
|
||||
session_reused: bool,
|
||||
session_reset: bool,
|
||||
latency_s: float | None,
|
||||
ttft_s: float | None,
|
||||
tpot_s: float | None,
|
||||
error: str | None = None,
|
||||
) -> "RequestMetrics":
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
session_id=request.session_id,
|
||||
turn_id=request.turn_id,
|
||||
mechanism_name=mechanism_name,
|
||||
execution_mode=execution_mode,
|
||||
trace_timestamp_s=request.timestamp_s,
|
||||
input_length=request.input_length,
|
||||
output_length=request.output_length,
|
||||
request_type=request.request_type,
|
||||
policy_name=decision.policy_name,
|
||||
assigned_prefill_node=decision.prefill_worker_id,
|
||||
assigned_decode_node=decision.decode_worker_id,
|
||||
assigned_decode_index=decision.decode_worker_index,
|
||||
inflight_decode_load_at_assignment=decision.inflight_decode_load,
|
||||
reuse_expected=decision.reuse_expected,
|
||||
reuse_observed=decision.observed_reuse,
|
||||
observed_overlap_blocks=decision.observed_overlap_blocks,
|
||||
kv_transfer_blocks=decision.kv_transfer_blocks,
|
||||
actual_kv_transfer_blocks=actual_kv_transfer_blocks,
|
||||
cached_tokens=cached_tokens,
|
||||
re_prefill_required=decision.re_prefill_required,
|
||||
effective_input_length=effective_input_length,
|
||||
session_reused=session_reused,
|
||||
session_reset=session_reset,
|
||||
latency_s=latency_s,
|
||||
ttft_s=ttft_s,
|
||||
tpot_s=tpot_s,
|
||||
error=error,
|
||||
)
|
||||
|
||||
|
||||
def write_metrics_jsonl(path: Path, rows: list[RequestMetrics]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", encoding="utf-8") as handle:
|
||||
for row in rows:
|
||||
handle.write(json.dumps(asdict(row), sort_keys=True) + "\n")
|
||||
|
||||
|
||||
def write_summary_json(
|
||||
path: Path,
|
||||
rows: list[RequestMetrics],
|
||||
*,
|
||||
trace_path: Path,
|
||||
router_url: str | None,
|
||||
) -> None:
|
||||
latencies = [row.latency_s for row in rows if row.latency_s is not None]
|
||||
ttfts = [row.ttft_s for row in rows if row.ttft_s is not None]
|
||||
tpots = [row.tpot_s for row in rows if row.tpot_s is not None]
|
||||
per_decode_load = Counter(row.assigned_decode_node for row in rows)
|
||||
per_prefill_load = Counter(row.assigned_prefill_node for row in rows)
|
||||
|
||||
summary: dict[str, Any] = {
|
||||
"trace_path": str(trace_path),
|
||||
"router_url": router_url,
|
||||
"request_count": len(rows),
|
||||
"mechanisms": dict(sorted(Counter(row.mechanism_name for row in rows).items())),
|
||||
"execution_modes": dict(sorted(Counter(row.execution_mode for row in rows).items())),
|
||||
"latency_stats_s": _stats(latencies),
|
||||
"ttft_stats_s": _stats(ttfts),
|
||||
"tpot_stats_s": _stats(tpots),
|
||||
"reuse_expected_count": sum(1 for row in rows if row.reuse_expected),
|
||||
"reuse_observed_count": sum(1 for row in rows if row.reuse_observed),
|
||||
"re_prefill_count": sum(1 for row in rows if row.re_prefill_required),
|
||||
"cache_hit_request_count": sum(1 for row in rows if row.cached_tokens > 0),
|
||||
"total_cached_tokens": sum(row.cached_tokens for row in rows),
|
||||
"cached_tokens_stats": _stats([float(row.cached_tokens) for row in rows]),
|
||||
"session_reused_count": sum(1 for row in rows if row.session_reused),
|
||||
"session_reset_count": sum(1 for row in rows if row.session_reset),
|
||||
"total_kv_transfer_blocks": sum(row.kv_transfer_blocks for row in rows),
|
||||
"total_actual_kv_transfer_blocks": sum(
|
||||
row.actual_kv_transfer_blocks for row in rows
|
||||
),
|
||||
"per_decode_load": dict(sorted(per_decode_load.items())),
|
||||
"per_prefill_load": dict(sorted(per_prefill_load.items())),
|
||||
"error_count": sum(1 for row in rows if row.error is not None),
|
||||
}
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", encoding="utf-8") as handle:
|
||||
json.dump(summary, handle, indent=2, sort_keys=True)
|
||||
|
||||
|
||||
def _stats(values: list[float | None]) -> dict[str, float] | None:
|
||||
clean = [value for value in values if value is not None]
|
||||
if not clean:
|
||||
return None
|
||||
clean.sort()
|
||||
return {
|
||||
"count": float(len(clean)),
|
||||
"mean": statistics.fmean(clean),
|
||||
"p50": _percentile(clean, 0.50),
|
||||
"p90": _percentile(clean, 0.90),
|
||||
"p99": _percentile(clean, 0.99),
|
||||
}
|
||||
|
||||
|
||||
def _percentile(sorted_values: list[float], percentile: float) -> float:
|
||||
if not sorted_values:
|
||||
raise ValueError("sorted_values must not be empty")
|
||||
if len(sorted_values) == 1:
|
||||
return sorted_values[0]
|
||||
index = round((len(sorted_values) - 1) * percentile)
|
||||
return sorted_values[index]
|
||||
123
src/agentic_pd_hybrid/microbench.py
Normal file
123
src/agentic_pd_hybrid/microbench.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
BLOCK_TOKEN_BUDGET = 24
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SmallAppendTraceConfig:
|
||||
output_path: Path
|
||||
session_count: int = 8
|
||||
turns_per_session: int = 3
|
||||
initial_input_length: int = 10_000
|
||||
append_input_length: int = 1_000
|
||||
output_length: int = 1_000
|
||||
inter_turn_gap_s: float = 1.0
|
||||
session_stagger_s: float = 0.1
|
||||
request_type: str = "coder"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SmallAppendTraceSummary:
|
||||
output_path: str
|
||||
session_count: int
|
||||
turns_per_session: int
|
||||
request_count: int
|
||||
initial_input_length: int
|
||||
append_input_length: int
|
||||
output_length: int
|
||||
inter_turn_gap_s: float
|
||||
session_stagger_s: float
|
||||
|
||||
|
||||
def write_small_append_trace(config: SmallAppendTraceConfig) -> SmallAppendTraceSummary:
|
||||
if config.session_count <= 0:
|
||||
raise ValueError("session_count must be > 0")
|
||||
if config.turns_per_session <= 0:
|
||||
raise ValueError("turns_per_session must be > 0")
|
||||
if config.initial_input_length < 0:
|
||||
raise ValueError("initial_input_length must be >= 0")
|
||||
if config.append_input_length < 0:
|
||||
raise ValueError("append_input_length must be >= 0")
|
||||
if config.output_length < 0:
|
||||
raise ValueError("output_length must be >= 0")
|
||||
|
||||
config.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
records: list[dict[str, object]] = []
|
||||
next_chat_id = 1_000_000
|
||||
|
||||
for session_idx in range(config.session_count):
|
||||
root_chat_id = next_chat_id
|
||||
previous_chat_id = -1
|
||||
session_base_time = session_idx * config.session_stagger_s
|
||||
base_block_count = ceil(config.initial_input_length / BLOCK_TOKEN_BUDGET)
|
||||
base_hash_ids = [
|
||||
_hash_id_for(session_idx=session_idx, block_idx=block_idx)
|
||||
for block_idx in range(base_block_count)
|
||||
]
|
||||
|
||||
for turn_idx in range(config.turns_per_session):
|
||||
chat_id = root_chat_id if turn_idx == 0 else next_chat_id
|
||||
if turn_idx > 0:
|
||||
next_chat_id += 1
|
||||
|
||||
input_length = config.initial_input_length + turn_idx * (
|
||||
config.append_input_length + config.output_length
|
||||
)
|
||||
total_block_count = ceil(input_length / BLOCK_TOKEN_BUDGET)
|
||||
hash_ids = base_hash_ids + [
|
||||
_hash_id_for(
|
||||
session_idx=session_idx,
|
||||
block_idx=base_block_count + append_block_idx,
|
||||
)
|
||||
for append_block_idx in range(max(0, total_block_count - base_block_count))
|
||||
]
|
||||
|
||||
records.append(
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
"parent_chat_id": previous_chat_id,
|
||||
"timestamp": session_base_time
|
||||
+ turn_idx * config.inter_turn_gap_s,
|
||||
"input_length": input_length,
|
||||
"output_length": config.output_length,
|
||||
"type": config.request_type,
|
||||
"turn": turn_idx + 1,
|
||||
"hash_ids": hash_ids,
|
||||
}
|
||||
)
|
||||
previous_chat_id = chat_id
|
||||
|
||||
next_chat_id += 1
|
||||
|
||||
records.sort(key=lambda item: float(item["timestamp"]))
|
||||
with config.output_path.open("w", encoding="utf-8") as handle:
|
||||
for record in records:
|
||||
handle.write(json.dumps(record, sort_keys=True) + "\n")
|
||||
|
||||
summary = SmallAppendTraceSummary(
|
||||
output_path=str(config.output_path),
|
||||
session_count=config.session_count,
|
||||
turns_per_session=config.turns_per_session,
|
||||
request_count=len(records),
|
||||
initial_input_length=config.initial_input_length,
|
||||
append_input_length=config.append_input_length,
|
||||
output_length=config.output_length,
|
||||
inter_turn_gap_s=config.inter_turn_gap_s,
|
||||
session_stagger_s=config.session_stagger_s,
|
||||
)
|
||||
summary_path = config.output_path.with_suffix(
|
||||
config.output_path.suffix + ".summary.json"
|
||||
)
|
||||
with summary_path.open("w", encoding="utf-8") as handle:
|
||||
json.dump(asdict(summary), handle, indent=2, sort_keys=True)
|
||||
return summary
|
||||
|
||||
|
||||
def _hash_id_for(*, session_idx: int, block_idx: int) -> int:
|
||||
return session_idx * 1_000_000 + block_idx
|
||||
267
src/agentic_pd_hybrid/pd_router.py
Normal file
267
src/agentic_pd_hybrid/pd_router.py
Normal file
@@ -0,0 +1,267 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import random
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
from typing import AsyncIterator
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
_STREAM_CHUNK_SIZE = 1024 * 64
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouterConfig:
|
||||
host: str
|
||||
port: int
|
||||
prefill_urls: list[tuple[str, int]]
|
||||
decode_urls: list[str]
|
||||
prefill_policy: str = "round_robin"
|
||||
decode_policy: str = "manual"
|
||||
request_timeout_s: float = 1800.0
|
||||
|
||||
|
||||
class RouterState:
|
||||
def __init__(self, config: RouterConfig):
|
||||
if not config.prefill_urls:
|
||||
raise ValueError("At least one prefill worker is required")
|
||||
if not config.decode_urls:
|
||||
raise ValueError("At least one decode worker is required")
|
||||
self.config = config
|
||||
self.prefill_cursor = 0
|
||||
self.decode_cursor = 0
|
||||
self.sticky_decode_map: dict[str, int] = {}
|
||||
|
||||
def select_pair(self, headers: dict[str, str]) -> tuple[str, int, str]:
|
||||
prefill_url, bootstrap_port = self.config.prefill_urls[
|
||||
self.prefill_cursor % len(self.config.prefill_urls)
|
||||
]
|
||||
self.prefill_cursor += 1
|
||||
decode_index = self._select_decode_index(headers)
|
||||
return prefill_url, bootstrap_port, self.config.decode_urls[decode_index]
|
||||
|
||||
def _select_decode_index(self, headers: dict[str, str]) -> int:
|
||||
target_worker = headers.get("x-smg-target-worker")
|
||||
routing_key = headers.get("x-smg-routing-key")
|
||||
|
||||
if (
|
||||
self.config.decode_policy == "consistent_hashing"
|
||||
and target_worker is not None
|
||||
):
|
||||
idx = int(target_worker)
|
||||
if 0 <= idx < len(self.config.decode_urls):
|
||||
return idx
|
||||
|
||||
if self.config.decode_policy == "manual" and routing_key:
|
||||
cached = self.sticky_decode_map.get(routing_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
idx = self.decode_cursor % len(self.config.decode_urls)
|
||||
self.decode_cursor += 1
|
||||
self.sticky_decode_map[routing_key] = idx
|
||||
return idx
|
||||
|
||||
idx = self.decode_cursor % len(self.config.decode_urls)
|
||||
self.decode_cursor += 1
|
||||
return idx
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
router_state: RouterState | None = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate() -> Response:
|
||||
state = _require_state()
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = []
|
||||
for server in chain(
|
||||
(url for url, _ in state.config.prefill_urls),
|
||||
state.config.decode_urls,
|
||||
):
|
||||
tasks.append(session.get(f"{server}/health_generate"))
|
||||
for response in asyncio.as_completed(tasks):
|
||||
async with await response:
|
||||
pass
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def models() -> ORJSONResponse:
|
||||
state = _require_state()
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{state.config.prefill_urls[0][0]}/v1/models") as response:
|
||||
payload = await response.json()
|
||||
return ORJSONResponse(payload, status_code=response.status)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request) -> Response:
|
||||
request_data = await request.json()
|
||||
headers = {key.lower(): value for key, value in request.headers.items()}
|
||||
return await _forward_to_backend(
|
||||
request_data=request_data,
|
||||
headers=headers,
|
||||
endpoint_name="v1/chat/completions",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def completions(request: Request) -> Response:
|
||||
request_data = await request.json()
|
||||
headers = {key.lower(): value for key, value in request.headers.items()}
|
||||
return await _forward_to_backend(
|
||||
request_data=request_data,
|
||||
headers=headers,
|
||||
endpoint_name="v1/completions",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
request_data = await request.json()
|
||||
headers = {key.lower(): value for key, value in request.headers.items()}
|
||||
return await _forward_to_backend(
|
||||
request_data=request_data,
|
||||
headers=headers,
|
||||
endpoint_name="generate",
|
||||
)
|
||||
|
||||
|
||||
async def _forward_to_backend(
|
||||
*,
|
||||
request_data: dict,
|
||||
headers: dict[str, str],
|
||||
endpoint_name: str,
|
||||
) -> Response:
|
||||
state = _require_state()
|
||||
prefill_server, bootstrap_port, decode_server = state.select_pair(headers)
|
||||
modified_request = request_data.copy()
|
||||
modified_request.update(_build_bootstrap_payload(prefill_server, bootstrap_port))
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return StreamingResponse(
|
||||
_stream_generate(
|
||||
modified_request=modified_request,
|
||||
prefill_server=prefill_server,
|
||||
decode_server=decode_server,
|
||||
endpoint_name=endpoint_name,
|
||||
timeout_s=state.config.request_timeout_s,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=state.config.request_timeout_s)
|
||||
) as session:
|
||||
prefill_response, decode_response = await asyncio.gather(
|
||||
session.post(f"{prefill_server}/{endpoint_name}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint_name}", json=modified_request),
|
||||
)
|
||||
async with prefill_response:
|
||||
await prefill_response.read()
|
||||
async with decode_response:
|
||||
body = await decode_response.read()
|
||||
return Response(
|
||||
content=body,
|
||||
status_code=decode_response.status,
|
||||
media_type=decode_response.content_type,
|
||||
)
|
||||
|
||||
|
||||
async def _stream_generate(
|
||||
*,
|
||||
modified_request: dict,
|
||||
prefill_server: str,
|
||||
decode_server: str,
|
||||
endpoint_name: str,
|
||||
timeout_s: float,
|
||||
) -> AsyncIterator[bytes]:
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=timeout_s)
|
||||
) as session:
|
||||
prefill_response, decode_response = await asyncio.gather(
|
||||
session.post(f"{prefill_server}/{endpoint_name}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint_name}", json=modified_request),
|
||||
)
|
||||
async with prefill_response, decode_response:
|
||||
if decode_response.status != HTTPStatus.OK:
|
||||
payload = await decode_response.read()
|
||||
yield payload
|
||||
return
|
||||
async for chunk in decode_response.content.iter_chunked(_STREAM_CHUNK_SIZE):
|
||||
yield chunk
|
||||
|
||||
|
||||
def _build_bootstrap_payload(prefill_server: str, bootstrap_port: int) -> dict[str, object]:
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = parsed_url.hostname
|
||||
if hostname is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Unable to parse prefill hostname from {prefill_server}",
|
||||
)
|
||||
return {
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_port": bootstrap_port,
|
||||
"bootstrap_room": random.randint(0, 2**63 - 1),
|
||||
}
|
||||
|
||||
|
||||
def _require_state() -> RouterState:
|
||||
if router_state is None:
|
||||
raise HTTPException(status_code=500, detail="router not initialized")
|
||||
return router_state
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Minimal local PD router")
|
||||
parser.add_argument("--host", default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument(
|
||||
"--prefill",
|
||||
nargs=2,
|
||||
metavar=("URL", "BOOTSTRAP_PORT"),
|
||||
action="append",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode",
|
||||
action="append",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument("--prefill-policy", default="round_robin")
|
||||
parser.add_argument("--decode-policy", default="manual")
|
||||
parser.add_argument("--request-timeout-s", type=float, default=1800.0)
|
||||
args = parser.parse_args()
|
||||
|
||||
global router_state
|
||||
router_state = RouterState(
|
||||
RouterConfig(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
prefill_urls=[(url, int(port)) for url, port in args.prefill],
|
||||
decode_urls=list(args.decode),
|
||||
prefill_policy=args.prefill_policy,
|
||||
decode_policy=args.decode_policy,
|
||||
request_timeout_s=args.request_timeout_s,
|
||||
)
|
||||
)
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
235
src/agentic_pd_hybrid/policies.py
Normal file
235
src/agentic_pd_hybrid/policies.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol
|
||||
|
||||
from agentic_pd_hybrid.topology import SingleNodeTopology
|
||||
from agentic_pd_hybrid.trace import TraceRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionRouteState:
|
||||
last_decode_worker: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingDecision:
|
||||
policy_name: str
|
||||
prefill_worker_id: str
|
||||
decode_worker_id: str
|
||||
decode_worker_index: int
|
||||
reuse_expected: bool
|
||||
observed_overlap_blocks: int
|
||||
kv_transfer_blocks: int
|
||||
inflight_decode_load: int
|
||||
session_id: str
|
||||
request_id: str
|
||||
turn_id: int
|
||||
|
||||
@property
|
||||
def observed_reuse(self) -> bool:
|
||||
return self.observed_overlap_blocks > 0
|
||||
|
||||
@property
|
||||
def re_prefill_required(self) -> bool:
|
||||
return self.turn_id > 1 and self.observed_overlap_blocks == 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingState:
|
||||
prefill_cursor: int = 0
|
||||
decode_cursor: int = 0
|
||||
session_state: dict[str, SessionRouteState] = field(default_factory=dict)
|
||||
inflight_decode: Counter[str] = field(default_factory=Counter)
|
||||
decode_assignment_counts: Counter[str] = field(default_factory=Counter)
|
||||
decode_resident_blocks: dict[str, set[int]] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def create(cls, topology: SingleNodeTopology) -> "RoutingState":
|
||||
return cls(
|
||||
decode_resident_blocks={
|
||||
worker.worker_id: set() for worker in topology.route_workers
|
||||
}
|
||||
)
|
||||
|
||||
def next_prefill_worker_id(self, topology: SingleNodeTopology) -> str:
|
||||
if not topology.prefill_workers:
|
||||
return "none"
|
||||
worker = topology.prefill_workers[self.prefill_cursor % len(topology.prefill_workers)]
|
||||
self.prefill_cursor += 1
|
||||
return worker.worker_id
|
||||
|
||||
def next_decode_worker_id(self, topology: SingleNodeTopology) -> str:
|
||||
route_workers = topology.route_workers
|
||||
worker = route_workers[self.decode_cursor % len(route_workers)]
|
||||
self.decode_cursor += 1
|
||||
return worker.worker_id
|
||||
|
||||
def finish(self, request: TraceRequest, decision: RoutingDecision) -> None:
|
||||
session = self.session_state.setdefault(request.session_id, SessionRouteState())
|
||||
session.last_decode_worker = decision.decode_worker_id
|
||||
self.decode_resident_blocks[decision.decode_worker_id].update(request.hash_ids)
|
||||
self.inflight_decode[decision.decode_worker_id] -= 1
|
||||
if self.inflight_decode[decision.decode_worker_id] <= 0:
|
||||
del self.inflight_decode[decision.decode_worker_id]
|
||||
|
||||
|
||||
class RoutingPolicy(Protocol):
|
||||
name: str
|
||||
|
||||
def select(
|
||||
self,
|
||||
request: TraceRequest,
|
||||
*,
|
||||
topology: SingleNodeTopology,
|
||||
state: RoutingState,
|
||||
) -> RoutingDecision:
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DefaultPolicy:
|
||||
name: str = "default"
|
||||
|
||||
def select(
|
||||
self,
|
||||
request: TraceRequest,
|
||||
*,
|
||||
topology: SingleNodeTopology,
|
||||
state: RoutingState,
|
||||
) -> RoutingDecision:
|
||||
prefill_worker_id = state.next_prefill_worker_id(topology)
|
||||
decode_worker_id = state.next_decode_worker_id(topology)
|
||||
return _build_decision(
|
||||
policy_name=self.name,
|
||||
request=request,
|
||||
topology=topology,
|
||||
state=state,
|
||||
prefill_worker_id=prefill_worker_id,
|
||||
decode_worker_id=decode_worker_id,
|
||||
reuse_expected=False,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StickyDecodePolicy:
|
||||
name: str = "sticky"
|
||||
|
||||
def select(
|
||||
self,
|
||||
request: TraceRequest,
|
||||
*,
|
||||
topology: SingleNodeTopology,
|
||||
state: RoutingState,
|
||||
) -> RoutingDecision:
|
||||
session = state.session_state.get(request.session_id)
|
||||
prefill_worker_id = state.next_prefill_worker_id(topology)
|
||||
if request.turn_id > 1 and session and session.last_decode_worker is not None:
|
||||
decode_worker_id = session.last_decode_worker
|
||||
reuse_expected = True
|
||||
else:
|
||||
decode_worker_id = state.next_decode_worker_id(topology)
|
||||
reuse_expected = False
|
||||
return _build_decision(
|
||||
policy_name=self.name,
|
||||
request=request,
|
||||
topology=topology,
|
||||
state=state,
|
||||
prefill_worker_id=prefill_worker_id,
|
||||
decode_worker_id=decode_worker_id,
|
||||
reuse_expected=reuse_expected,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KvAwarePolicy:
|
||||
name: str = "kv-aware"
|
||||
sticky_bonus: int = 1
|
||||
|
||||
def select(
|
||||
self,
|
||||
request: TraceRequest,
|
||||
*,
|
||||
topology: SingleNodeTopology,
|
||||
state: RoutingState,
|
||||
) -> RoutingDecision:
|
||||
prefill_worker_id = state.next_prefill_worker_id(topology)
|
||||
session = state.session_state.get(request.session_id)
|
||||
|
||||
best_decode_worker_id: str | None = None
|
||||
best_score: tuple[int, int, int] | None = None
|
||||
for worker in topology.route_workers:
|
||||
overlap = _overlap_blocks(request, state, worker.worker_id)
|
||||
sticky = int(session is not None and session.last_decode_worker == worker.worker_id)
|
||||
inflight_penalty = -state.inflight_decode.get(worker.worker_id, 0)
|
||||
assignment_penalty = -state.decode_assignment_counts.get(worker.worker_id, 0)
|
||||
score = (
|
||||
overlap + sticky * self.sticky_bonus,
|
||||
sticky,
|
||||
inflight_penalty,
|
||||
assignment_penalty,
|
||||
)
|
||||
if best_score is None or score > best_score:
|
||||
best_score = score
|
||||
best_decode_worker_id = worker.worker_id
|
||||
|
||||
assert best_decode_worker_id is not None
|
||||
reuse_expected = bool(best_score and best_score[0] > 0)
|
||||
return _build_decision(
|
||||
policy_name=self.name,
|
||||
request=request,
|
||||
topology=topology,
|
||||
state=state,
|
||||
prefill_worker_id=prefill_worker_id,
|
||||
decode_worker_id=best_decode_worker_id,
|
||||
reuse_expected=reuse_expected,
|
||||
)
|
||||
|
||||
|
||||
def create_policy(name: str) -> RoutingPolicy:
|
||||
normalized = name.strip().lower()
|
||||
if normalized == "default":
|
||||
return DefaultPolicy()
|
||||
if normalized == "sticky":
|
||||
return StickyDecodePolicy()
|
||||
if normalized in {"kv-aware", "kv_aware", "kv"}:
|
||||
return KvAwarePolicy()
|
||||
raise ValueError(f"Unsupported policy: {name}")
|
||||
|
||||
|
||||
def _build_decision(
|
||||
*,
|
||||
policy_name: str,
|
||||
request: TraceRequest,
|
||||
topology: SingleNodeTopology,
|
||||
state: RoutingState,
|
||||
prefill_worker_id: str,
|
||||
decode_worker_id: str,
|
||||
reuse_expected: bool,
|
||||
) -> RoutingDecision:
|
||||
overlap = _overlap_blocks(request, state, decode_worker_id)
|
||||
state.inflight_decode[decode_worker_id] += 1
|
||||
state.decode_assignment_counts[decode_worker_id] += 1
|
||||
return RoutingDecision(
|
||||
policy_name=policy_name,
|
||||
prefill_worker_id=prefill_worker_id,
|
||||
decode_worker_id=decode_worker_id,
|
||||
decode_worker_index=topology.route_index(decode_worker_id),
|
||||
reuse_expected=reuse_expected,
|
||||
observed_overlap_blocks=overlap,
|
||||
kv_transfer_blocks=max(0, len(request.hash_ids) - overlap),
|
||||
inflight_decode_load=state.inflight_decode[decode_worker_id],
|
||||
session_id=request.session_id,
|
||||
request_id=request.request_id,
|
||||
turn_id=request.turn_id,
|
||||
)
|
||||
|
||||
|
||||
def _overlap_blocks(
|
||||
request: TraceRequest,
|
||||
state: RoutingState,
|
||||
decode_worker_id: str,
|
||||
) -> int:
|
||||
resident = state.decode_resident_blocks.get(decode_worker_id, set())
|
||||
return sum(1 for block in request.hash_ids if block in resident)
|
||||
2271
src/agentic_pd_hybrid/replay.py
Normal file
2271
src/agentic_pd_hybrid/replay.py
Normal file
File diff suppressed because it is too large
Load Diff
295
src/agentic_pd_hybrid/sampling.py
Normal file
295
src/agentic_pd_hybrid/sampling.py
Normal file
@@ -0,0 +1,295 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from agentic_pd_hybrid.trace import TraceRequest, load_trace
|
||||
|
||||
|
||||
SampleProfile = Literal["default", "small-append"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SessionSampleConfig:
|
||||
trace_path: Path
|
||||
output_path: Path
|
||||
target_duration_s: float = 600.0
|
||||
start_time_s: float = 0.0
|
||||
session_sample_rate: float = 1.0
|
||||
min_turns: int = 1
|
||||
max_requests: int | None = None
|
||||
profile: SampleProfile = "default"
|
||||
min_initial_input_tokens: int | None = None
|
||||
max_initial_input_tokens: int | None = None
|
||||
max_append_input_tokens: int | None = None
|
||||
max_output_tokens: int | None = None
|
||||
min_overlap_ratio: float | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SessionSampleSummary:
|
||||
input_trace_path: str
|
||||
output_trace_path: str
|
||||
request_count: int
|
||||
session_count: int
|
||||
multi_turn_session_count: int
|
||||
start_time_s: float
|
||||
end_time_s: float
|
||||
sampled_duration_s: float
|
||||
session_sample_rate: float
|
||||
min_turns: int
|
||||
profile: str
|
||||
min_initial_input_tokens: int | None
|
||||
max_initial_input_tokens: int | None
|
||||
max_append_input_tokens: int | None
|
||||
max_output_tokens: int | None
|
||||
min_overlap_ratio: float | None
|
||||
mean_append_input_tokens: float | None
|
||||
mean_turn_overlap_ratio: float | None
|
||||
|
||||
|
||||
def sample_trace_sessions(config: SessionSampleConfig) -> SessionSampleSummary:
|
||||
requests = load_trace(config.trace_path)
|
||||
sessions: dict[str, list[TraceRequest]] = defaultdict(list)
|
||||
for request in requests:
|
||||
sessions[request.session_id].append(request)
|
||||
|
||||
filters = _resolve_filters(config)
|
||||
eligible_sessions = {
|
||||
session_id: session_requests
|
||||
for session_id, session_requests in sessions.items()
|
||||
if len(session_requests) >= filters.min_turns
|
||||
and _session_matches_filters(session_requests, filters)
|
||||
and _keep_session(session_id, config.session_sample_rate)
|
||||
}
|
||||
ordered_sessions = sorted(
|
||||
eligible_sessions.values(),
|
||||
key=lambda session_requests: session_requests[0].timestamp_s,
|
||||
)
|
||||
|
||||
selected_requests: list[TraceRequest] = []
|
||||
sampled_start: float | None = None
|
||||
sampled_end: float | None = None
|
||||
for session_requests in ordered_sessions:
|
||||
session_first = session_requests[0].timestamp_s
|
||||
if session_first < config.start_time_s:
|
||||
continue
|
||||
|
||||
if sampled_start is None:
|
||||
sampled_start = session_first
|
||||
|
||||
selected_requests.extend(session_requests)
|
||||
sampled_end = max(request.timestamp_s for request in session_requests)
|
||||
|
||||
if config.max_requests is not None and len(selected_requests) >= config.max_requests:
|
||||
break
|
||||
if sampled_end - sampled_start >= config.target_duration_s:
|
||||
break
|
||||
|
||||
selected_requests.sort(key=lambda request: request.timestamp_s)
|
||||
if config.max_requests is not None:
|
||||
selected_requests = selected_requests[: config.max_requests]
|
||||
|
||||
if not selected_requests:
|
||||
raise ValueError("Sampling produced no requests; adjust the sampling arguments")
|
||||
|
||||
config.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with config.output_path.open("w", encoding="utf-8") as handle:
|
||||
for request in selected_requests:
|
||||
payload = {
|
||||
"request_id": request.request_id,
|
||||
"session_id": request.session_id,
|
||||
"chat_id": request.chat_id,
|
||||
"parent_chat_id": request.parent_chat_id,
|
||||
"timestamp": request.timestamp_s,
|
||||
"input_length": request.input_length,
|
||||
"output_length": request.output_length,
|
||||
"type": request.request_type,
|
||||
"turn": request.turn_id,
|
||||
"hash_ids": list(request.hash_ids),
|
||||
}
|
||||
handle.write(json.dumps(payload, sort_keys=True) + "\n")
|
||||
|
||||
selected_session_ids = {request.session_id for request in selected_requests}
|
||||
selected_session_requests = [
|
||||
eligible_sessions[session_id] for session_id in selected_session_ids
|
||||
]
|
||||
append_lengths = [
|
||||
length
|
||||
for session_requests in selected_session_requests
|
||||
for length in _turn_append_lengths(session_requests)
|
||||
]
|
||||
overlap_ratios = [
|
||||
ratio
|
||||
for session_requests in selected_session_requests
|
||||
for ratio in _turn_overlap_ratios(session_requests)
|
||||
]
|
||||
summary = SessionSampleSummary(
|
||||
input_trace_path=str(config.trace_path),
|
||||
output_trace_path=str(config.output_path),
|
||||
request_count=len(selected_requests),
|
||||
session_count=len(selected_session_ids),
|
||||
multi_turn_session_count=sum(
|
||||
1
|
||||
for session_id in selected_session_ids
|
||||
if len(eligible_sessions[session_id]) > 1
|
||||
),
|
||||
start_time_s=selected_requests[0].timestamp_s,
|
||||
end_time_s=selected_requests[-1].timestamp_s,
|
||||
sampled_duration_s=selected_requests[-1].timestamp_s
|
||||
- selected_requests[0].timestamp_s,
|
||||
session_sample_rate=config.session_sample_rate,
|
||||
min_turns=filters.min_turns,
|
||||
profile=config.profile,
|
||||
min_initial_input_tokens=filters.min_initial_input_tokens,
|
||||
max_initial_input_tokens=filters.max_initial_input_tokens,
|
||||
max_append_input_tokens=filters.max_append_input_tokens,
|
||||
max_output_tokens=filters.max_output_tokens,
|
||||
min_overlap_ratio=filters.min_overlap_ratio,
|
||||
mean_append_input_tokens=_mean(append_lengths),
|
||||
mean_turn_overlap_ratio=_mean(overlap_ratios),
|
||||
)
|
||||
summary_path = config.output_path.with_suffix(config.output_path.suffix + ".summary.json")
|
||||
with summary_path.open("w", encoding="utf-8") as handle:
|
||||
json.dump(asdict(summary), handle, indent=2, sort_keys=True)
|
||||
return summary
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ResolvedFilters:
|
||||
min_turns: int
|
||||
min_initial_input_tokens: int | None
|
||||
max_initial_input_tokens: int | None
|
||||
max_append_input_tokens: int | None
|
||||
max_output_tokens: int | None
|
||||
min_overlap_ratio: float | None
|
||||
|
||||
|
||||
def _resolve_filters(config: SessionSampleConfig) -> _ResolvedFilters:
|
||||
if config.profile == "default":
|
||||
return _ResolvedFilters(
|
||||
min_turns=config.min_turns,
|
||||
min_initial_input_tokens=config.min_initial_input_tokens,
|
||||
max_initial_input_tokens=config.max_initial_input_tokens,
|
||||
max_append_input_tokens=config.max_append_input_tokens,
|
||||
max_output_tokens=config.max_output_tokens,
|
||||
min_overlap_ratio=config.min_overlap_ratio,
|
||||
)
|
||||
|
||||
if config.profile != "small-append":
|
||||
raise ValueError(f"Unsupported sample profile: {config.profile}")
|
||||
|
||||
return _ResolvedFilters(
|
||||
min_turns=max(config.min_turns, 2),
|
||||
min_initial_input_tokens=(
|
||||
2048
|
||||
if config.min_initial_input_tokens is None
|
||||
else config.min_initial_input_tokens
|
||||
),
|
||||
max_initial_input_tokens=(
|
||||
16000
|
||||
if config.max_initial_input_tokens is None
|
||||
else config.max_initial_input_tokens
|
||||
),
|
||||
max_append_input_tokens=(
|
||||
2048
|
||||
if config.max_append_input_tokens is None
|
||||
else config.max_append_input_tokens
|
||||
),
|
||||
max_output_tokens=(
|
||||
2048 if config.max_output_tokens is None else config.max_output_tokens
|
||||
),
|
||||
min_overlap_ratio=(
|
||||
0.75 if config.min_overlap_ratio is None else config.min_overlap_ratio
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _session_matches_filters(
|
||||
session_requests: list[TraceRequest],
|
||||
filters: _ResolvedFilters,
|
||||
) -> bool:
|
||||
ordered = sorted(
|
||||
session_requests,
|
||||
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
|
||||
)
|
||||
if not ordered:
|
||||
return False
|
||||
|
||||
initial = ordered[0]
|
||||
if (
|
||||
filters.min_initial_input_tokens is not None
|
||||
and initial.input_length < filters.min_initial_input_tokens
|
||||
):
|
||||
return False
|
||||
if (
|
||||
filters.max_initial_input_tokens is not None
|
||||
and initial.input_length > filters.max_initial_input_tokens
|
||||
):
|
||||
return False
|
||||
if filters.max_output_tokens is not None and any(
|
||||
request.output_length > filters.max_output_tokens for request in ordered
|
||||
):
|
||||
return False
|
||||
|
||||
append_lengths = _turn_append_lengths(ordered)
|
||||
if filters.max_append_input_tokens is not None and any(
|
||||
append_length <= 0 or append_length > filters.max_append_input_tokens
|
||||
for append_length in append_lengths
|
||||
):
|
||||
return False
|
||||
|
||||
overlap_ratios = _turn_overlap_ratios(ordered)
|
||||
if filters.min_overlap_ratio is not None and any(
|
||||
overlap_ratio < filters.min_overlap_ratio for overlap_ratio in overlap_ratios
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _turn_append_lengths(session_requests: list[TraceRequest]) -> list[int]:
|
||||
ordered = sorted(
|
||||
session_requests,
|
||||
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
|
||||
)
|
||||
return [
|
||||
current.input_length - (previous.input_length + previous.output_length)
|
||||
for previous, current in zip(ordered, ordered[1:], strict=False)
|
||||
]
|
||||
|
||||
|
||||
def _turn_overlap_ratios(session_requests: list[TraceRequest]) -> list[float]:
|
||||
ordered = sorted(
|
||||
session_requests,
|
||||
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
|
||||
)
|
||||
ratios: list[float] = []
|
||||
for previous, current in zip(ordered, ordered[1:], strict=False):
|
||||
if not current.hash_ids:
|
||||
ratios.append(0.0)
|
||||
continue
|
||||
previous_blocks = set(previous.hash_ids)
|
||||
overlap = sum(1 for block in current.hash_ids if block in previous_blocks)
|
||||
ratios.append(overlap / len(current.hash_ids))
|
||||
return ratios
|
||||
|
||||
|
||||
def _mean(values: list[int] | list[float]) -> float | None:
|
||||
if not values:
|
||||
return None
|
||||
return sum(values) / len(values)
|
||||
|
||||
|
||||
def _keep_session(session_id: str, sample_rate: float) -> bool:
|
||||
if sample_rate >= 1.0:
|
||||
return True
|
||||
if sample_rate <= 0.0:
|
||||
return False
|
||||
digest = hashlib.blake2b(session_id.encode("utf-8"), digest_size=8).digest()
|
||||
bucket = int.from_bytes(digest, byteorder="big", signed=False) / 2**64
|
||||
return bucket < sample_rate
|
||||
222
src/agentic_pd_hybrid/stack.py
Normal file
222
src/agentic_pd_hybrid/stack.py
Normal file
@@ -0,0 +1,222 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from agentic_pd_hybrid.launcher import build_launch_plan
|
||||
from agentic_pd_hybrid.topology import SingleNodeTopology
|
||||
|
||||
|
||||
@dataclass
|
||||
class ManagedProcess:
|
||||
name: str
|
||||
command: tuple[str, ...]
|
||||
process: subprocess.Popen[bytes]
|
||||
log_path: Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class ManagedPdStack:
|
||||
topology: SingleNodeTopology
|
||||
run_dir: Path
|
||||
prefill_processes: list[ManagedProcess]
|
||||
decode_processes: list[ManagedProcess]
|
||||
direct_processes: list[ManagedProcess]
|
||||
router_process: ManagedProcess | None
|
||||
|
||||
@property
|
||||
def router_url(self) -> str:
|
||||
return self.topology.router_url
|
||||
|
||||
def stop(self) -> None:
|
||||
processes = (
|
||||
([self.router_process] if self.router_process is not None else [])
|
||||
+ self.direct_processes
|
||||
+ self.decode_processes
|
||||
+ self.prefill_processes
|
||||
)
|
||||
for managed in processes:
|
||||
if managed.process.poll() is None:
|
||||
os.killpg(os.getpgid(managed.process.pid), signal.SIGTERM)
|
||||
deadline = time.time() + 20
|
||||
for managed in processes:
|
||||
if managed.process.poll() is not None:
|
||||
continue
|
||||
remaining = max(0.0, deadline - time.time())
|
||||
try:
|
||||
managed.process.wait(timeout=remaining)
|
||||
except subprocess.TimeoutExpired:
|
||||
if managed.process.poll() is None:
|
||||
os.killpg(os.getpgid(managed.process.pid), signal.SIGKILL)
|
||||
managed.process.wait(timeout=5)
|
||||
|
||||
|
||||
def launch_pd_stack(
|
||||
*,
|
||||
topology: SingleNodeTopology,
|
||||
run_dir: Path,
|
||||
prefill_policy: str,
|
||||
decode_policy: str,
|
||||
timeout_s: float = 1200.0,
|
||||
include_router: bool = True,
|
||||
) -> ManagedPdStack:
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
logs_dir = run_dir / "logs"
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
plan = build_launch_plan(
|
||||
topology,
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
include_router=include_router,
|
||||
)
|
||||
|
||||
prefill_processes = [
|
||||
_spawn_process(
|
||||
name=f"prefill-{idx}",
|
||||
command=command,
|
||||
log_path=logs_dir / f"prefill-{idx}.log",
|
||||
topology=topology,
|
||||
)
|
||||
for idx, command in enumerate(plan.prefill_commands)
|
||||
]
|
||||
decode_processes = [
|
||||
_spawn_process(
|
||||
name=f"decode-{idx}",
|
||||
command=command,
|
||||
log_path=logs_dir / f"decode-{idx}.log",
|
||||
topology=topology,
|
||||
)
|
||||
for idx, command in enumerate(plan.decode_commands)
|
||||
]
|
||||
direct_processes = [
|
||||
_spawn_process(
|
||||
name=f"direct-{idx}",
|
||||
command=command,
|
||||
log_path=logs_dir / f"direct-{idx}.log",
|
||||
topology=topology,
|
||||
)
|
||||
for idx, command in enumerate(plan.direct_commands)
|
||||
]
|
||||
|
||||
router_process: ManagedProcess | None = None
|
||||
try:
|
||||
for worker in topology.prefill_workers:
|
||||
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
|
||||
for worker in topology.decode_workers:
|
||||
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
|
||||
for worker in topology.direct_workers:
|
||||
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
|
||||
|
||||
if plan.router_command is not None:
|
||||
router_process = _spawn_process(
|
||||
name="router",
|
||||
command=plan.router_command,
|
||||
log_path=logs_dir / "router.log",
|
||||
topology=topology,
|
||||
)
|
||||
_wait_for_ready_endpoint(f"{topology.router_url}/health", timeout_s=timeout_s)
|
||||
except Exception:
|
||||
stack = ManagedPdStack(
|
||||
topology=topology,
|
||||
run_dir=run_dir,
|
||||
prefill_processes=prefill_processes,
|
||||
decode_processes=decode_processes,
|
||||
direct_processes=direct_processes,
|
||||
router_process=router_process,
|
||||
)
|
||||
stack.stop()
|
||||
raise
|
||||
|
||||
return ManagedPdStack(
|
||||
topology=topology,
|
||||
run_dir=run_dir,
|
||||
prefill_processes=prefill_processes,
|
||||
decode_processes=decode_processes,
|
||||
direct_processes=direct_processes,
|
||||
router_process=router_process,
|
||||
)
|
||||
|
||||
|
||||
def _spawn_process(
|
||||
*,
|
||||
name: str,
|
||||
command: tuple[str, ...],
|
||||
log_path: Path,
|
||||
topology: SingleNodeTopology,
|
||||
) -> ManagedProcess:
|
||||
log_handle = log_path.open("wb")
|
||||
env = _build_process_env(topology)
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
stdout=log_handle,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=env,
|
||||
preexec_fn=os.setsid,
|
||||
)
|
||||
return ManagedProcess(
|
||||
name=name,
|
||||
command=command,
|
||||
process=process,
|
||||
log_path=log_path,
|
||||
)
|
||||
|
||||
|
||||
def _build_process_env(topology: SingleNodeTopology) -> dict[str, str]:
|
||||
env = os.environ.copy()
|
||||
env["PYTHONDONTWRITEBYTECODE"] = "1"
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
|
||||
# SGLang's PD bootstrap path uses `requests`; force localhost traffic to stay local.
|
||||
for key in (
|
||||
"http_proxy",
|
||||
"https_proxy",
|
||||
"all_proxy",
|
||||
"HTTP_PROXY",
|
||||
"HTTPS_PROXY",
|
||||
"ALL_PROXY",
|
||||
):
|
||||
env.pop(key, None)
|
||||
env["NO_PROXY"] = "*"
|
||||
env["no_proxy"] = "*"
|
||||
env.setdefault("SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", "600")
|
||||
env.setdefault("SGLANG_DISAGGREGATION_WAITING_TIMEOUT", "60")
|
||||
env.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1")
|
||||
if topology.force_rdma:
|
||||
env["MOONCAKE_PROTOCOL"] = "rdma"
|
||||
env["MC_MS_AUTO_DISC"] = "0"
|
||||
if topology.ib_device:
|
||||
env["MOONCAKE_DEVICE"] = topology.ib_device
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
python_paths = [
|
||||
str(repo_root / "src"),
|
||||
str(repo_root / "third_party" / "sglang" / "python"),
|
||||
]
|
||||
existing_pythonpath = env.get("PYTHONPATH")
|
||||
if existing_pythonpath:
|
||||
python_paths.append(existing_pythonpath)
|
||||
env["PYTHONPATH"] = os.pathsep.join(python_paths)
|
||||
return env
|
||||
|
||||
|
||||
def _wait_for_ready_endpoint(url: str, *, timeout_s: float) -> None:
|
||||
start = time.perf_counter()
|
||||
last_error: str | None = None
|
||||
with httpx.Client(timeout=5.0, trust_env=False) as client:
|
||||
while time.perf_counter() - start < timeout_s:
|
||||
try:
|
||||
response = client.get(url)
|
||||
if response.status_code == 200:
|
||||
return
|
||||
last_error = f"status={response.status_code}"
|
||||
except Exception as exc: # pragma: no cover
|
||||
last_error = f"{type(exc).__name__}: {exc}"
|
||||
time.sleep(1.0)
|
||||
raise TimeoutError(f"Timed out waiting for {url} ({last_error})")
|
||||
245
src/agentic_pd_hybrid/topology.py
Normal file
245
src/agentic_pd_hybrid/topology.py
Normal file
@@ -0,0 +1,245 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
|
||||
WorkerRole = Literal["prefill", "decode", "direct"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WorkerSpec:
|
||||
role: WorkerRole
|
||||
ordinal: int
|
||||
gpu_ids: tuple[int, ...]
|
||||
host: str
|
||||
port: int
|
||||
bootstrap_port: int | None = None
|
||||
|
||||
@property
|
||||
def worker_id(self) -> str:
|
||||
return f"{self.role}-{self.ordinal}"
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
@property
|
||||
def gpu_id(self) -> int:
|
||||
return self.gpu_ids[0]
|
||||
|
||||
@property
|
||||
def tp_size(self) -> int:
|
||||
return len(self.gpu_ids)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SingleNodeTopology:
|
||||
model_path: str
|
||||
prefill_workers: tuple[WorkerSpec, ...]
|
||||
decode_workers: tuple[WorkerSpec, ...]
|
||||
direct_workers: tuple[WorkerSpec, ...]
|
||||
router_host: str
|
||||
router_port: int
|
||||
transfer_backend: str
|
||||
trust_remote_code: bool
|
||||
force_rdma: bool = False
|
||||
ib_device: str | None = None
|
||||
extra_server_args: tuple[str, ...] = ()
|
||||
prefill_extra_server_args: tuple[str, ...] = ()
|
||||
decode_extra_server_args: tuple[str, ...] = ()
|
||||
direct_extra_server_args: tuple[str, ...] = ()
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return Path(self.model_path).name
|
||||
|
||||
@property
|
||||
def router_url(self) -> str:
|
||||
return f"http://{self.router_host}:{self.router_port}"
|
||||
|
||||
@property
|
||||
def route_workers(self) -> tuple[WorkerSpec, ...]:
|
||||
if self.decode_workers:
|
||||
return self.decode_workers
|
||||
return self.direct_workers
|
||||
|
||||
def route_index(self, worker_id: str) -> int:
|
||||
for idx, worker in enumerate(self.route_workers):
|
||||
if worker.worker_id == worker_id:
|
||||
return idx
|
||||
raise KeyError(f"Unknown route worker: {worker_id}")
|
||||
|
||||
|
||||
def build_single_node_topology(
|
||||
*,
|
||||
model_path: str,
|
||||
prefill_worker_count: int,
|
||||
decode_worker_count: int,
|
||||
direct_worker_count: int = 0,
|
||||
prefill_tp_size: int = 1,
|
||||
decode_tp_size: int = 1,
|
||||
direct_tp_size: int = 1,
|
||||
prefill_gpu_ids: tuple[int, ...] | None = None,
|
||||
decode_gpu_ids: tuple[int, ...] | None = None,
|
||||
direct_gpu_ids: tuple[int, ...] | None = None,
|
||||
total_gpu_budget: int = 8,
|
||||
host: str = "127.0.0.1",
|
||||
router_port: int = 8000,
|
||||
prefill_port_base: int = 30000,
|
||||
decode_port_base: int = 31000,
|
||||
direct_port_base: int = 32000,
|
||||
bootstrap_port_base: int = 8998,
|
||||
transfer_backend: str = "nixl",
|
||||
force_rdma: bool = False,
|
||||
trust_remote_code: bool = True,
|
||||
ib_device: str | None = None,
|
||||
extra_server_args: tuple[str, ...] = (),
|
||||
prefill_extra_server_args: tuple[str, ...] = (),
|
||||
decode_extra_server_args: tuple[str, ...] = (),
|
||||
direct_extra_server_args: tuple[str, ...] = (),
|
||||
) -> SingleNodeTopology:
|
||||
if prefill_worker_count < 0:
|
||||
raise ValueError("prefill_worker_count must be >= 0")
|
||||
if decode_worker_count < 0:
|
||||
raise ValueError("decode_worker_count must be >= 0")
|
||||
if direct_worker_count < 0:
|
||||
raise ValueError("direct_worker_count must be >= 0")
|
||||
if (
|
||||
prefill_worker_count == 0
|
||||
and decode_worker_count == 0
|
||||
and direct_worker_count == 0
|
||||
):
|
||||
raise ValueError("At least one worker must be configured")
|
||||
if prefill_tp_size <= 0:
|
||||
raise ValueError("prefill_tp_size must be >= 1")
|
||||
if decode_tp_size <= 0:
|
||||
raise ValueError("decode_tp_size must be >= 1")
|
||||
if direct_tp_size <= 0:
|
||||
raise ValueError("direct_tp_size must be >= 1")
|
||||
if force_rdma and not ib_device:
|
||||
raise ValueError("force_rdma requires --ib-device to be set")
|
||||
if force_rdma and transfer_backend != "mooncake":
|
||||
raise ValueError(
|
||||
"force_rdma currently requires transfer_backend='mooncake' "
|
||||
"to guarantee an RDMA path"
|
||||
)
|
||||
|
||||
total_gpus_required = (
|
||||
prefill_worker_count * prefill_tp_size
|
||||
+ decode_worker_count * decode_tp_size
|
||||
+ direct_worker_count * direct_tp_size
|
||||
)
|
||||
if total_gpus_required > total_gpu_budget:
|
||||
raise ValueError(
|
||||
"Single-node GPU budget exceeded: "
|
||||
f"{prefill_worker_count} prefill x tp={prefill_tp_size} + "
|
||||
f"{decode_worker_count} decode x tp={decode_tp_size} + "
|
||||
f"{direct_worker_count} direct x tp={direct_tp_size} > "
|
||||
f"{total_gpu_budget} GPUs"
|
||||
)
|
||||
|
||||
if prefill_gpu_ids is None:
|
||||
prefill_gpu_ids = tuple(range(prefill_worker_count * prefill_tp_size))
|
||||
if decode_gpu_ids is None:
|
||||
decode_gpu_ids = tuple(
|
||||
range(
|
||||
len(prefill_gpu_ids),
|
||||
len(prefill_gpu_ids) + decode_worker_count * decode_tp_size,
|
||||
)
|
||||
)
|
||||
if direct_gpu_ids is None:
|
||||
direct_gpu_ids = tuple(
|
||||
range(
|
||||
len(prefill_gpu_ids) + len(decode_gpu_ids),
|
||||
len(prefill_gpu_ids)
|
||||
+ len(decode_gpu_ids)
|
||||
+ direct_worker_count * direct_tp_size,
|
||||
)
|
||||
)
|
||||
|
||||
if len(prefill_gpu_ids) != prefill_worker_count * prefill_tp_size:
|
||||
raise ValueError(
|
||||
"prefill_gpu_ids length must equal prefill_worker_count * prefill_tp_size: "
|
||||
f"{len(prefill_gpu_ids)} != {prefill_worker_count * prefill_tp_size}"
|
||||
)
|
||||
if len(decode_gpu_ids) != decode_worker_count * decode_tp_size:
|
||||
raise ValueError(
|
||||
"decode_gpu_ids length must equal decode_worker_count * decode_tp_size: "
|
||||
f"{len(decode_gpu_ids)} != {decode_worker_count * decode_tp_size}"
|
||||
)
|
||||
if len(direct_gpu_ids) != direct_worker_count * direct_tp_size:
|
||||
raise ValueError(
|
||||
"direct_gpu_ids length must equal direct_worker_count * direct_tp_size: "
|
||||
f"{len(direct_gpu_ids)} != {direct_worker_count * direct_tp_size}"
|
||||
)
|
||||
assigned_gpu_ids = prefill_gpu_ids + decode_gpu_ids + direct_gpu_ids
|
||||
if len(set(assigned_gpu_ids)) != len(assigned_gpu_ids):
|
||||
raise ValueError("prefill/decode/direct GPU IDs must be unique")
|
||||
if any(gpu_id < 0 or gpu_id >= total_gpu_budget for gpu_id in assigned_gpu_ids):
|
||||
raise ValueError(
|
||||
"GPU IDs must fall within the single-node budget range "
|
||||
f"[0, {total_gpu_budget - 1}]"
|
||||
)
|
||||
|
||||
prefill_workers = tuple(
|
||||
WorkerSpec(
|
||||
role="prefill",
|
||||
ordinal=idx,
|
||||
gpu_ids=tuple(
|
||||
prefill_gpu_ids[
|
||||
idx * prefill_tp_size : (idx + 1) * prefill_tp_size
|
||||
]
|
||||
),
|
||||
host=host,
|
||||
port=prefill_port_base + idx,
|
||||
bootstrap_port=bootstrap_port_base + idx,
|
||||
)
|
||||
for idx in range(prefill_worker_count)
|
||||
)
|
||||
decode_workers = tuple(
|
||||
WorkerSpec(
|
||||
role="decode",
|
||||
ordinal=idx,
|
||||
gpu_ids=tuple(
|
||||
decode_gpu_ids[
|
||||
idx * decode_tp_size : (idx + 1) * decode_tp_size
|
||||
]
|
||||
),
|
||||
host=host,
|
||||
port=decode_port_base + idx,
|
||||
)
|
||||
for idx in range(decode_worker_count)
|
||||
)
|
||||
direct_workers = tuple(
|
||||
WorkerSpec(
|
||||
role="direct",
|
||||
ordinal=idx,
|
||||
gpu_ids=tuple(
|
||||
direct_gpu_ids[
|
||||
idx * direct_tp_size : (idx + 1) * direct_tp_size
|
||||
]
|
||||
),
|
||||
host=host,
|
||||
port=direct_port_base + idx,
|
||||
)
|
||||
for idx in range(direct_worker_count)
|
||||
)
|
||||
|
||||
return SingleNodeTopology(
|
||||
model_path=model_path,
|
||||
prefill_workers=prefill_workers,
|
||||
decode_workers=decode_workers,
|
||||
direct_workers=direct_workers,
|
||||
router_host=host,
|
||||
router_port=router_port,
|
||||
transfer_backend=transfer_backend,
|
||||
trust_remote_code=trust_remote_code,
|
||||
force_rdma=force_rdma,
|
||||
ib_device=ib_device,
|
||||
extra_server_args=extra_server_args,
|
||||
prefill_extra_server_args=prefill_extra_server_args,
|
||||
decode_extra_server_args=decode_extra_server_args,
|
||||
direct_extra_server_args=direct_extra_server_args,
|
||||
)
|
||||
106
src/agentic_pd_hybrid/trace.py
Normal file
106
src/agentic_pd_hybrid/trace.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TraceRequest:
|
||||
request_id: str
|
||||
session_id: str
|
||||
chat_id: int
|
||||
parent_chat_id: int
|
||||
timestamp_s: float
|
||||
input_length: int
|
||||
output_length: int
|
||||
request_type: str
|
||||
turn_id: int
|
||||
hash_ids: tuple[int, ...]
|
||||
|
||||
|
||||
def load_trace(path: Path, *, request_limit: int | None = None) -> list[TraceRequest]:
|
||||
chat_to_session: dict[int, str] = {}
|
||||
requests: list[TraceRequest] = []
|
||||
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for index, line in enumerate(handle):
|
||||
if request_limit is not None and len(requests) >= request_limit:
|
||||
break
|
||||
|
||||
payload = json.loads(line)
|
||||
chat_id = int(payload["chat_id"])
|
||||
parent_chat_id = int(payload["parent_chat_id"])
|
||||
session_id = _resolve_session_id(
|
||||
chat_id=chat_id,
|
||||
parent_chat_id=parent_chat_id,
|
||||
chat_to_session=chat_to_session,
|
||||
)
|
||||
turn_id = int(payload["turn"])
|
||||
request_id = f"{session_id}:{turn_id}:{chat_id}:{index}"
|
||||
requests.append(
|
||||
TraceRequest(
|
||||
request_id=request_id,
|
||||
session_id=session_id,
|
||||
chat_id=chat_id,
|
||||
parent_chat_id=parent_chat_id,
|
||||
timestamp_s=float(payload["timestamp"]),
|
||||
input_length=int(payload["input_length"]),
|
||||
output_length=int(payload["output_length"]),
|
||||
request_type=str(payload["type"]),
|
||||
turn_id=turn_id,
|
||||
hash_ids=tuple(int(item) for item in payload.get("hash_ids", [])),
|
||||
)
|
||||
)
|
||||
|
||||
return requests
|
||||
|
||||
|
||||
def build_synthetic_prompt(
|
||||
request: TraceRequest,
|
||||
*,
|
||||
block_token_budget: int = 24,
|
||||
) -> str:
|
||||
return " ".join(build_synthetic_prompt_tokens(request, block_token_budget=block_token_budget))
|
||||
|
||||
|
||||
def build_synthetic_prompt_tokens(
|
||||
request: TraceRequest,
|
||||
*,
|
||||
block_token_budget: int = 24,
|
||||
) -> list[str]:
|
||||
tokens: list[str] = []
|
||||
for hash_id in request.hash_ids:
|
||||
for offset in range(block_token_budget):
|
||||
tokens.append(f"blk{hash_id}_{offset}")
|
||||
|
||||
while len(tokens) < request.input_length:
|
||||
tokens.append(f"fill_{len(tokens) % 64}")
|
||||
|
||||
return tokens[: request.input_length]
|
||||
|
||||
|
||||
def build_synthetic_append_chunk(
|
||||
request: TraceRequest,
|
||||
append_length: int,
|
||||
) -> str:
|
||||
if append_length <= 0:
|
||||
return ""
|
||||
return " ".join(
|
||||
f"turn{request.turn_id}_append_{request.chat_id}_{offset}"
|
||||
for offset in range(append_length)
|
||||
)
|
||||
|
||||
|
||||
def _resolve_session_id(
|
||||
*,
|
||||
chat_id: int,
|
||||
parent_chat_id: int,
|
||||
chat_to_session: dict[int, str],
|
||||
) -> str:
|
||||
if parent_chat_id < 0:
|
||||
session_id = str(chat_id)
|
||||
else:
|
||||
session_id = chat_to_session.get(parent_chat_id, str(parent_chat_id))
|
||||
chat_to_session[chat_id] = session_id
|
||||
return session_id
|
||||
127
src/agentic_pd_hybrid/trace_profiles.py
Normal file
127
src/agentic_pd_hybrid/trace_profiles.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
|
||||
from agentic_pd_hybrid.trace import TraceRequest, load_trace
|
||||
|
||||
|
||||
BLOCK_TOKEN_BUDGET = 24
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalizeTraceLengthsConfig:
|
||||
trace_path: Path
|
||||
output_path: Path
|
||||
initial_input_length: int = 10_000
|
||||
append_input_length: int = 1_000
|
||||
output_length: int = 1_000
|
||||
max_requests: int | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalizeTraceLengthsSummary:
|
||||
input_trace_path: str
|
||||
output_trace_path: str
|
||||
request_count: int
|
||||
session_count: int
|
||||
multi_turn_session_count: int
|
||||
initial_input_length: int
|
||||
append_input_length: int
|
||||
output_length: int
|
||||
max_turns_per_session: int
|
||||
max_input_length: int
|
||||
|
||||
|
||||
def normalize_trace_lengths(
|
||||
config: NormalizeTraceLengthsConfig,
|
||||
) -> NormalizeTraceLengthsSummary:
|
||||
if config.initial_input_length < 0:
|
||||
raise ValueError("initial_input_length must be >= 0")
|
||||
if config.append_input_length < 0:
|
||||
raise ValueError("append_input_length must be >= 0")
|
||||
if config.output_length < 0:
|
||||
raise ValueError("output_length must be >= 0")
|
||||
|
||||
requests = load_trace(config.trace_path, request_limit=config.max_requests)
|
||||
sessions: dict[str, list[TraceRequest]] = defaultdict(list)
|
||||
for request in requests:
|
||||
sessions[request.session_id].append(request)
|
||||
|
||||
normalized_records: list[dict[str, object]] = []
|
||||
max_turns_per_session = 0
|
||||
max_input_length = 0
|
||||
|
||||
for session_idx, session_id in enumerate(sorted(sessions, key=_session_sort_key)):
|
||||
session_requests = sorted(
|
||||
sessions[session_id],
|
||||
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
|
||||
)
|
||||
max_turns_per_session = max(max_turns_per_session, len(session_requests))
|
||||
base_block_count = ceil(config.initial_input_length / BLOCK_TOKEN_BUDGET)
|
||||
base_hash_ids = [
|
||||
_hash_id_for(session_idx=session_idx, block_idx=block_idx)
|
||||
for block_idx in range(base_block_count)
|
||||
]
|
||||
|
||||
for turn_idx, request in enumerate(session_requests):
|
||||
input_length = config.initial_input_length + turn_idx * (
|
||||
config.append_input_length + config.output_length
|
||||
)
|
||||
total_block_count = ceil(input_length / BLOCK_TOKEN_BUDGET)
|
||||
hash_ids = base_hash_ids + [
|
||||
_hash_id_for(
|
||||
session_idx=session_idx,
|
||||
block_idx=base_block_count + append_block_idx,
|
||||
)
|
||||
for append_block_idx in range(max(0, total_block_count - base_block_count))
|
||||
]
|
||||
max_input_length = max(max_input_length, input_length)
|
||||
normalized_records.append(
|
||||
{
|
||||
"chat_id": request.chat_id,
|
||||
"parent_chat_id": request.parent_chat_id,
|
||||
"timestamp": request.timestamp_s,
|
||||
"input_length": input_length,
|
||||
"output_length": config.output_length,
|
||||
"type": request.request_type,
|
||||
"turn": request.turn_id,
|
||||
"hash_ids": hash_ids,
|
||||
}
|
||||
)
|
||||
|
||||
normalized_records.sort(key=lambda item: float(item["timestamp"]))
|
||||
config.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with config.output_path.open("w", encoding="utf-8") as handle:
|
||||
for record in normalized_records:
|
||||
handle.write(json.dumps(record, sort_keys=True) + "\n")
|
||||
|
||||
summary = NormalizeTraceLengthsSummary(
|
||||
input_trace_path=str(config.trace_path),
|
||||
output_trace_path=str(config.output_path),
|
||||
request_count=len(normalized_records),
|
||||
session_count=len(sessions),
|
||||
multi_turn_session_count=sum(
|
||||
1 for session_requests in sessions.values() if len(session_requests) > 1
|
||||
),
|
||||
initial_input_length=config.initial_input_length,
|
||||
append_input_length=config.append_input_length,
|
||||
output_length=config.output_length,
|
||||
max_turns_per_session=max_turns_per_session,
|
||||
max_input_length=max_input_length,
|
||||
)
|
||||
summary_path = config.output_path.with_suffix(config.output_path.suffix + ".summary.json")
|
||||
with summary_path.open("w", encoding="utf-8") as handle:
|
||||
json.dump(asdict(summary), handle, indent=2, sort_keys=True)
|
||||
return summary
|
||||
|
||||
|
||||
def _hash_id_for(*, session_idx: int, block_idx: int) -> int:
|
||||
return session_idx * 1_000_000 + block_idx
|
||||
|
||||
|
||||
def _session_sort_key(session_id: str) -> tuple[int, str]:
|
||||
return (0, session_id) if session_id.isdigit() else (1, session_id)
|
||||
Reference in New Issue
Block a user