feat: add agentic pd hybrid benchmark prototype

This commit is contained in:
2026-04-24 12:17:46 +00:00
parent d2fe014db7
commit 4bca741f32
16 changed files with 9182 additions and 0 deletions

24
pyproject.toml Normal file
View File

@@ -0,0 +1,24 @@
[project]
name = "agentic-pd-hybrid"
version = "0.1.0"
description = "Prototype for session-aware and KV-cache-aware PD routing on SGLang xPyD"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"httpx>=0.28.1",
"mooncake-transfer-engine",
"sglang==0.5.10",
]
[project.scripts]
agentic-pd-hybrid = "agentic_pd_hybrid.cli:main"
[build-system]
requires = ["setuptools>=68"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
[tool.uv]
prerelease = "allow"

View File

@@ -0,0 +1,12 @@
"""Agentic PD hybrid prototype."""
__all__ = [
"cli",
"launcher",
"metrics",
"microbench",
"policies",
"replay",
"topology",
"trace",
]

View File

@@ -0,0 +1,221 @@
from __future__ import annotations
import asyncio
import json
import signal
from dataclasses import asdict, dataclass, replace
from datetime import UTC, datetime
from pathlib import Path
from agentic_pd_hybrid.replay import ReplayConfig, replay_trace
from agentic_pd_hybrid.sampling import SessionSampleConfig, sample_trace_sessions
from agentic_pd_hybrid.stack import ManagedPdStack, launch_pd_stack
from agentic_pd_hybrid.topology import SingleNodeTopology
@dataclass(frozen=True)
class BenchmarkConfig:
trace_path: Path
output_root: Path
topology: SingleNodeTopology
policy_name: str
mechanism_name: str = "pd-disaggregation"
target_duration_s: float = 600.0
start_time_s: float = 0.0
session_sample_rate: float = 0.01
min_turns: int = 1
time_scale: float = 1.0
concurrency_limit: int = 32
timeout_s: float = 1200.0
stream: bool = True
stream_idle_timeout_s: float | None = 900.0
kvcache_direct_max_uncached_tokens: int = 2048
kvcache_admission_mode: str = "router"
sample_profile: str = "default"
min_initial_input_tokens: int | None = None
max_initial_input_tokens: int | None = None
max_append_input_tokens: int | None = None
max_output_tokens: int | None = None
min_overlap_ratio: float | None = None
launch_stack: bool = True
@dataclass(frozen=True)
class BenchmarkArtifacts:
run_dir: Path
sampled_trace_path: Path
metrics_path: Path
summary_path: Path
benchmark_config_path: Path
def run_live_benchmark(config: BenchmarkConfig) -> BenchmarkArtifacts:
run_id = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ")
run_label = f"{config.mechanism_name}-{config.policy_name}"
if config.mechanism_name == "kvcache-centric":
run_label = f"{run_label}-{config.kvcache_admission_mode}-admission"
run_dir = config.output_root / f"{run_label}-{run_id}"
run_dir.mkdir(parents=True, exist_ok=True)
topology = config.topology
if config.mechanism_name == "kvcache-centric":
topology = replace(
topology,
prefill_extra_server_args=topology.prefill_extra_server_args
+ ("--enable-streaming-session",),
decode_extra_server_args=topology.decode_extra_server_args
+ (
"--enable-streaming-session",
"--disaggregation-decode-allow-local-prefill",
),
)
sampled_trace_path = run_dir / "sampled-trace.jsonl"
sample_summary = sample_trace_sessions(
SessionSampleConfig(
trace_path=config.trace_path,
output_path=sampled_trace_path,
target_duration_s=config.target_duration_s,
start_time_s=config.start_time_s,
session_sample_rate=config.session_sample_rate,
min_turns=config.min_turns,
profile=config.sample_profile, # type: ignore[arg-type]
min_initial_input_tokens=config.min_initial_input_tokens,
max_initial_input_tokens=config.max_initial_input_tokens,
max_append_input_tokens=config.max_append_input_tokens,
max_output_tokens=config.max_output_tokens,
min_overlap_ratio=config.min_overlap_ratio,
)
)
stack: ManagedPdStack | None = None
previous_sigint = signal.getsignal(signal.SIGINT)
previous_sigterm = signal.getsignal(signal.SIGTERM)
def _handle_termination(signum, _frame) -> None:
if stack is not None:
stack.stop()
raise SystemExit(128 + signum)
try:
signal.signal(signal.SIGINT, _handle_termination)
signal.signal(signal.SIGTERM, _handle_termination)
if config.launch_stack:
stack = launch_pd_stack(
topology=topology,
run_dir=run_dir,
prefill_policy="round_robin",
decode_policy=_decode_policy_for(config.policy_name),
timeout_s=config.timeout_s,
include_router=(
config.mechanism_name in {"pd-disaggregation", "kvcache-centric"}
),
)
router_url = (
stack.router_url
if config.mechanism_name in {"pd-disaggregation", "kvcache-centric"}
else None
)
else:
router_url = (
topology.router_url
if config.mechanism_name in {"pd-disaggregation", "kvcache-centric"}
else None
)
metrics_path = run_dir / "request-metrics.jsonl"
replay_config = ReplayConfig(
trace_path=sampled_trace_path,
output_path=metrics_path,
policy_name=config.policy_name,
mechanism_name=config.mechanism_name,
topology=topology,
router_url=router_url,
model_name=topology.model_name,
pace=True,
time_scale=config.time_scale,
request_limit=None,
concurrency_limit=config.concurrency_limit,
header_mode=_header_mode_for(config.policy_name),
timeout_s=config.timeout_s,
stream=config.stream,
stream_idle_timeout_s=config.stream_idle_timeout_s,
kvcache_direct_max_uncached_tokens=config.kvcache_direct_max_uncached_tokens,
kvcache_admission_mode=config.kvcache_admission_mode, # type: ignore[arg-type]
)
asyncio.run(replay_trace(replay_config))
finally:
signal.signal(signal.SIGINT, previous_sigint)
signal.signal(signal.SIGTERM, previous_sigterm)
if stack is not None:
stack.stop()
benchmark_config_path = run_dir / "benchmark-config.json"
with benchmark_config_path.open("w", encoding="utf-8") as handle:
json.dump(
{
"policy_name": config.policy_name,
"mechanism_name": config.mechanism_name,
"target_duration_s": config.target_duration_s,
"start_time_s": config.start_time_s,
"session_sample_rate": config.session_sample_rate,
"min_turns": config.min_turns,
"time_scale": config.time_scale,
"concurrency_limit": config.concurrency_limit,
"timeout_s": config.timeout_s,
"stream": config.stream,
"stream_idle_timeout_s": config.stream_idle_timeout_s,
"kvcache_direct_max_uncached_tokens": config.kvcache_direct_max_uncached_tokens,
"kvcache_admission_mode": config.kvcache_admission_mode,
"sample_profile": config.sample_profile,
"min_initial_input_tokens": config.min_initial_input_tokens,
"max_initial_input_tokens": config.max_initial_input_tokens,
"max_append_input_tokens": config.max_append_input_tokens,
"max_output_tokens": config.max_output_tokens,
"min_overlap_ratio": config.min_overlap_ratio,
"sample_summary": asdict(sample_summary),
"topology": {
"model_path": config.topology.model_path,
"router_url": topology.router_url,
"transfer_backend": topology.transfer_backend,
"force_rdma": topology.force_rdma,
"ib_device": topology.ib_device,
"prefill_workers": [
worker.worker_id for worker in topology.prefill_workers
],
"decode_workers": [
worker.worker_id for worker in topology.decode_workers
],
"direct_workers": [
worker.worker_id for worker in topology.direct_workers
],
},
},
handle,
indent=2,
sort_keys=True,
)
return BenchmarkArtifacts(
run_dir=run_dir,
sampled_trace_path=sampled_trace_path,
metrics_path=run_dir / "request-metrics.jsonl",
summary_path=run_dir / "request-metrics.jsonl.summary.json",
benchmark_config_path=benchmark_config_path,
)
def _decode_policy_for(policy_name: str) -> str:
if policy_name == "sticky":
return "manual"
if policy_name == "kv-aware":
return "consistent_hashing"
return "round_robin"
def _header_mode_for(policy_name: str) -> str:
if policy_name == "sticky":
return "routing-key"
if policy_name == "kv-aware":
return "target-worker"
return "none"

View File

@@ -0,0 +1,484 @@
from __future__ import annotations
import argparse
import asyncio
from pathlib import Path
from agentic_pd_hybrid.benchmark import BenchmarkConfig, run_live_benchmark
from agentic_pd_hybrid.launcher import build_launch_plan
from agentic_pd_hybrid.microbench import SmallAppendTraceConfig, write_small_append_trace
from agentic_pd_hybrid.replay import ReplayConfig, replay_trace
from agentic_pd_hybrid.sampling import SessionSampleConfig, sample_trace_sessions
from agentic_pd_hybrid.trace_profiles import (
NormalizeTraceLengthsConfig,
normalize_trace_lengths,
)
from agentic_pd_hybrid.topology import build_single_node_topology
def _normalize_mechanism_name(name: str) -> str:
normalized = name.strip().lower()
aliases = {
"pd-disagg": "pd-disaggregation",
"pd-disaggregation": "pd-disaggregation",
"pd-hybrid": "pd-disaggregation",
"baseline-pd-disagg": "pd-disaggregation",
"pd-colo": "pd-colo",
"direct-d": "pd-colo",
"colocation": "pd-colo",
"kvcache-centric": "kvcache-centric",
"turn2+-direct-to-d": "kvcache-centric",
"pd-with-d-append": "kvcache-centric",
}
if normalized not in aliases:
raise ValueError(f"Unsupported mechanism: {name}")
return aliases[normalized]
def _parse_gpu_id_list(value: str | None) -> tuple[int, ...] | None:
if value is None:
return None
items = [item.strip() for item in value.split(",") if item.strip()]
if not items:
return tuple()
return tuple(int(item) for item in items)
def main() -> None:
parser = argparse.ArgumentParser(description="Agentic PD hybrid prototype")
subparsers = parser.add_subparsers(dest="command", required=True)
print_launch = subparsers.add_parser(
"print-launch",
help="Print one-node SGLang PD launch commands",
)
_add_topology_arguments(print_launch)
print_launch.add_argument("--prefill-policy", default="round_robin")
print_launch.add_argument("--decode-policy", default="manual")
replay = subparsers.add_parser(
"replay",
help="Replay trace and log request-level metrics",
)
_add_topology_arguments(replay)
replay.add_argument("--trace", type=Path, required=True)
replay.add_argument("--output", type=Path, required=True)
replay.add_argument(
"--policy",
choices=["default", "sticky", "kv-aware"],
default="sticky",
)
replay.add_argument(
"--mechanism",
choices=[
"pd-disaggregation",
"pd-hybrid",
"pd-disagg",
"pd-colo",
"direct-d",
"kvcache-centric",
"turn2+-direct-to-d",
"pd-with-d-append",
],
default="pd-disaggregation",
)
replay.add_argument("--router-url")
replay.add_argument("--model")
replay.add_argument(
"--header-mode",
choices=["auto", "none", "routing-key", "target-worker"],
default="auto",
)
replay.add_argument(
"--request-limit",
type=int,
default=None,
help="Replay at most this many requests",
)
replay.add_argument(
"--no-pace",
action="store_true",
help="Disable wall-clock pacing from trace timestamps",
)
replay.add_argument(
"--time-scale",
type=float,
default=1.0,
help="Scale trace timing by this factor when pacing is enabled",
)
replay.add_argument(
"--concurrency-limit",
type=int,
default=32,
)
replay.add_argument(
"--timeout-s",
type=float,
default=600.0,
)
replay.add_argument(
"--no-stream",
action="store_true",
help="Use non-streaming OpenAI responses for more robust E2E-only replay.",
)
replay.add_argument(
"--stream-idle-timeout-s",
type=float,
default=900.0,
help="Abort a streaming request if no SSE line arrives within this many seconds.",
)
replay.add_argument(
"--kvcache-direct-max-uncached-tokens",
type=int,
default=2048,
help="For kvcache-centric routing, bypass P when the uncached suffix is at most this many tokens.",
)
replay.add_argument(
"--kvcache-admission-mode",
choices=["router", "worker"],
default="router",
help=(
"For kvcache-centric routing, use router shadow-state admission "
"or query the decode worker on the critical path."
),
)
sample = subparsers.add_parser(
"sample-sessions",
help="Sample a session-granularity trace shard for live benchmarking",
)
sample.add_argument("--trace", type=Path, required=True)
sample.add_argument("--output", type=Path, required=True)
sample.add_argument("--target-duration-s", type=float, default=600.0)
sample.add_argument("--start-time-s", type=float, default=0.0)
sample.add_argument("--session-sample-rate", type=float, default=0.01)
sample.add_argument("--min-turns", type=int, default=1)
sample.add_argument("--max-requests", type=int, default=None)
sample.add_argument(
"--profile",
choices=["default", "small-append"],
default="default",
help="Optional workload-shape filter for live benchmarks.",
)
sample.add_argument("--min-initial-input-tokens", type=int, default=None)
sample.add_argument("--max-initial-input-tokens", type=int, default=None)
sample.add_argument("--max-append-input-tokens", type=int, default=None)
sample.add_argument("--max-output-tokens", type=int, default=None)
sample.add_argument("--min-overlap-ratio", type=float, default=None)
normalize = subparsers.add_parser(
"normalize-trace-lengths",
help="Rewrite a trace to a fixed turn1/append/output length profile",
)
normalize.add_argument("--trace", type=Path, required=True)
normalize.add_argument("--output", type=Path, required=True)
normalize.add_argument("--initial-input-length", type=int, default=10_000)
normalize.add_argument("--append-input-length", type=int, default=1_000)
normalize.add_argument("--output-length", type=int, default=1_000)
normalize.add_argument("--max-requests", type=int, default=None)
micro = subparsers.add_parser(
"make-small-append-trace",
help="Generate a synthetic multi-turn trace with small turn2+ appends",
)
micro.add_argument("--output", type=Path, required=True)
micro.add_argument("--session-count", type=int, default=8)
micro.add_argument("--turns-per-session", type=int, default=3)
micro.add_argument("--initial-input-length", type=int, default=10_000)
micro.add_argument("--append-input-length", type=int, default=1_000)
micro.add_argument("--output-length", type=int, default=1_000)
micro.add_argument("--inter-turn-gap-s", type=float, default=1.0)
micro.add_argument("--session-stagger-s", type=float, default=0.1)
benchmark = subparsers.add_parser(
"benchmark-live",
help="Launch a real PD stack, sample sessions, and collect live E2E numbers",
)
_add_topology_arguments(benchmark)
benchmark.add_argument("--trace", type=Path, required=True)
benchmark.add_argument(
"--policy",
choices=["default", "sticky", "kv-aware"],
default="sticky",
)
benchmark.add_argument(
"--mechanism",
choices=[
"pd-disaggregation",
"pd-hybrid",
"pd-disagg",
"pd-colo",
"direct-d",
"kvcache-centric",
"turn2+-direct-to-d",
"pd-with-d-append",
],
default="pd-disaggregation",
)
benchmark.add_argument("--output-root", type=Path, default=Path("outputs/live"))
benchmark.add_argument("--target-duration-s", type=float, default=600.0)
benchmark.add_argument("--start-time-s", type=float, default=0.0)
benchmark.add_argument("--session-sample-rate", type=float, default=0.01)
benchmark.add_argument("--min-turns", type=int, default=1)
benchmark.add_argument("--time-scale", type=float, default=1.0)
benchmark.add_argument("--concurrency-limit", type=int, default=32)
benchmark.add_argument("--timeout-s", type=float, default=1200.0)
benchmark.add_argument(
"--no-stream",
action="store_true",
help="Use non-streaming OpenAI responses for E2E-only live benchmarking.",
)
benchmark.add_argument(
"--stream-idle-timeout-s",
type=float,
default=900.0,
help="Abort a streaming request if no SSE line arrives within this many seconds.",
)
benchmark.add_argument(
"--kvcache-direct-max-uncached-tokens",
type=int,
default=2048,
help="For kvcache-centric routing, bypass P when the uncached suffix is at most this many tokens.",
)
benchmark.add_argument(
"--kvcache-admission-mode",
choices=["router", "worker"],
default="router",
help=(
"For kvcache-centric routing, use router shadow-state admission "
"or query the decode worker on the critical path."
),
)
benchmark.add_argument(
"--sample-profile",
choices=["default", "small-append"],
default="default",
help="Optional session-shape filter applied before live replay.",
)
benchmark.add_argument("--min-initial-input-tokens", type=int, default=None)
benchmark.add_argument("--max-initial-input-tokens", type=int, default=None)
benchmark.add_argument("--max-append-input-tokens", type=int, default=None)
benchmark.add_argument("--max-output-tokens", type=int, default=None)
benchmark.add_argument("--min-overlap-ratio", type=float, default=None)
args = parser.parse_args()
if args.command == "print-launch":
topology = _topology_from_args(args)
plan = build_launch_plan(
topology,
prefill_policy=args.prefill_policy,
decode_policy=args.decode_policy,
include_router=bool(topology.prefill_workers and topology.decode_workers),
)
print(plan.render())
return
if args.command == "replay":
topology = _topology_from_args(args)
config = ReplayConfig(
trace_path=args.trace,
output_path=args.output,
policy_name=args.policy,
mechanism_name=_normalize_mechanism_name(args.mechanism),
topology=topology,
router_url=args.router_url,
model_name=args.model,
pace=not args.no_pace,
time_scale=args.time_scale,
request_limit=args.request_limit,
concurrency_limit=args.concurrency_limit,
header_mode=args.header_mode,
timeout_s=args.timeout_s,
stream=not args.no_stream,
stream_idle_timeout_s=args.stream_idle_timeout_s,
kvcache_direct_max_uncached_tokens=args.kvcache_direct_max_uncached_tokens,
kvcache_admission_mode=args.kvcache_admission_mode,
)
results = asyncio.run(replay_trace(config))
print(
f"wrote {len(results)} request records to {args.output} and "
f"{args.output}{'.summary.json'}"
)
return
if args.command == "sample-sessions":
summary = sample_trace_sessions(
SessionSampleConfig(
trace_path=args.trace,
output_path=args.output,
target_duration_s=args.target_duration_s,
start_time_s=args.start_time_s,
session_sample_rate=args.session_sample_rate,
min_turns=args.min_turns,
max_requests=args.max_requests,
profile=args.profile,
min_initial_input_tokens=args.min_initial_input_tokens,
max_initial_input_tokens=args.max_initial_input_tokens,
max_append_input_tokens=args.max_append_input_tokens,
max_output_tokens=args.max_output_tokens,
min_overlap_ratio=args.min_overlap_ratio,
)
)
print(
f"wrote {summary.request_count} requests from {summary.session_count} sessions "
f"covering {summary.sampled_duration_s:.3f}s to {args.output}"
)
return
if args.command == "normalize-trace-lengths":
summary = normalize_trace_lengths(
NormalizeTraceLengthsConfig(
trace_path=args.trace,
output_path=args.output,
initial_input_length=args.initial_input_length,
append_input_length=args.append_input_length,
output_length=args.output_length,
max_requests=args.max_requests,
)
)
print(
f"wrote {summary.request_count} normalized requests from "
f"{summary.session_count} sessions to {args.output}"
)
return
if args.command == "make-small-append-trace":
summary = write_small_append_trace(
SmallAppendTraceConfig(
output_path=args.output,
session_count=args.session_count,
turns_per_session=args.turns_per_session,
initial_input_length=args.initial_input_length,
append_input_length=args.append_input_length,
output_length=args.output_length,
inter_turn_gap_s=args.inter_turn_gap_s,
session_stagger_s=args.session_stagger_s,
)
)
print(
f"wrote {summary.request_count} requests across {summary.session_count} sessions "
f"to {args.output}"
)
return
if args.command == "benchmark-live":
topology = _topology_from_args(args)
artifacts = run_live_benchmark(
BenchmarkConfig(
trace_path=args.trace,
output_root=args.output_root,
topology=topology,
policy_name=args.policy,
mechanism_name=_normalize_mechanism_name(args.mechanism),
target_duration_s=args.target_duration_s,
start_time_s=args.start_time_s,
session_sample_rate=args.session_sample_rate,
min_turns=args.min_turns,
time_scale=args.time_scale,
concurrency_limit=args.concurrency_limit,
timeout_s=args.timeout_s,
stream=not args.no_stream,
stream_idle_timeout_s=args.stream_idle_timeout_s,
kvcache_direct_max_uncached_tokens=args.kvcache_direct_max_uncached_tokens,
kvcache_admission_mode=args.kvcache_admission_mode,
sample_profile=args.sample_profile,
min_initial_input_tokens=args.min_initial_input_tokens,
max_initial_input_tokens=args.max_initial_input_tokens,
max_append_input_tokens=args.max_append_input_tokens,
max_output_tokens=args.max_output_tokens,
min_overlap_ratio=args.min_overlap_ratio,
launch_stack=True,
)
)
print(
f"benchmark artifacts written under {artifacts.run_dir}; "
f"metrics={artifacts.metrics_path} summary={artifacts.summary_path}"
)
return
raise AssertionError(f"Unhandled command: {args.command}")
def _add_topology_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--model-path",
default="~/models/Qwen/Qwen3-Coder-30B-A3B-Instruct",
)
parser.add_argument("--prefill-workers", type=int, default=1)
parser.add_argument("--decode-workers", type=int, default=1)
parser.add_argument("--direct-workers", type=int, default=0)
parser.add_argument("--prefill-tp-size", type=int, default=1)
parser.add_argument("--decode-tp-size", type=int, default=1)
parser.add_argument("--direct-tp-size", type=int, default=1)
parser.add_argument("--gpu-budget", type=int, default=8)
parser.add_argument(
"--prefill-gpu-ids",
default=None,
help="Comma-separated GPU IDs for prefill workers, e.g. 3,4",
)
parser.add_argument(
"--decode-gpu-ids",
default=None,
help="Comma-separated GPU IDs for decode workers, e.g. 5",
)
parser.add_argument(
"--direct-gpu-ids",
default=None,
help="Comma-separated GPU IDs for direct workers, e.g. 6",
)
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--router-port", type=int, default=8000)
parser.add_argument("--prefill-port-base", type=int, default=30000)
parser.add_argument("--decode-port-base", type=int, default=31000)
parser.add_argument("--direct-port-base", type=int, default=32000)
parser.add_argument("--bootstrap-port-base", type=int, default=8998)
parser.add_argument("--transfer-backend", default="nixl")
parser.add_argument(
"--force-rdma",
action="store_true",
help=(
"Force real RDMA transport for PD KV transfer. "
"Currently this requires Mooncake plus --ib-device."
),
)
parser.add_argument("--ib-device", default=None)
parser.add_argument(
"--no-trust-remote-code",
action="store_true",
)
def _topology_from_args(args: argparse.Namespace):
transfer_backend = args.transfer_backend
if args.force_rdma:
transfer_backend = "mooncake"
return build_single_node_topology(
model_path=str(Path(args.model_path).expanduser()),
prefill_worker_count=args.prefill_workers,
decode_worker_count=args.decode_workers,
direct_worker_count=args.direct_workers,
prefill_tp_size=args.prefill_tp_size,
decode_tp_size=args.decode_tp_size,
direct_tp_size=args.direct_tp_size,
prefill_gpu_ids=_parse_gpu_id_list(args.prefill_gpu_ids),
decode_gpu_ids=_parse_gpu_id_list(args.decode_gpu_ids),
direct_gpu_ids=_parse_gpu_id_list(args.direct_gpu_ids),
total_gpu_budget=args.gpu_budget,
host=args.host,
router_port=args.router_port,
prefill_port_base=args.prefill_port_base,
decode_port_base=args.decode_port_base,
direct_port_base=args.direct_port_base,
bootstrap_port_base=args.bootstrap_port_base,
transfer_backend=transfer_backend,
force_rdma=args.force_rdma,
trust_remote_code=not args.no_trust_remote_code,
ib_device=args.ib_device,
direct_extra_server_args=("--enable-streaming-session",),
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,140 @@
from __future__ import annotations
import shlex
import sys
from dataclasses import dataclass
from agentic_pd_hybrid.topology import SingleNodeTopology, WorkerSpec
@dataclass(frozen=True)
class LaunchPlan:
prefill_commands: tuple[tuple[str, ...], ...]
decode_commands: tuple[tuple[str, ...], ...]
direct_commands: tuple[tuple[str, ...], ...]
router_command: tuple[str, ...] | None
def render(self) -> str:
sections: list[str] = []
for idx, command in enumerate(self.prefill_commands):
sections.append(_render_named_command(f"prefill-{idx}", command))
for idx, command in enumerate(self.decode_commands):
sections.append(_render_named_command(f"decode-{idx}", command))
for idx, command in enumerate(self.direct_commands):
sections.append(_render_named_command(f"direct-{idx}", command))
if self.router_command is not None:
sections.append(_render_named_command("router", self.router_command))
return "\n\n".join(sections)
def build_launch_plan(
topology: SingleNodeTopology,
*,
prefill_policy: str = "round_robin",
decode_policy: str = "manual",
include_router: bool = True,
) -> LaunchPlan:
return LaunchPlan(
prefill_commands=tuple(
_build_server_command(topology, worker) for worker in topology.prefill_workers
),
decode_commands=tuple(
_build_server_command(topology, worker) for worker in topology.decode_workers
),
direct_commands=tuple(
_build_server_command(topology, worker) for worker in topology.direct_workers
),
router_command=(
_build_router_command(
topology,
prefill_policy=prefill_policy,
decode_policy=decode_policy,
)
if include_router and topology.prefill_workers and topology.decode_workers
else None
),
)
def _build_server_command(
topology: SingleNodeTopology,
worker: WorkerSpec,
) -> tuple[str, ...]:
command = [
sys.executable,
"-B",
"-u",
"-m",
"sglang.launch_server",
"--model-path",
topology.model_path,
"--host",
worker.host,
"--port",
str(worker.port),
"--base-gpu-id",
str(worker.gpu_id),
"--disaggregation-mode",
_disaggregation_mode_for(worker),
"--disaggregation-transfer-backend",
topology.transfer_backend,
]
if worker.tp_size > 1:
command.extend(["--tp-size", str(worker.tp_size)])
if topology.trust_remote_code:
command.append("--trust-remote-code")
command.append("--enable-cache-report")
if worker.bootstrap_port is not None:
command.extend(
["--disaggregation-bootstrap-port", str(worker.bootstrap_port)]
)
if topology.ib_device:
command.extend(["--disaggregation-ib-device", topology.ib_device])
command.extend(topology.extra_server_args)
if worker.role == "prefill":
command.extend(topology.prefill_extra_server_args)
elif worker.role == "decode":
command.extend(topology.decode_extra_server_args)
else:
command.extend(topology.direct_extra_server_args)
return tuple(command)
def _build_router_command(
topology: SingleNodeTopology,
*,
prefill_policy: str,
decode_policy: str,
) -> tuple[str, ...]:
command: list[str] = [
sys.executable,
"-B",
"-u",
"-m",
"agentic_pd_hybrid.pd_router",
"--host",
topology.router_host,
"--port",
str(topology.router_port),
"--prefill-policy",
prefill_policy,
"--decode-policy",
decode_policy,
]
for worker in topology.prefill_workers:
command.extend(
["--prefill", worker.url, str(worker.bootstrap_port or topology.router_port)]
)
for worker in topology.decode_workers:
command.extend(["--decode", worker.url])
return tuple(command)
def _render_named_command(name: str, command: tuple[str, ...]) -> str:
return f"# {name}\n" + " ".join(shlex.quote(part) for part in command)
def _disaggregation_mode_for(worker: WorkerSpec) -> str:
if worker.role == "direct":
return "null"
return worker.role

View File

@@ -0,0 +1,165 @@
from __future__ import annotations
import json
import statistics
from collections import Counter
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
from agentic_pd_hybrid.policies import RoutingDecision
from agentic_pd_hybrid.trace import TraceRequest
@dataclass(frozen=True)
class RequestMetrics:
request_id: str
session_id: str
turn_id: int
mechanism_name: str
execution_mode: str
trace_timestamp_s: float
input_length: int
output_length: int
request_type: str
policy_name: str
assigned_prefill_node: str
assigned_decode_node: str
assigned_decode_index: int
inflight_decode_load_at_assignment: int
reuse_expected: bool
reuse_observed: bool
observed_overlap_blocks: int
kv_transfer_blocks: int
actual_kv_transfer_blocks: int
cached_tokens: int
re_prefill_required: bool
effective_input_length: int | None
session_reused: bool
session_reset: bool
latency_s: float | None
ttft_s: float | None
tpot_s: float | None
error: str | None = None
@classmethod
def from_decision(
cls,
request: TraceRequest,
decision: RoutingDecision,
*,
mechanism_name: str,
execution_mode: str,
actual_kv_transfer_blocks: int,
effective_input_length: int | None,
cached_tokens: int,
session_reused: bool,
session_reset: bool,
latency_s: float | None,
ttft_s: float | None,
tpot_s: float | None,
error: str | None = None,
) -> "RequestMetrics":
return cls(
request_id=request.request_id,
session_id=request.session_id,
turn_id=request.turn_id,
mechanism_name=mechanism_name,
execution_mode=execution_mode,
trace_timestamp_s=request.timestamp_s,
input_length=request.input_length,
output_length=request.output_length,
request_type=request.request_type,
policy_name=decision.policy_name,
assigned_prefill_node=decision.prefill_worker_id,
assigned_decode_node=decision.decode_worker_id,
assigned_decode_index=decision.decode_worker_index,
inflight_decode_load_at_assignment=decision.inflight_decode_load,
reuse_expected=decision.reuse_expected,
reuse_observed=decision.observed_reuse,
observed_overlap_blocks=decision.observed_overlap_blocks,
kv_transfer_blocks=decision.kv_transfer_blocks,
actual_kv_transfer_blocks=actual_kv_transfer_blocks,
cached_tokens=cached_tokens,
re_prefill_required=decision.re_prefill_required,
effective_input_length=effective_input_length,
session_reused=session_reused,
session_reset=session_reset,
latency_s=latency_s,
ttft_s=ttft_s,
tpot_s=tpot_s,
error=error,
)
def write_metrics_jsonl(path: Path, rows: list[RequestMetrics]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(asdict(row), sort_keys=True) + "\n")
def write_summary_json(
path: Path,
rows: list[RequestMetrics],
*,
trace_path: Path,
router_url: str | None,
) -> None:
latencies = [row.latency_s for row in rows if row.latency_s is not None]
ttfts = [row.ttft_s for row in rows if row.ttft_s is not None]
tpots = [row.tpot_s for row in rows if row.tpot_s is not None]
per_decode_load = Counter(row.assigned_decode_node for row in rows)
per_prefill_load = Counter(row.assigned_prefill_node for row in rows)
summary: dict[str, Any] = {
"trace_path": str(trace_path),
"router_url": router_url,
"request_count": len(rows),
"mechanisms": dict(sorted(Counter(row.mechanism_name for row in rows).items())),
"execution_modes": dict(sorted(Counter(row.execution_mode for row in rows).items())),
"latency_stats_s": _stats(latencies),
"ttft_stats_s": _stats(ttfts),
"tpot_stats_s": _stats(tpots),
"reuse_expected_count": sum(1 for row in rows if row.reuse_expected),
"reuse_observed_count": sum(1 for row in rows if row.reuse_observed),
"re_prefill_count": sum(1 for row in rows if row.re_prefill_required),
"cache_hit_request_count": sum(1 for row in rows if row.cached_tokens > 0),
"total_cached_tokens": sum(row.cached_tokens for row in rows),
"cached_tokens_stats": _stats([float(row.cached_tokens) for row in rows]),
"session_reused_count": sum(1 for row in rows if row.session_reused),
"session_reset_count": sum(1 for row in rows if row.session_reset),
"total_kv_transfer_blocks": sum(row.kv_transfer_blocks for row in rows),
"total_actual_kv_transfer_blocks": sum(
row.actual_kv_transfer_blocks for row in rows
),
"per_decode_load": dict(sorted(per_decode_load.items())),
"per_prefill_load": dict(sorted(per_prefill_load.items())),
"error_count": sum(1 for row in rows if row.error is not None),
}
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as handle:
json.dump(summary, handle, indent=2, sort_keys=True)
def _stats(values: list[float | None]) -> dict[str, float] | None:
clean = [value for value in values if value is not None]
if not clean:
return None
clean.sort()
return {
"count": float(len(clean)),
"mean": statistics.fmean(clean),
"p50": _percentile(clean, 0.50),
"p90": _percentile(clean, 0.90),
"p99": _percentile(clean, 0.99),
}
def _percentile(sorted_values: list[float], percentile: float) -> float:
if not sorted_values:
raise ValueError("sorted_values must not be empty")
if len(sorted_values) == 1:
return sorted_values[0]
index = round((len(sorted_values) - 1) * percentile)
return sorted_values[index]

View File

@@ -0,0 +1,123 @@
from __future__ import annotations
import json
from dataclasses import asdict, dataclass
from math import ceil
from pathlib import Path
BLOCK_TOKEN_BUDGET = 24
@dataclass(frozen=True)
class SmallAppendTraceConfig:
output_path: Path
session_count: int = 8
turns_per_session: int = 3
initial_input_length: int = 10_000
append_input_length: int = 1_000
output_length: int = 1_000
inter_turn_gap_s: float = 1.0
session_stagger_s: float = 0.1
request_type: str = "coder"
@dataclass(frozen=True)
class SmallAppendTraceSummary:
output_path: str
session_count: int
turns_per_session: int
request_count: int
initial_input_length: int
append_input_length: int
output_length: int
inter_turn_gap_s: float
session_stagger_s: float
def write_small_append_trace(config: SmallAppendTraceConfig) -> SmallAppendTraceSummary:
if config.session_count <= 0:
raise ValueError("session_count must be > 0")
if config.turns_per_session <= 0:
raise ValueError("turns_per_session must be > 0")
if config.initial_input_length < 0:
raise ValueError("initial_input_length must be >= 0")
if config.append_input_length < 0:
raise ValueError("append_input_length must be >= 0")
if config.output_length < 0:
raise ValueError("output_length must be >= 0")
config.output_path.parent.mkdir(parents=True, exist_ok=True)
records: list[dict[str, object]] = []
next_chat_id = 1_000_000
for session_idx in range(config.session_count):
root_chat_id = next_chat_id
previous_chat_id = -1
session_base_time = session_idx * config.session_stagger_s
base_block_count = ceil(config.initial_input_length / BLOCK_TOKEN_BUDGET)
base_hash_ids = [
_hash_id_for(session_idx=session_idx, block_idx=block_idx)
for block_idx in range(base_block_count)
]
for turn_idx in range(config.turns_per_session):
chat_id = root_chat_id if turn_idx == 0 else next_chat_id
if turn_idx > 0:
next_chat_id += 1
input_length = config.initial_input_length + turn_idx * (
config.append_input_length + config.output_length
)
total_block_count = ceil(input_length / BLOCK_TOKEN_BUDGET)
hash_ids = base_hash_ids + [
_hash_id_for(
session_idx=session_idx,
block_idx=base_block_count + append_block_idx,
)
for append_block_idx in range(max(0, total_block_count - base_block_count))
]
records.append(
{
"chat_id": chat_id,
"parent_chat_id": previous_chat_id,
"timestamp": session_base_time
+ turn_idx * config.inter_turn_gap_s,
"input_length": input_length,
"output_length": config.output_length,
"type": config.request_type,
"turn": turn_idx + 1,
"hash_ids": hash_ids,
}
)
previous_chat_id = chat_id
next_chat_id += 1
records.sort(key=lambda item: float(item["timestamp"]))
with config.output_path.open("w", encoding="utf-8") as handle:
for record in records:
handle.write(json.dumps(record, sort_keys=True) + "\n")
summary = SmallAppendTraceSummary(
output_path=str(config.output_path),
session_count=config.session_count,
turns_per_session=config.turns_per_session,
request_count=len(records),
initial_input_length=config.initial_input_length,
append_input_length=config.append_input_length,
output_length=config.output_length,
inter_turn_gap_s=config.inter_turn_gap_s,
session_stagger_s=config.session_stagger_s,
)
summary_path = config.output_path.with_suffix(
config.output_path.suffix + ".summary.json"
)
with summary_path.open("w", encoding="utf-8") as handle:
json.dump(asdict(summary), handle, indent=2, sort_keys=True)
return summary
def _hash_id_for(*, session_idx: int, block_idx: int) -> int:
return session_idx * 1_000_000 + block_idx

View File

@@ -0,0 +1,267 @@
from __future__ import annotations
import argparse
import asyncio
import random
import urllib.parse
from dataclasses import dataclass
from http import HTTPStatus
from itertools import chain
from typing import AsyncIterator
import aiohttp
import orjson
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
_STREAM_CHUNK_SIZE = 1024 * 64
@dataclass
class RouterConfig:
host: str
port: int
prefill_urls: list[tuple[str, int]]
decode_urls: list[str]
prefill_policy: str = "round_robin"
decode_policy: str = "manual"
request_timeout_s: float = 1800.0
class RouterState:
def __init__(self, config: RouterConfig):
if not config.prefill_urls:
raise ValueError("At least one prefill worker is required")
if not config.decode_urls:
raise ValueError("At least one decode worker is required")
self.config = config
self.prefill_cursor = 0
self.decode_cursor = 0
self.sticky_decode_map: dict[str, int] = {}
def select_pair(self, headers: dict[str, str]) -> tuple[str, int, str]:
prefill_url, bootstrap_port = self.config.prefill_urls[
self.prefill_cursor % len(self.config.prefill_urls)
]
self.prefill_cursor += 1
decode_index = self._select_decode_index(headers)
return prefill_url, bootstrap_port, self.config.decode_urls[decode_index]
def _select_decode_index(self, headers: dict[str, str]) -> int:
target_worker = headers.get("x-smg-target-worker")
routing_key = headers.get("x-smg-routing-key")
if (
self.config.decode_policy == "consistent_hashing"
and target_worker is not None
):
idx = int(target_worker)
if 0 <= idx < len(self.config.decode_urls):
return idx
if self.config.decode_policy == "manual" and routing_key:
cached = self.sticky_decode_map.get(routing_key)
if cached is not None:
return cached
idx = self.decode_cursor % len(self.config.decode_urls)
self.decode_cursor += 1
self.sticky_decode_map[routing_key] = idx
return idx
idx = self.decode_cursor % len(self.config.decode_urls)
self.decode_cursor += 1
return idx
app = FastAPI()
router_state: RouterState | None = None
@app.get("/health")
async def health() -> Response:
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate() -> Response:
state = _require_state()
async with aiohttp.ClientSession() as session:
tasks = []
for server in chain(
(url for url, _ in state.config.prefill_urls),
state.config.decode_urls,
):
tasks.append(session.get(f"{server}/health_generate"))
for response in asyncio.as_completed(tasks):
async with await response:
pass
return Response(status_code=200)
@app.get("/v1/models")
async def models() -> ORJSONResponse:
state = _require_state()
async with aiohttp.ClientSession() as session:
async with session.get(f"{state.config.prefill_urls[0][0]}/v1/models") as response:
payload = await response.json()
return ORJSONResponse(payload, status_code=response.status)
@app.post("/v1/chat/completions")
async def chat_completions(request: Request) -> Response:
request_data = await request.json()
headers = {key.lower(): value for key, value in request.headers.items()}
return await _forward_to_backend(
request_data=request_data,
headers=headers,
endpoint_name="v1/chat/completions",
)
@app.post("/v1/completions")
async def completions(request: Request) -> Response:
request_data = await request.json()
headers = {key.lower(): value for key, value in request.headers.items()}
return await _forward_to_backend(
request_data=request_data,
headers=headers,
endpoint_name="v1/completions",
)
@app.post("/generate")
async def generate(request: Request) -> Response:
request_data = await request.json()
headers = {key.lower(): value for key, value in request.headers.items()}
return await _forward_to_backend(
request_data=request_data,
headers=headers,
endpoint_name="generate",
)
async def _forward_to_backend(
*,
request_data: dict,
headers: dict[str, str],
endpoint_name: str,
) -> Response:
state = _require_state()
prefill_server, bootstrap_port, decode_server = state.select_pair(headers)
modified_request = request_data.copy()
modified_request.update(_build_bootstrap_payload(prefill_server, bootstrap_port))
if request_data.get("stream", False):
return StreamingResponse(
_stream_generate(
modified_request=modified_request,
prefill_server=prefill_server,
decode_server=decode_server,
endpoint_name=endpoint_name,
timeout_s=state.config.request_timeout_s,
),
media_type="text/event-stream",
)
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=state.config.request_timeout_s)
) as session:
prefill_response, decode_response = await asyncio.gather(
session.post(f"{prefill_server}/{endpoint_name}", json=modified_request),
session.post(f"{decode_server}/{endpoint_name}", json=modified_request),
)
async with prefill_response:
await prefill_response.read()
async with decode_response:
body = await decode_response.read()
return Response(
content=body,
status_code=decode_response.status,
media_type=decode_response.content_type,
)
async def _stream_generate(
*,
modified_request: dict,
prefill_server: str,
decode_server: str,
endpoint_name: str,
timeout_s: float,
) -> AsyncIterator[bytes]:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=timeout_s)
) as session:
prefill_response, decode_response = await asyncio.gather(
session.post(f"{prefill_server}/{endpoint_name}", json=modified_request),
session.post(f"{decode_server}/{endpoint_name}", json=modified_request),
)
async with prefill_response, decode_response:
if decode_response.status != HTTPStatus.OK:
payload = await decode_response.read()
yield payload
return
async for chunk in decode_response.content.iter_chunked(_STREAM_CHUNK_SIZE):
yield chunk
def _build_bootstrap_payload(prefill_server: str, bootstrap_port: int) -> dict[str, object]:
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
if hostname is None:
raise HTTPException(
status_code=500,
detail=f"Unable to parse prefill hostname from {prefill_server}",
)
return {
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": random.randint(0, 2**63 - 1),
}
def _require_state() -> RouterState:
if router_state is None:
raise HTTPException(status_code=500, detail="router not initialized")
return router_state
def main() -> None:
parser = argparse.ArgumentParser(description="Minimal local PD router")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--prefill",
nargs=2,
metavar=("URL", "BOOTSTRAP_PORT"),
action="append",
required=True,
)
parser.add_argument(
"--decode",
action="append",
required=True,
)
parser.add_argument("--prefill-policy", default="round_robin")
parser.add_argument("--decode-policy", default="manual")
parser.add_argument("--request-timeout-s", type=float, default=1800.0)
args = parser.parse_args()
global router_state
router_state = RouterState(
RouterConfig(
host=args.host,
port=args.port,
prefill_urls=[(url, int(port)) for url, port in args.prefill],
decode_urls=list(args.decode),
prefill_policy=args.prefill_policy,
decode_policy=args.decode_policy,
request_timeout_s=args.request_timeout_s,
)
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,235 @@
from __future__ import annotations
from collections import Counter
from dataclasses import dataclass, field
from typing import Protocol
from agentic_pd_hybrid.topology import SingleNodeTopology
from agentic_pd_hybrid.trace import TraceRequest
@dataclass
class SessionRouteState:
last_decode_worker: str | None = None
@dataclass
class RoutingDecision:
policy_name: str
prefill_worker_id: str
decode_worker_id: str
decode_worker_index: int
reuse_expected: bool
observed_overlap_blocks: int
kv_transfer_blocks: int
inflight_decode_load: int
session_id: str
request_id: str
turn_id: int
@property
def observed_reuse(self) -> bool:
return self.observed_overlap_blocks > 0
@property
def re_prefill_required(self) -> bool:
return self.turn_id > 1 and self.observed_overlap_blocks == 0
@dataclass
class RoutingState:
prefill_cursor: int = 0
decode_cursor: int = 0
session_state: dict[str, SessionRouteState] = field(default_factory=dict)
inflight_decode: Counter[str] = field(default_factory=Counter)
decode_assignment_counts: Counter[str] = field(default_factory=Counter)
decode_resident_blocks: dict[str, set[int]] = field(default_factory=dict)
@classmethod
def create(cls, topology: SingleNodeTopology) -> "RoutingState":
return cls(
decode_resident_blocks={
worker.worker_id: set() for worker in topology.route_workers
}
)
def next_prefill_worker_id(self, topology: SingleNodeTopology) -> str:
if not topology.prefill_workers:
return "none"
worker = topology.prefill_workers[self.prefill_cursor % len(topology.prefill_workers)]
self.prefill_cursor += 1
return worker.worker_id
def next_decode_worker_id(self, topology: SingleNodeTopology) -> str:
route_workers = topology.route_workers
worker = route_workers[self.decode_cursor % len(route_workers)]
self.decode_cursor += 1
return worker.worker_id
def finish(self, request: TraceRequest, decision: RoutingDecision) -> None:
session = self.session_state.setdefault(request.session_id, SessionRouteState())
session.last_decode_worker = decision.decode_worker_id
self.decode_resident_blocks[decision.decode_worker_id].update(request.hash_ids)
self.inflight_decode[decision.decode_worker_id] -= 1
if self.inflight_decode[decision.decode_worker_id] <= 0:
del self.inflight_decode[decision.decode_worker_id]
class RoutingPolicy(Protocol):
name: str
def select(
self,
request: TraceRequest,
*,
topology: SingleNodeTopology,
state: RoutingState,
) -> RoutingDecision:
...
@dataclass(frozen=True)
class DefaultPolicy:
name: str = "default"
def select(
self,
request: TraceRequest,
*,
topology: SingleNodeTopology,
state: RoutingState,
) -> RoutingDecision:
prefill_worker_id = state.next_prefill_worker_id(topology)
decode_worker_id = state.next_decode_worker_id(topology)
return _build_decision(
policy_name=self.name,
request=request,
topology=topology,
state=state,
prefill_worker_id=prefill_worker_id,
decode_worker_id=decode_worker_id,
reuse_expected=False,
)
@dataclass(frozen=True)
class StickyDecodePolicy:
name: str = "sticky"
def select(
self,
request: TraceRequest,
*,
topology: SingleNodeTopology,
state: RoutingState,
) -> RoutingDecision:
session = state.session_state.get(request.session_id)
prefill_worker_id = state.next_prefill_worker_id(topology)
if request.turn_id > 1 and session and session.last_decode_worker is not None:
decode_worker_id = session.last_decode_worker
reuse_expected = True
else:
decode_worker_id = state.next_decode_worker_id(topology)
reuse_expected = False
return _build_decision(
policy_name=self.name,
request=request,
topology=topology,
state=state,
prefill_worker_id=prefill_worker_id,
decode_worker_id=decode_worker_id,
reuse_expected=reuse_expected,
)
@dataclass(frozen=True)
class KvAwarePolicy:
name: str = "kv-aware"
sticky_bonus: int = 1
def select(
self,
request: TraceRequest,
*,
topology: SingleNodeTopology,
state: RoutingState,
) -> RoutingDecision:
prefill_worker_id = state.next_prefill_worker_id(topology)
session = state.session_state.get(request.session_id)
best_decode_worker_id: str | None = None
best_score: tuple[int, int, int] | None = None
for worker in topology.route_workers:
overlap = _overlap_blocks(request, state, worker.worker_id)
sticky = int(session is not None and session.last_decode_worker == worker.worker_id)
inflight_penalty = -state.inflight_decode.get(worker.worker_id, 0)
assignment_penalty = -state.decode_assignment_counts.get(worker.worker_id, 0)
score = (
overlap + sticky * self.sticky_bonus,
sticky,
inflight_penalty,
assignment_penalty,
)
if best_score is None or score > best_score:
best_score = score
best_decode_worker_id = worker.worker_id
assert best_decode_worker_id is not None
reuse_expected = bool(best_score and best_score[0] > 0)
return _build_decision(
policy_name=self.name,
request=request,
topology=topology,
state=state,
prefill_worker_id=prefill_worker_id,
decode_worker_id=best_decode_worker_id,
reuse_expected=reuse_expected,
)
def create_policy(name: str) -> RoutingPolicy:
normalized = name.strip().lower()
if normalized == "default":
return DefaultPolicy()
if normalized == "sticky":
return StickyDecodePolicy()
if normalized in {"kv-aware", "kv_aware", "kv"}:
return KvAwarePolicy()
raise ValueError(f"Unsupported policy: {name}")
def _build_decision(
*,
policy_name: str,
request: TraceRequest,
topology: SingleNodeTopology,
state: RoutingState,
prefill_worker_id: str,
decode_worker_id: str,
reuse_expected: bool,
) -> RoutingDecision:
overlap = _overlap_blocks(request, state, decode_worker_id)
state.inflight_decode[decode_worker_id] += 1
state.decode_assignment_counts[decode_worker_id] += 1
return RoutingDecision(
policy_name=policy_name,
prefill_worker_id=prefill_worker_id,
decode_worker_id=decode_worker_id,
decode_worker_index=topology.route_index(decode_worker_id),
reuse_expected=reuse_expected,
observed_overlap_blocks=overlap,
kv_transfer_blocks=max(0, len(request.hash_ids) - overlap),
inflight_decode_load=state.inflight_decode[decode_worker_id],
session_id=request.session_id,
request_id=request.request_id,
turn_id=request.turn_id,
)
def _overlap_blocks(
request: TraceRequest,
state: RoutingState,
decode_worker_id: str,
) -> int:
resident = state.decode_resident_blocks.get(decode_worker_id, set())
return sum(1 for block in request.hash_ids if block in resident)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,295 @@
from __future__ import annotations
import hashlib
import json
from collections import defaultdict
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Literal
from agentic_pd_hybrid.trace import TraceRequest, load_trace
SampleProfile = Literal["default", "small-append"]
@dataclass(frozen=True)
class SessionSampleConfig:
trace_path: Path
output_path: Path
target_duration_s: float = 600.0
start_time_s: float = 0.0
session_sample_rate: float = 1.0
min_turns: int = 1
max_requests: int | None = None
profile: SampleProfile = "default"
min_initial_input_tokens: int | None = None
max_initial_input_tokens: int | None = None
max_append_input_tokens: int | None = None
max_output_tokens: int | None = None
min_overlap_ratio: float | None = None
@dataclass(frozen=True)
class SessionSampleSummary:
input_trace_path: str
output_trace_path: str
request_count: int
session_count: int
multi_turn_session_count: int
start_time_s: float
end_time_s: float
sampled_duration_s: float
session_sample_rate: float
min_turns: int
profile: str
min_initial_input_tokens: int | None
max_initial_input_tokens: int | None
max_append_input_tokens: int | None
max_output_tokens: int | None
min_overlap_ratio: float | None
mean_append_input_tokens: float | None
mean_turn_overlap_ratio: float | None
def sample_trace_sessions(config: SessionSampleConfig) -> SessionSampleSummary:
requests = load_trace(config.trace_path)
sessions: dict[str, list[TraceRequest]] = defaultdict(list)
for request in requests:
sessions[request.session_id].append(request)
filters = _resolve_filters(config)
eligible_sessions = {
session_id: session_requests
for session_id, session_requests in sessions.items()
if len(session_requests) >= filters.min_turns
and _session_matches_filters(session_requests, filters)
and _keep_session(session_id, config.session_sample_rate)
}
ordered_sessions = sorted(
eligible_sessions.values(),
key=lambda session_requests: session_requests[0].timestamp_s,
)
selected_requests: list[TraceRequest] = []
sampled_start: float | None = None
sampled_end: float | None = None
for session_requests in ordered_sessions:
session_first = session_requests[0].timestamp_s
if session_first < config.start_time_s:
continue
if sampled_start is None:
sampled_start = session_first
selected_requests.extend(session_requests)
sampled_end = max(request.timestamp_s for request in session_requests)
if config.max_requests is not None and len(selected_requests) >= config.max_requests:
break
if sampled_end - sampled_start >= config.target_duration_s:
break
selected_requests.sort(key=lambda request: request.timestamp_s)
if config.max_requests is not None:
selected_requests = selected_requests[: config.max_requests]
if not selected_requests:
raise ValueError("Sampling produced no requests; adjust the sampling arguments")
config.output_path.parent.mkdir(parents=True, exist_ok=True)
with config.output_path.open("w", encoding="utf-8") as handle:
for request in selected_requests:
payload = {
"request_id": request.request_id,
"session_id": request.session_id,
"chat_id": request.chat_id,
"parent_chat_id": request.parent_chat_id,
"timestamp": request.timestamp_s,
"input_length": request.input_length,
"output_length": request.output_length,
"type": request.request_type,
"turn": request.turn_id,
"hash_ids": list(request.hash_ids),
}
handle.write(json.dumps(payload, sort_keys=True) + "\n")
selected_session_ids = {request.session_id for request in selected_requests}
selected_session_requests = [
eligible_sessions[session_id] for session_id in selected_session_ids
]
append_lengths = [
length
for session_requests in selected_session_requests
for length in _turn_append_lengths(session_requests)
]
overlap_ratios = [
ratio
for session_requests in selected_session_requests
for ratio in _turn_overlap_ratios(session_requests)
]
summary = SessionSampleSummary(
input_trace_path=str(config.trace_path),
output_trace_path=str(config.output_path),
request_count=len(selected_requests),
session_count=len(selected_session_ids),
multi_turn_session_count=sum(
1
for session_id in selected_session_ids
if len(eligible_sessions[session_id]) > 1
),
start_time_s=selected_requests[0].timestamp_s,
end_time_s=selected_requests[-1].timestamp_s,
sampled_duration_s=selected_requests[-1].timestamp_s
- selected_requests[0].timestamp_s,
session_sample_rate=config.session_sample_rate,
min_turns=filters.min_turns,
profile=config.profile,
min_initial_input_tokens=filters.min_initial_input_tokens,
max_initial_input_tokens=filters.max_initial_input_tokens,
max_append_input_tokens=filters.max_append_input_tokens,
max_output_tokens=filters.max_output_tokens,
min_overlap_ratio=filters.min_overlap_ratio,
mean_append_input_tokens=_mean(append_lengths),
mean_turn_overlap_ratio=_mean(overlap_ratios),
)
summary_path = config.output_path.with_suffix(config.output_path.suffix + ".summary.json")
with summary_path.open("w", encoding="utf-8") as handle:
json.dump(asdict(summary), handle, indent=2, sort_keys=True)
return summary
@dataclass(frozen=True)
class _ResolvedFilters:
min_turns: int
min_initial_input_tokens: int | None
max_initial_input_tokens: int | None
max_append_input_tokens: int | None
max_output_tokens: int | None
min_overlap_ratio: float | None
def _resolve_filters(config: SessionSampleConfig) -> _ResolvedFilters:
if config.profile == "default":
return _ResolvedFilters(
min_turns=config.min_turns,
min_initial_input_tokens=config.min_initial_input_tokens,
max_initial_input_tokens=config.max_initial_input_tokens,
max_append_input_tokens=config.max_append_input_tokens,
max_output_tokens=config.max_output_tokens,
min_overlap_ratio=config.min_overlap_ratio,
)
if config.profile != "small-append":
raise ValueError(f"Unsupported sample profile: {config.profile}")
return _ResolvedFilters(
min_turns=max(config.min_turns, 2),
min_initial_input_tokens=(
2048
if config.min_initial_input_tokens is None
else config.min_initial_input_tokens
),
max_initial_input_tokens=(
16000
if config.max_initial_input_tokens is None
else config.max_initial_input_tokens
),
max_append_input_tokens=(
2048
if config.max_append_input_tokens is None
else config.max_append_input_tokens
),
max_output_tokens=(
2048 if config.max_output_tokens is None else config.max_output_tokens
),
min_overlap_ratio=(
0.75 if config.min_overlap_ratio is None else config.min_overlap_ratio
),
)
def _session_matches_filters(
session_requests: list[TraceRequest],
filters: _ResolvedFilters,
) -> bool:
ordered = sorted(
session_requests,
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
)
if not ordered:
return False
initial = ordered[0]
if (
filters.min_initial_input_tokens is not None
and initial.input_length < filters.min_initial_input_tokens
):
return False
if (
filters.max_initial_input_tokens is not None
and initial.input_length > filters.max_initial_input_tokens
):
return False
if filters.max_output_tokens is not None and any(
request.output_length > filters.max_output_tokens for request in ordered
):
return False
append_lengths = _turn_append_lengths(ordered)
if filters.max_append_input_tokens is not None and any(
append_length <= 0 or append_length > filters.max_append_input_tokens
for append_length in append_lengths
):
return False
overlap_ratios = _turn_overlap_ratios(ordered)
if filters.min_overlap_ratio is not None and any(
overlap_ratio < filters.min_overlap_ratio for overlap_ratio in overlap_ratios
):
return False
return True
def _turn_append_lengths(session_requests: list[TraceRequest]) -> list[int]:
ordered = sorted(
session_requests,
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
)
return [
current.input_length - (previous.input_length + previous.output_length)
for previous, current in zip(ordered, ordered[1:], strict=False)
]
def _turn_overlap_ratios(session_requests: list[TraceRequest]) -> list[float]:
ordered = sorted(
session_requests,
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
)
ratios: list[float] = []
for previous, current in zip(ordered, ordered[1:], strict=False):
if not current.hash_ids:
ratios.append(0.0)
continue
previous_blocks = set(previous.hash_ids)
overlap = sum(1 for block in current.hash_ids if block in previous_blocks)
ratios.append(overlap / len(current.hash_ids))
return ratios
def _mean(values: list[int] | list[float]) -> float | None:
if not values:
return None
return sum(values) / len(values)
def _keep_session(session_id: str, sample_rate: float) -> bool:
if sample_rate >= 1.0:
return True
if sample_rate <= 0.0:
return False
digest = hashlib.blake2b(session_id.encode("utf-8"), digest_size=8).digest()
bucket = int.from_bytes(digest, byteorder="big", signed=False) / 2**64
return bucket < sample_rate

View File

@@ -0,0 +1,222 @@
from __future__ import annotations
import os
import signal
import subprocess
import time
from dataclasses import dataclass
from pathlib import Path
import httpx
from agentic_pd_hybrid.launcher import build_launch_plan
from agentic_pd_hybrid.topology import SingleNodeTopology
@dataclass
class ManagedProcess:
name: str
command: tuple[str, ...]
process: subprocess.Popen[bytes]
log_path: Path
@dataclass
class ManagedPdStack:
topology: SingleNodeTopology
run_dir: Path
prefill_processes: list[ManagedProcess]
decode_processes: list[ManagedProcess]
direct_processes: list[ManagedProcess]
router_process: ManagedProcess | None
@property
def router_url(self) -> str:
return self.topology.router_url
def stop(self) -> None:
processes = (
([self.router_process] if self.router_process is not None else [])
+ self.direct_processes
+ self.decode_processes
+ self.prefill_processes
)
for managed in processes:
if managed.process.poll() is None:
os.killpg(os.getpgid(managed.process.pid), signal.SIGTERM)
deadline = time.time() + 20
for managed in processes:
if managed.process.poll() is not None:
continue
remaining = max(0.0, deadline - time.time())
try:
managed.process.wait(timeout=remaining)
except subprocess.TimeoutExpired:
if managed.process.poll() is None:
os.killpg(os.getpgid(managed.process.pid), signal.SIGKILL)
managed.process.wait(timeout=5)
def launch_pd_stack(
*,
topology: SingleNodeTopology,
run_dir: Path,
prefill_policy: str,
decode_policy: str,
timeout_s: float = 1200.0,
include_router: bool = True,
) -> ManagedPdStack:
run_dir.mkdir(parents=True, exist_ok=True)
logs_dir = run_dir / "logs"
logs_dir.mkdir(parents=True, exist_ok=True)
plan = build_launch_plan(
topology,
prefill_policy=prefill_policy,
decode_policy=decode_policy,
include_router=include_router,
)
prefill_processes = [
_spawn_process(
name=f"prefill-{idx}",
command=command,
log_path=logs_dir / f"prefill-{idx}.log",
topology=topology,
)
for idx, command in enumerate(plan.prefill_commands)
]
decode_processes = [
_spawn_process(
name=f"decode-{idx}",
command=command,
log_path=logs_dir / f"decode-{idx}.log",
topology=topology,
)
for idx, command in enumerate(plan.decode_commands)
]
direct_processes = [
_spawn_process(
name=f"direct-{idx}",
command=command,
log_path=logs_dir / f"direct-{idx}.log",
topology=topology,
)
for idx, command in enumerate(plan.direct_commands)
]
router_process: ManagedProcess | None = None
try:
for worker in topology.prefill_workers:
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
for worker in topology.decode_workers:
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
for worker in topology.direct_workers:
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
if plan.router_command is not None:
router_process = _spawn_process(
name="router",
command=plan.router_command,
log_path=logs_dir / "router.log",
topology=topology,
)
_wait_for_ready_endpoint(f"{topology.router_url}/health", timeout_s=timeout_s)
except Exception:
stack = ManagedPdStack(
topology=topology,
run_dir=run_dir,
prefill_processes=prefill_processes,
decode_processes=decode_processes,
direct_processes=direct_processes,
router_process=router_process,
)
stack.stop()
raise
return ManagedPdStack(
topology=topology,
run_dir=run_dir,
prefill_processes=prefill_processes,
decode_processes=decode_processes,
direct_processes=direct_processes,
router_process=router_process,
)
def _spawn_process(
*,
name: str,
command: tuple[str, ...],
log_path: Path,
topology: SingleNodeTopology,
) -> ManagedProcess:
log_handle = log_path.open("wb")
env = _build_process_env(topology)
process = subprocess.Popen(
command,
stdout=log_handle,
stderr=subprocess.STDOUT,
env=env,
preexec_fn=os.setsid,
)
return ManagedProcess(
name=name,
command=command,
process=process,
log_path=log_path,
)
def _build_process_env(topology: SingleNodeTopology) -> dict[str, str]:
env = os.environ.copy()
env["PYTHONDONTWRITEBYTECODE"] = "1"
env["PYTHONUNBUFFERED"] = "1"
# SGLang's PD bootstrap path uses `requests`; force localhost traffic to stay local.
for key in (
"http_proxy",
"https_proxy",
"all_proxy",
"HTTP_PROXY",
"HTTPS_PROXY",
"ALL_PROXY",
):
env.pop(key, None)
env["NO_PROXY"] = "*"
env["no_proxy"] = "*"
env.setdefault("SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", "600")
env.setdefault("SGLANG_DISAGGREGATION_WAITING_TIMEOUT", "60")
env.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1")
if topology.force_rdma:
env["MOONCAKE_PROTOCOL"] = "rdma"
env["MC_MS_AUTO_DISC"] = "0"
if topology.ib_device:
env["MOONCAKE_DEVICE"] = topology.ib_device
repo_root = Path(__file__).resolve().parents[2]
python_paths = [
str(repo_root / "src"),
str(repo_root / "third_party" / "sglang" / "python"),
]
existing_pythonpath = env.get("PYTHONPATH")
if existing_pythonpath:
python_paths.append(existing_pythonpath)
env["PYTHONPATH"] = os.pathsep.join(python_paths)
return env
def _wait_for_ready_endpoint(url: str, *, timeout_s: float) -> None:
start = time.perf_counter()
last_error: str | None = None
with httpx.Client(timeout=5.0, trust_env=False) as client:
while time.perf_counter() - start < timeout_s:
try:
response = client.get(url)
if response.status_code == 200:
return
last_error = f"status={response.status_code}"
except Exception as exc: # pragma: no cover
last_error = f"{type(exc).__name__}: {exc}"
time.sleep(1.0)
raise TimeoutError(f"Timed out waiting for {url} ({last_error})")

View File

@@ -0,0 +1,245 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
WorkerRole = Literal["prefill", "decode", "direct"]
@dataclass(frozen=True)
class WorkerSpec:
role: WorkerRole
ordinal: int
gpu_ids: tuple[int, ...]
host: str
port: int
bootstrap_port: int | None = None
@property
def worker_id(self) -> str:
return f"{self.role}-{self.ordinal}"
@property
def url(self) -> str:
return f"http://{self.host}:{self.port}"
@property
def gpu_id(self) -> int:
return self.gpu_ids[0]
@property
def tp_size(self) -> int:
return len(self.gpu_ids)
@dataclass(frozen=True)
class SingleNodeTopology:
model_path: str
prefill_workers: tuple[WorkerSpec, ...]
decode_workers: tuple[WorkerSpec, ...]
direct_workers: tuple[WorkerSpec, ...]
router_host: str
router_port: int
transfer_backend: str
trust_remote_code: bool
force_rdma: bool = False
ib_device: str | None = None
extra_server_args: tuple[str, ...] = ()
prefill_extra_server_args: tuple[str, ...] = ()
decode_extra_server_args: tuple[str, ...] = ()
direct_extra_server_args: tuple[str, ...] = ()
@property
def model_name(self) -> str:
return Path(self.model_path).name
@property
def router_url(self) -> str:
return f"http://{self.router_host}:{self.router_port}"
@property
def route_workers(self) -> tuple[WorkerSpec, ...]:
if self.decode_workers:
return self.decode_workers
return self.direct_workers
def route_index(self, worker_id: str) -> int:
for idx, worker in enumerate(self.route_workers):
if worker.worker_id == worker_id:
return idx
raise KeyError(f"Unknown route worker: {worker_id}")
def build_single_node_topology(
*,
model_path: str,
prefill_worker_count: int,
decode_worker_count: int,
direct_worker_count: int = 0,
prefill_tp_size: int = 1,
decode_tp_size: int = 1,
direct_tp_size: int = 1,
prefill_gpu_ids: tuple[int, ...] | None = None,
decode_gpu_ids: tuple[int, ...] | None = None,
direct_gpu_ids: tuple[int, ...] | None = None,
total_gpu_budget: int = 8,
host: str = "127.0.0.1",
router_port: int = 8000,
prefill_port_base: int = 30000,
decode_port_base: int = 31000,
direct_port_base: int = 32000,
bootstrap_port_base: int = 8998,
transfer_backend: str = "nixl",
force_rdma: bool = False,
trust_remote_code: bool = True,
ib_device: str | None = None,
extra_server_args: tuple[str, ...] = (),
prefill_extra_server_args: tuple[str, ...] = (),
decode_extra_server_args: tuple[str, ...] = (),
direct_extra_server_args: tuple[str, ...] = (),
) -> SingleNodeTopology:
if prefill_worker_count < 0:
raise ValueError("prefill_worker_count must be >= 0")
if decode_worker_count < 0:
raise ValueError("decode_worker_count must be >= 0")
if direct_worker_count < 0:
raise ValueError("direct_worker_count must be >= 0")
if (
prefill_worker_count == 0
and decode_worker_count == 0
and direct_worker_count == 0
):
raise ValueError("At least one worker must be configured")
if prefill_tp_size <= 0:
raise ValueError("prefill_tp_size must be >= 1")
if decode_tp_size <= 0:
raise ValueError("decode_tp_size must be >= 1")
if direct_tp_size <= 0:
raise ValueError("direct_tp_size must be >= 1")
if force_rdma and not ib_device:
raise ValueError("force_rdma requires --ib-device to be set")
if force_rdma and transfer_backend != "mooncake":
raise ValueError(
"force_rdma currently requires transfer_backend='mooncake' "
"to guarantee an RDMA path"
)
total_gpus_required = (
prefill_worker_count * prefill_tp_size
+ decode_worker_count * decode_tp_size
+ direct_worker_count * direct_tp_size
)
if total_gpus_required > total_gpu_budget:
raise ValueError(
"Single-node GPU budget exceeded: "
f"{prefill_worker_count} prefill x tp={prefill_tp_size} + "
f"{decode_worker_count} decode x tp={decode_tp_size} + "
f"{direct_worker_count} direct x tp={direct_tp_size} > "
f"{total_gpu_budget} GPUs"
)
if prefill_gpu_ids is None:
prefill_gpu_ids = tuple(range(prefill_worker_count * prefill_tp_size))
if decode_gpu_ids is None:
decode_gpu_ids = tuple(
range(
len(prefill_gpu_ids),
len(prefill_gpu_ids) + decode_worker_count * decode_tp_size,
)
)
if direct_gpu_ids is None:
direct_gpu_ids = tuple(
range(
len(prefill_gpu_ids) + len(decode_gpu_ids),
len(prefill_gpu_ids)
+ len(decode_gpu_ids)
+ direct_worker_count * direct_tp_size,
)
)
if len(prefill_gpu_ids) != prefill_worker_count * prefill_tp_size:
raise ValueError(
"prefill_gpu_ids length must equal prefill_worker_count * prefill_tp_size: "
f"{len(prefill_gpu_ids)} != {prefill_worker_count * prefill_tp_size}"
)
if len(decode_gpu_ids) != decode_worker_count * decode_tp_size:
raise ValueError(
"decode_gpu_ids length must equal decode_worker_count * decode_tp_size: "
f"{len(decode_gpu_ids)} != {decode_worker_count * decode_tp_size}"
)
if len(direct_gpu_ids) != direct_worker_count * direct_tp_size:
raise ValueError(
"direct_gpu_ids length must equal direct_worker_count * direct_tp_size: "
f"{len(direct_gpu_ids)} != {direct_worker_count * direct_tp_size}"
)
assigned_gpu_ids = prefill_gpu_ids + decode_gpu_ids + direct_gpu_ids
if len(set(assigned_gpu_ids)) != len(assigned_gpu_ids):
raise ValueError("prefill/decode/direct GPU IDs must be unique")
if any(gpu_id < 0 or gpu_id >= total_gpu_budget for gpu_id in assigned_gpu_ids):
raise ValueError(
"GPU IDs must fall within the single-node budget range "
f"[0, {total_gpu_budget - 1}]"
)
prefill_workers = tuple(
WorkerSpec(
role="prefill",
ordinal=idx,
gpu_ids=tuple(
prefill_gpu_ids[
idx * prefill_tp_size : (idx + 1) * prefill_tp_size
]
),
host=host,
port=prefill_port_base + idx,
bootstrap_port=bootstrap_port_base + idx,
)
for idx in range(prefill_worker_count)
)
decode_workers = tuple(
WorkerSpec(
role="decode",
ordinal=idx,
gpu_ids=tuple(
decode_gpu_ids[
idx * decode_tp_size : (idx + 1) * decode_tp_size
]
),
host=host,
port=decode_port_base + idx,
)
for idx in range(decode_worker_count)
)
direct_workers = tuple(
WorkerSpec(
role="direct",
ordinal=idx,
gpu_ids=tuple(
direct_gpu_ids[
idx * direct_tp_size : (idx + 1) * direct_tp_size
]
),
host=host,
port=direct_port_base + idx,
)
for idx in range(direct_worker_count)
)
return SingleNodeTopology(
model_path=model_path,
prefill_workers=prefill_workers,
decode_workers=decode_workers,
direct_workers=direct_workers,
router_host=host,
router_port=router_port,
transfer_backend=transfer_backend,
trust_remote_code=trust_remote_code,
force_rdma=force_rdma,
ib_device=ib_device,
extra_server_args=extra_server_args,
prefill_extra_server_args=prefill_extra_server_args,
decode_extra_server_args=decode_extra_server_args,
direct_extra_server_args=direct_extra_server_args,
)

View File

@@ -0,0 +1,106 @@
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
@dataclass(frozen=True)
class TraceRequest:
request_id: str
session_id: str
chat_id: int
parent_chat_id: int
timestamp_s: float
input_length: int
output_length: int
request_type: str
turn_id: int
hash_ids: tuple[int, ...]
def load_trace(path: Path, *, request_limit: int | None = None) -> list[TraceRequest]:
chat_to_session: dict[int, str] = {}
requests: list[TraceRequest] = []
with path.open("r", encoding="utf-8") as handle:
for index, line in enumerate(handle):
if request_limit is not None and len(requests) >= request_limit:
break
payload = json.loads(line)
chat_id = int(payload["chat_id"])
parent_chat_id = int(payload["parent_chat_id"])
session_id = _resolve_session_id(
chat_id=chat_id,
parent_chat_id=parent_chat_id,
chat_to_session=chat_to_session,
)
turn_id = int(payload["turn"])
request_id = f"{session_id}:{turn_id}:{chat_id}:{index}"
requests.append(
TraceRequest(
request_id=request_id,
session_id=session_id,
chat_id=chat_id,
parent_chat_id=parent_chat_id,
timestamp_s=float(payload["timestamp"]),
input_length=int(payload["input_length"]),
output_length=int(payload["output_length"]),
request_type=str(payload["type"]),
turn_id=turn_id,
hash_ids=tuple(int(item) for item in payload.get("hash_ids", [])),
)
)
return requests
def build_synthetic_prompt(
request: TraceRequest,
*,
block_token_budget: int = 24,
) -> str:
return " ".join(build_synthetic_prompt_tokens(request, block_token_budget=block_token_budget))
def build_synthetic_prompt_tokens(
request: TraceRequest,
*,
block_token_budget: int = 24,
) -> list[str]:
tokens: list[str] = []
for hash_id in request.hash_ids:
for offset in range(block_token_budget):
tokens.append(f"blk{hash_id}_{offset}")
while len(tokens) < request.input_length:
tokens.append(f"fill_{len(tokens) % 64}")
return tokens[: request.input_length]
def build_synthetic_append_chunk(
request: TraceRequest,
append_length: int,
) -> str:
if append_length <= 0:
return ""
return " ".join(
f"turn{request.turn_id}_append_{request.chat_id}_{offset}"
for offset in range(append_length)
)
def _resolve_session_id(
*,
chat_id: int,
parent_chat_id: int,
chat_to_session: dict[int, str],
) -> str:
if parent_chat_id < 0:
session_id = str(chat_id)
else:
session_id = chat_to_session.get(parent_chat_id, str(parent_chat_id))
chat_to_session[chat_id] = session_id
return session_id

View File

@@ -0,0 +1,127 @@
from __future__ import annotations
import json
from collections import defaultdict
from dataclasses import asdict, dataclass
from math import ceil
from pathlib import Path
from agentic_pd_hybrid.trace import TraceRequest, load_trace
BLOCK_TOKEN_BUDGET = 24
@dataclass(frozen=True)
class NormalizeTraceLengthsConfig:
trace_path: Path
output_path: Path
initial_input_length: int = 10_000
append_input_length: int = 1_000
output_length: int = 1_000
max_requests: int | None = None
@dataclass(frozen=True)
class NormalizeTraceLengthsSummary:
input_trace_path: str
output_trace_path: str
request_count: int
session_count: int
multi_turn_session_count: int
initial_input_length: int
append_input_length: int
output_length: int
max_turns_per_session: int
max_input_length: int
def normalize_trace_lengths(
config: NormalizeTraceLengthsConfig,
) -> NormalizeTraceLengthsSummary:
if config.initial_input_length < 0:
raise ValueError("initial_input_length must be >= 0")
if config.append_input_length < 0:
raise ValueError("append_input_length must be >= 0")
if config.output_length < 0:
raise ValueError("output_length must be >= 0")
requests = load_trace(config.trace_path, request_limit=config.max_requests)
sessions: dict[str, list[TraceRequest]] = defaultdict(list)
for request in requests:
sessions[request.session_id].append(request)
normalized_records: list[dict[str, object]] = []
max_turns_per_session = 0
max_input_length = 0
for session_idx, session_id in enumerate(sorted(sessions, key=_session_sort_key)):
session_requests = sorted(
sessions[session_id],
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
)
max_turns_per_session = max(max_turns_per_session, len(session_requests))
base_block_count = ceil(config.initial_input_length / BLOCK_TOKEN_BUDGET)
base_hash_ids = [
_hash_id_for(session_idx=session_idx, block_idx=block_idx)
for block_idx in range(base_block_count)
]
for turn_idx, request in enumerate(session_requests):
input_length = config.initial_input_length + turn_idx * (
config.append_input_length + config.output_length
)
total_block_count = ceil(input_length / BLOCK_TOKEN_BUDGET)
hash_ids = base_hash_ids + [
_hash_id_for(
session_idx=session_idx,
block_idx=base_block_count + append_block_idx,
)
for append_block_idx in range(max(0, total_block_count - base_block_count))
]
max_input_length = max(max_input_length, input_length)
normalized_records.append(
{
"chat_id": request.chat_id,
"parent_chat_id": request.parent_chat_id,
"timestamp": request.timestamp_s,
"input_length": input_length,
"output_length": config.output_length,
"type": request.request_type,
"turn": request.turn_id,
"hash_ids": hash_ids,
}
)
normalized_records.sort(key=lambda item: float(item["timestamp"]))
config.output_path.parent.mkdir(parents=True, exist_ok=True)
with config.output_path.open("w", encoding="utf-8") as handle:
for record in normalized_records:
handle.write(json.dumps(record, sort_keys=True) + "\n")
summary = NormalizeTraceLengthsSummary(
input_trace_path=str(config.trace_path),
output_trace_path=str(config.output_path),
request_count=len(normalized_records),
session_count=len(sessions),
multi_turn_session_count=sum(
1 for session_requests in sessions.values() if len(session_requests) > 1
),
initial_input_length=config.initial_input_length,
append_input_length=config.append_input_length,
output_length=config.output_length,
max_turns_per_session=max_turns_per_session,
max_input_length=max_input_length,
)
summary_path = config.output_path.with_suffix(config.output_path.suffix + ".summary.json")
with summary_path.open("w", encoding="utf-8") as handle:
json.dump(asdict(summary), handle, indent=2, sort_keys=True)
return summary
def _hash_id_for(*, session_idx: int, block_idx: int) -> int:
return session_idx * 1_000_000 + block_idx
def _session_sort_key(session_id: str) -> tuple[int, str]:
return (0, session_id) if session_id.isdigit() else (1, session_id)

4245
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff