Agentic workload PD separation analysis with trace-driven benchmarks
Systematic study of prefill-decode disaggregation for agentic LLM workloads using production GLM-5.1 coder trace (2.1M requests, 71B input tokens). Key findings: - Cache-aware routing improves TPOT p90 by 15% and APC from 20.8% to 44.7% without PD separation, matching PD-Sep's decode isolation benefit - PD separation adds +72% TTFT overhead (KV transfer) with no TPOT gain when using the same cache-aware scheduler - Prefill remains compute-bound even at 95% KV cache reuse (AI >1000x vs decode AI <2), but absolute FLOPs drop 71% from cache hits - For agentic MoE workloads, cache-aware routing > PD separation Infrastructure: - Trace sampler preserving session structure + hash_ids for prefix sharing - Async trace replayer with streaming TTFT/TPOT/E2E measurement - Unified cache-aware + token-level load-balanced global scheduler proxy supporting both PD-colocated and PD-disaggregated (Mooncake/RDMA) modes - vLLM 0.18.1 scheduler patch for KV transfer abort race condition - Roofline analysis tool for prefill/decode compute characterization Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
0
replayer/__init__.py
Normal file
0
replayer/__init__.py
Normal file
55
replayer/__main__.py
Normal file
55
replayer/__main__.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""CLI entry point: python -m replayer replay ..."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from .replay import ReplayConfig, replay_trace
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description="Trace replayer for vLLM benchmarking")
|
||||
p.add_argument("--trace", type=Path, required=True, help="Sampled trace JSONL")
|
||||
p.add_argument("--output", type=Path, required=True, help="Output metrics JSONL")
|
||||
p.add_argument("--endpoint", type=str, required=True,
|
||||
help="vLLM server URL (e.g. http://localhost:8000)")
|
||||
p.add_argument("--model", type=str, default="default", help="Model name for API")
|
||||
p.add_argument("--time-scale", type=float, default=1.0,
|
||||
help="Time compression (>1 = faster)")
|
||||
p.add_argument("--max-inflight-sessions", type=int, default=32)
|
||||
p.add_argument("--concurrency-limit", type=int, default=256)
|
||||
p.add_argument("--request-timeout", type=float, default=600.0)
|
||||
p.add_argument("--request-limit", type=int, default=None,
|
||||
help="Limit number of requests to replay")
|
||||
p.add_argument("-v", "--verbose", action="store_true")
|
||||
args = p.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
config = ReplayConfig(
|
||||
trace_path=args.trace,
|
||||
output_path=args.output,
|
||||
endpoint_url=args.endpoint.rstrip("/"),
|
||||
model_name=args.model,
|
||||
time_scale=args.time_scale,
|
||||
max_inflight_sessions=args.max_inflight_sessions,
|
||||
concurrency_limit=args.concurrency_limit,
|
||||
request_timeout_s=args.request_timeout,
|
||||
request_limit=args.request_limit,
|
||||
)
|
||||
|
||||
results = asyncio.run(replay_trace(config))
|
||||
succeeded = sum(1 for r in results if r.error is None)
|
||||
print(f"\nDone: {succeeded}/{len(results)} requests succeeded")
|
||||
print(f"Metrics: {args.output}")
|
||||
print(f"Summary: {args.output.with_suffix('.summary.json')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
107
replayer/metrics.py
Normal file
107
replayer/metrics.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Per-request metrics collection and summary reporting."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import statistics
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestMetrics:
|
||||
request_id: str
|
||||
session_id: str
|
||||
turn_id: int
|
||||
trace_timestamp_s: float
|
||||
input_length: int
|
||||
output_length: int
|
||||
request_type: str
|
||||
effective_input_length: int | None
|
||||
cached_tokens: int
|
||||
latency_s: float | None
|
||||
ttft_s: float | None
|
||||
tpot_s: float | None
|
||||
actual_output_tokens: int | None = None
|
||||
requested_output_tokens: int | None = None
|
||||
finish_reason: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class IncrementalMetricSink:
|
||||
"""Append each RequestMetrics to JSONL immediately (crash-safe)."""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
self.path = path
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text("")
|
||||
self._lock = asyncio.Lock()
|
||||
self._fh = path.open("a", encoding="utf-8", buffering=1)
|
||||
|
||||
async def append(self, metric: RequestMetrics) -> None:
|
||||
line = json.dumps(asdict(metric), sort_keys=True) + "\n"
|
||||
async with self._lock:
|
||||
self._fh.write(line)
|
||||
self._fh.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self._fh.flush()
|
||||
self._fh.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def write_summary_json(path: Path, rows: list[RequestMetrics]) -> None:
|
||||
successful = [r for r in rows if r.error is None]
|
||||
latencies = [r.latency_s for r in successful if r.latency_s is not None]
|
||||
ttfts = [r.ttft_s for r in successful if r.ttft_s is not None]
|
||||
tpots = [r.tpot_s for r in successful if r.tpot_s is not None]
|
||||
|
||||
total_input = sum(r.input_length for r in successful)
|
||||
total_cached = sum(r.cached_tokens for r in successful)
|
||||
|
||||
summary: dict[str, Any] = {
|
||||
"request_count": len(rows),
|
||||
"success_count": len(successful),
|
||||
"error_count": sum(1 for r in rows if r.error is not None),
|
||||
"latency_stats_s": _stats(latencies),
|
||||
"ttft_stats_s": _stats(ttfts),
|
||||
"tpot_stats_s": _stats(tpots),
|
||||
"cache_hit_request_count": sum(1 for r in successful if r.cached_tokens > 0),
|
||||
"total_input_tokens": total_input,
|
||||
"total_cached_tokens": total_cached,
|
||||
"prefix_cache_hit_ratio": total_cached / total_input if total_input > 0 else 0.0,
|
||||
"cached_tokens_stats": _stats([float(r.cached_tokens) for r in successful]),
|
||||
"actual_output_tokens_stats": _stats(
|
||||
[float(r.actual_output_tokens) for r in successful
|
||||
if r.actual_output_tokens is not None]
|
||||
),
|
||||
}
|
||||
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", encoding="utf-8") as fh:
|
||||
json.dump(summary, fh, indent=2, sort_keys=True)
|
||||
|
||||
|
||||
def _stats(values: list[float | None]) -> dict[str, float] | None:
|
||||
clean = [v for v in values if v is not None]
|
||||
if not clean:
|
||||
return None
|
||||
clean.sort()
|
||||
return {
|
||||
"count": float(len(clean)),
|
||||
"mean": statistics.fmean(clean),
|
||||
"p50": _percentile(clean, 0.50),
|
||||
"p90": _percentile(clean, 0.90),
|
||||
"p99": _percentile(clean, 0.99),
|
||||
}
|
||||
|
||||
|
||||
def _percentile(sorted_vals: list[float], pct: float) -> float:
|
||||
if len(sorted_vals) == 1:
|
||||
return sorted_vals[0]
|
||||
idx = round((len(sorted_vals) - 1) * pct)
|
||||
return sorted_vals[idx]
|
||||
343
replayer/replay.py
Normal file
343
replayer/replay.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""Trace replayer — send requests to vLLM following trace timing.
|
||||
|
||||
Supports both vLLM's /v1/completions (OpenAI-compatible) and /generate
|
||||
(SGLang-style) endpoints. Uses hash_ids from the trace to construct
|
||||
synthetic prompts that reproduce realistic prefix-cache hit patterns.
|
||||
|
||||
Key behaviors:
|
||||
- Per-session sequencing: turns within a session are sent in order,
|
||||
each waiting for the previous to complete before dispatching.
|
||||
- Inter-session arrival: sessions start at their trace timestamps,
|
||||
scaled by --time-scale.
|
||||
- Concurrency control: --max-inflight-sessions caps concurrent sessions;
|
||||
--concurrency-limit caps total in-flight requests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import random as _random
|
||||
|
||||
import httpx
|
||||
|
||||
from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json
|
||||
from .trace import TraceRequest, load_trace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BLOCK_SIZE = 512
|
||||
VOCAB_SIZE = 151936
|
||||
TOKEN_RANGE_START = 100
|
||||
TOKEN_RANGE_END = VOCAB_SIZE - 100
|
||||
|
||||
_block_cache: dict[int, list[int]] = {}
|
||||
|
||||
|
||||
def _hash_id_to_token_ids(hash_id: int) -> list[int]:
|
||||
"""Deterministically map a hash_id to BLOCK_SIZE token IDs."""
|
||||
if hash_id in _block_cache:
|
||||
return _block_cache[hash_id]
|
||||
rng = _random.Random(hash_id)
|
||||
ids = [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END) for _ in range(BLOCK_SIZE)]
|
||||
_block_cache[hash_id] = ids
|
||||
return ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplayConfig:
|
||||
trace_path: Path
|
||||
output_path: Path
|
||||
endpoint_url: str # comma-separated for round-robin: "http://host:8000,http://host:8001"
|
||||
time_scale: float = 1.0
|
||||
max_inflight_sessions: int = 32
|
||||
concurrency_limit: int = 256
|
||||
request_timeout_s: float = 600.0
|
||||
request_limit: int | None = None
|
||||
model_name: str = "default"
|
||||
|
||||
|
||||
def _build_prompt_token_ids(req: TraceRequest) -> list[int]:
|
||||
"""Build token IDs from hash_ids for prefix-cache-aware replay.
|
||||
|
||||
Same hash_id prefix → same token ID prefix → APC cache hit in vLLM.
|
||||
"""
|
||||
ids: list[int] = []
|
||||
for hid in req.hash_ids:
|
||||
ids.extend(_hash_id_to_token_ids(hid))
|
||||
# Pad to input_length with deterministic tokens
|
||||
pad_rng = _random.Random(req.chat_id)
|
||||
while len(ids) < req.input_length:
|
||||
ids.append(pad_rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END))
|
||||
return ids[:req.input_length]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SessionState:
|
||||
session_id: str
|
||||
turns: list[TraceRequest]
|
||||
metrics: list[RequestMetrics] = field(default_factory=list)
|
||||
|
||||
|
||||
_endpoint_counter = 0
|
||||
|
||||
|
||||
def _pick_endpoint(config: ReplayConfig) -> str:
|
||||
"""Round-robin across comma-separated endpoints."""
|
||||
global _endpoint_counter
|
||||
endpoints = [e.strip() for e in config.endpoint_url.split(",")]
|
||||
url = endpoints[_endpoint_counter % len(endpoints)]
|
||||
_endpoint_counter += 1
|
||||
return url
|
||||
|
||||
|
||||
async def _dispatch_request(
|
||||
*,
|
||||
client: httpx.AsyncClient,
|
||||
config: ReplayConfig,
|
||||
req: TraceRequest,
|
||||
prompt_token_ids: list[int],
|
||||
sem: asyncio.Semaphore,
|
||||
) -> RequestMetrics:
|
||||
"""Send one request via /v1/completions (streaming) and collect metrics."""
|
||||
endpoint = _pick_endpoint(config)
|
||||
payload = {
|
||||
"model": config.model_name,
|
||||
"prompt": prompt_token_ids,
|
||||
"max_tokens": max(1, req.output_length),
|
||||
"temperature": 0,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
|
||||
start = time.perf_counter()
|
||||
ttft_s = None
|
||||
n_output = 0
|
||||
cached_tokens = 0
|
||||
finish_reason = None
|
||||
err = None
|
||||
token_times: list[float] = []
|
||||
|
||||
async with sem:
|
||||
try:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{endpoint}/v1/completions",
|
||||
json=payload,
|
||||
timeout=config.request_timeout_s,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for raw_line in resp.aiter_lines():
|
||||
if not raw_line or not raw_line.startswith("data:"):
|
||||
continue
|
||||
data = raw_line[5:].strip()
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
now = time.perf_counter()
|
||||
if ttft_s is None:
|
||||
ttft_s = now - start
|
||||
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("text", "")
|
||||
if delta:
|
||||
token_times.append(now)
|
||||
fr = choices[0].get("finish_reason")
|
||||
if fr:
|
||||
finish_reason = fr
|
||||
|
||||
usage = chunk.get("usage")
|
||||
if usage:
|
||||
n_output = usage.get("completion_tokens", n_output)
|
||||
cached_tokens = _extract_cached_tokens(usage)
|
||||
except Exception as exc:
|
||||
err = repr(exc)[:300]
|
||||
|
||||
end = time.perf_counter()
|
||||
e2e = end - start
|
||||
if n_output == 0 and token_times:
|
||||
n_output = len(token_times)
|
||||
|
||||
tpot = 0.0
|
||||
if len(token_times) > 1:
|
||||
inter_token = [token_times[i+1] - token_times[i]
|
||||
for i in range(len(token_times) - 1)]
|
||||
tpot = sum(inter_token) / len(inter_token)
|
||||
|
||||
return RequestMetrics(
|
||||
request_id=req.request_id,
|
||||
session_id=req.session_id,
|
||||
turn_id=req.turn_id,
|
||||
trace_timestamp_s=req.timestamp_s,
|
||||
input_length=req.input_length,
|
||||
output_length=req.output_length,
|
||||
request_type=req.request_type,
|
||||
effective_input_length=len(prompt_token_ids),
|
||||
cached_tokens=cached_tokens,
|
||||
latency_s=e2e,
|
||||
ttft_s=ttft_s,
|
||||
tpot_s=tpot,
|
||||
actual_output_tokens=n_output,
|
||||
requested_output_tokens=req.output_length,
|
||||
finish_reason=finish_reason,
|
||||
error=err,
|
||||
)
|
||||
|
||||
|
||||
def _extract_cached_tokens(usage: dict) -> int:
|
||||
ct = 0
|
||||
details = usage.get("prompt_tokens_details")
|
||||
if isinstance(details, dict):
|
||||
ct = details.get("cached_tokens", 0) or 0
|
||||
if ct == 0:
|
||||
ct = usage.get("cached_tokens", 0) or 0
|
||||
return int(ct)
|
||||
|
||||
|
||||
async def _run_session(
|
||||
*,
|
||||
state: _SessionState,
|
||||
config: ReplayConfig,
|
||||
client: httpx.AsyncClient,
|
||||
session_sem: asyncio.Semaphore,
|
||||
request_sem: asyncio.Semaphore,
|
||||
earliest_ts: float,
|
||||
sweep_start: float,
|
||||
sink: IncrementalMetricSink,
|
||||
) -> list[RequestMetrics]:
|
||||
async with session_sem:
|
||||
# Wait until this session's start time
|
||||
offset = (state.turns[0].timestamp_s - earliest_ts) / config.time_scale
|
||||
wait = offset - (time.perf_counter() - sweep_start)
|
||||
if wait > 0:
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
for req in state.turns:
|
||||
# Intra-session: wait for turn's relative offset
|
||||
if req != state.turns[0]:
|
||||
target = (req.timestamp_s - state.turns[0].timestamp_s) / config.time_scale
|
||||
elapsed = time.perf_counter() - sweep_start - offset
|
||||
if elapsed < target:
|
||||
await asyncio.sleep(target - elapsed)
|
||||
|
||||
token_ids = _build_prompt_token_ids(req)
|
||||
metric = await _dispatch_request(
|
||||
client=client, config=config, req=req,
|
||||
prompt_token_ids=token_ids, sem=request_sem,
|
||||
)
|
||||
state.metrics.append(metric)
|
||||
await sink.append(metric)
|
||||
|
||||
return state.metrics
|
||||
|
||||
|
||||
async def _snapshot_prefix_cache_metrics(url_csv: str) -> dict[str, float]:
|
||||
"""Scrape vLLM /metrics for prefix cache counters (aggregated across endpoints)."""
|
||||
total = {"queries": 0.0, "hits": 0.0}
|
||||
endpoints = [e.strip() for e in url_csv.split(",")]
|
||||
async with httpx.AsyncClient(timeout=10) as c:
|
||||
for url in endpoints:
|
||||
try:
|
||||
r = await c.get(f"{url}/metrics")
|
||||
for line in r.text.split("\n"):
|
||||
if line.startswith("vllm:prefix_cache_queries_total"):
|
||||
total["queries"] += float(line.split()[-1])
|
||||
elif line.startswith("vllm:prefix_cache_hits_total"):
|
||||
total["hits"] += float(line.split()[-1])
|
||||
except Exception:
|
||||
pass
|
||||
return total
|
||||
|
||||
|
||||
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
|
||||
"""Main entry: load trace, replay against endpoint, return metrics."""
|
||||
requests = load_trace(config.trace_path, request_limit=config.request_limit)
|
||||
if not requests:
|
||||
return []
|
||||
|
||||
by_session: dict[str, list[TraceRequest]] = defaultdict(list)
|
||||
for r in requests:
|
||||
by_session[r.session_id].append(r)
|
||||
for sid in by_session:
|
||||
by_session[sid].sort(key=lambda r: (r.turn_id, r.timestamp_s))
|
||||
|
||||
sessions = sorted(by_session.items(), key=lambda kv: kv[1][0].timestamp_s)
|
||||
earliest_ts = sessions[0][1][0].timestamp_s
|
||||
|
||||
session_sem = asyncio.Semaphore(config.max_inflight_sessions)
|
||||
request_sem = asyncio.Semaphore(config.concurrency_limit)
|
||||
|
||||
sink = IncrementalMetricSink(config.output_path)
|
||||
|
||||
n_sessions = len(sessions)
|
||||
n_requests = len(requests)
|
||||
logger.info("Replaying %d sessions (%d requests), time_scale=%.1f",
|
||||
n_sessions, n_requests, config.time_scale)
|
||||
|
||||
pre_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
|
||||
sweep_start = time.perf_counter()
|
||||
|
||||
try:
|
||||
limits = httpx.Limits(
|
||||
max_connections=2000,
|
||||
max_keepalive_connections=500,
|
||||
keepalive_expiry=30.0,
|
||||
)
|
||||
async with httpx.AsyncClient(
|
||||
timeout=config.request_timeout_s,
|
||||
trust_env=False,
|
||||
limits=limits,
|
||||
) as client:
|
||||
tasks = [
|
||||
asyncio.create_task(_run_session(
|
||||
state=_SessionState(session_id=sid, turns=turns),
|
||||
config=config, client=client,
|
||||
session_sem=session_sem, request_sem=request_sem,
|
||||
earliest_ts=earliest_ts, sweep_start=sweep_start,
|
||||
sink=sink,
|
||||
))
|
||||
for sid, turns in sessions
|
||||
]
|
||||
all_results = await asyncio.gather(*tasks)
|
||||
finally:
|
||||
sink.close()
|
||||
|
||||
sweep_elapsed = time.perf_counter() - sweep_start
|
||||
post_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
|
||||
|
||||
flat = [m for group in all_results for m in group]
|
||||
summary_path = config.output_path.with_suffix(".summary.json")
|
||||
write_summary_json(summary_path, flat)
|
||||
|
||||
# Compute aggregate prefix cache hit ratio from /metrics deltas
|
||||
delta_queries = post_metrics.get("queries", 0) - pre_metrics.get("queries", 0)
|
||||
delta_hits = post_metrics.get("hits", 0) - pre_metrics.get("hits", 0)
|
||||
hit_ratio = delta_hits / delta_queries if delta_queries > 0 else 0.0
|
||||
|
||||
logger.info("Done: %d/%d succeeded in %.1fs", sum(1 for m in flat if m.error is None), len(flat), sweep_elapsed)
|
||||
logger.info("Prefix cache: %.1f%% hit ratio (%d/%d tokens)",
|
||||
hit_ratio * 100, int(delta_hits), int(delta_queries))
|
||||
|
||||
# Append cache stats to summary
|
||||
import json as _json
|
||||
summary = _json.loads(summary_path.read_text())
|
||||
summary["prefix_cache_queries_tokens"] = int(delta_queries)
|
||||
summary["prefix_cache_hits_tokens"] = int(delta_hits)
|
||||
summary["prefix_cache_hit_ratio"] = hit_ratio
|
||||
summary["wall_clock_s"] = sweep_elapsed
|
||||
summary_path.write_text(_json.dumps(summary, indent=2, sort_keys=True))
|
||||
|
||||
logger.info("Summary written to %s", summary_path)
|
||||
return flat
|
||||
84
replayer/trace.py
Normal file
84
replayer/trace.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Trace data structures and loader for the Ali agentic-coder trace format.
|
||||
|
||||
Trace format (one JSON per line):
|
||||
chat_id, parent_chat_id, timestamp, input_length, output_length,
|
||||
type, turn, hash_ids[]
|
||||
|
||||
Sessions are derived from parent_chat_id chains:
|
||||
- parent_chat_id == -1 → new session root
|
||||
- parent_chat_id >= 0 → belongs to the same session as the parent
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TraceRequest:
|
||||
request_id: str
|
||||
session_id: str
|
||||
chat_id: int
|
||||
parent_chat_id: int
|
||||
timestamp_s: float
|
||||
input_length: int
|
||||
output_length: int
|
||||
request_type: str
|
||||
turn_id: int
|
||||
hash_ids: tuple[int, ...]
|
||||
|
||||
|
||||
def load_trace(
|
||||
path: Path,
|
||||
*,
|
||||
request_limit: int | None = None,
|
||||
) -> list[TraceRequest]:
|
||||
"""Load trace and resolve session IDs from parent_chat_id chains."""
|
||||
chat_to_session: dict[int, str] = {}
|
||||
requests: list[TraceRequest] = []
|
||||
|
||||
with path.open("r", encoding="utf-8") as fh:
|
||||
for idx, line in enumerate(fh):
|
||||
if request_limit is not None and len(requests) >= request_limit:
|
||||
break
|
||||
row = json.loads(line)
|
||||
chat_id = int(row["chat_id"])
|
||||
parent_chat_id = int(row["parent_chat_id"])
|
||||
|
||||
if "session_id" in row:
|
||||
session_id = str(row["session_id"])
|
||||
else:
|
||||
session_id = _resolve_session_id(
|
||||
chat_id, parent_chat_id, chat_to_session,
|
||||
)
|
||||
chat_to_session[chat_id] = session_id
|
||||
|
||||
requests.append(TraceRequest(
|
||||
request_id=f"{session_id}:{row['turn']}:{chat_id}:{idx}",
|
||||
session_id=session_id,
|
||||
chat_id=chat_id,
|
||||
parent_chat_id=parent_chat_id,
|
||||
timestamp_s=float(row["timestamp"]),
|
||||
input_length=int(row["input_length"]),
|
||||
output_length=int(row["output_length"]),
|
||||
request_type=str(row["type"]),
|
||||
turn_id=int(row["turn"]),
|
||||
hash_ids=tuple(int(h) for h in row.get("hash_ids", [])),
|
||||
))
|
||||
|
||||
return requests
|
||||
|
||||
|
||||
def _resolve_session_id(
|
||||
chat_id: int,
|
||||
parent_chat_id: int,
|
||||
chat_to_session: dict[int, str],
|
||||
) -> str:
|
||||
if parent_chat_id < 0:
|
||||
session_id = str(chat_id)
|
||||
else:
|
||||
session_id = chat_to_session.get(parent_chat_id, str(parent_chat_id))
|
||||
chat_to_session[chat_id] = session_id
|
||||
return session_id
|
||||
Reference in New Issue
Block a user