Files
agentic-pd-hybrid/scripts/prepare_real_ali_samples.py

451 lines
16 KiB
Python
Executable File

#!/usr/bin/env python3
"""Prepare balanced real-Ali trace samples for KVC experiments.
The generic sampler is duration-oriented and can be dominated by one long
session. This script keeps real request lengths/timestamps but caps turns per
session so live sweeps can compare policies on a repeatable multi-session
workload.
"""
from __future__ import annotations
import argparse
import json
import statistics
from collections import defaultdict
from dataclasses import asdict, dataclass
from pathlib import Path
from agentic_pd_hybrid.trace import TraceRequest, load_trace
@dataclass(frozen=True)
class SampleSummary:
input_trace_path: str
output_trace_path: str
profile: str
request_count: int
session_count: int
multi_turn_session_count: int
turn2plus_count: int
direct_eligible_turn2plus_count: int
direct_eligible_turn2plus_ratio: float
missing_parent_count: int
max_sessions: int
max_turns_per_session: int
start_time_s: float
end_time_s: float
sampled_duration_s: float
rebased_timestamps: bool
input_tokens: dict[str, float] | None
output_tokens: dict[str, float] | None
append_tokens: dict[str, float] | None
inter_turn_gap_s: dict[str, float] | None
overlap_ratio: dict[str, float] | None
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--trace", type=Path, required=True)
parser.add_argument("--output-root", type=Path, required=True)
parser.add_argument("--max-sessions", type=int, default=64)
parser.add_argument("--max-turns-per-session", type=int, default=12)
parser.add_argument("--start-time-s", type=float, default=0.0)
parser.add_argument(
"--window-duration-s",
type=float,
default=None,
help=(
"If set, also write continuous-window samples that keep only requests "
"inside [start-time, start-time + window-duration]."
),
)
parser.add_argument(
"--window-target-requests",
type=int,
default=None,
help=(
"For continuous-window samples, select whole sessions across time "
"buckets until at least this many requests are included. This keeps "
"the window span while making live runs tractable."
),
)
parser.add_argument(
"--window-buckets",
type=int,
default=15,
help="Number of time buckets used with --window-target-requests.",
)
parser.add_argument(
"--window-min-turns",
type=int,
default=1,
help=(
"Minimum number of in-window turns per selected session for "
"continuous-window samples."
),
)
parser.add_argument(
"--window-output-name",
default="ali-window.jsonl",
help="Output filename for the continuous-window sample.",
)
parser.add_argument(
"--max-sampled-duration-s",
type=float,
default=None,
help=(
"For balanced profile samples, drop requests after the first selected "
"timestamp plus this duration. Use only for quick smoke runs; headline "
"runs should preserve the full sampled span."
),
)
parser.add_argument(
"--profiles",
nargs="+",
default=["representative-mt", "kvc-fit-smallappend"],
choices=["representative-mt", "kvc-fit-smallappend"],
)
parser.add_argument(
"--no-rebase-timestamps",
action="store_true",
help="Keep original timestamps instead of shifting the sample to start at 0.",
)
args = parser.parse_args()
requests = load_trace(args.trace)
sessions: dict[str, list[TraceRequest]] = defaultdict(list)
for request in requests:
sessions[request.session_id].append(request)
args.output_root.mkdir(parents=True, exist_ok=True)
if args.window_duration_s is not None:
if args.window_target_requests is None:
selected = _select_window(
requests=requests,
start_time_s=args.start_time_s,
window_duration_s=args.window_duration_s,
)
profile = "window"
else:
selected = _select_window_session_sample(
sessions=sessions,
start_time_s=args.start_time_s,
window_duration_s=args.window_duration_s,
target_requests=args.window_target_requests,
bucket_count=args.window_buckets,
min_turns=args.window_min_turns,
)
profile = (
"window-session-sample"
if args.window_min_turns <= 1
else f"window-session-sample-min{args.window_min_turns}turns"
)
output_path = args.output_root / args.window_output_name
summary = _write_sample(
selected=selected,
input_trace_path=args.trace,
output_path=output_path,
profile=profile,
max_sessions=args.max_sessions,
max_turns_per_session=args.max_turns_per_session,
rebase_timestamps=not args.no_rebase_timestamps,
)
print(
f"window: wrote {summary.request_count} requests from "
f"{summary.session_count} sessions to {output_path}"
)
for profile in args.profiles:
selected = _select_profile(
sessions=sessions,
profile=profile,
start_time_s=args.start_time_s,
max_sessions=args.max_sessions,
max_turns_per_session=args.max_turns_per_session,
max_sampled_duration_s=args.max_sampled_duration_s,
)
output_path = args.output_root / f"ali-{profile}.jsonl"
summary = _write_sample(
selected=selected,
input_trace_path=args.trace,
output_path=output_path,
profile=profile,
max_sessions=args.max_sessions,
max_turns_per_session=args.max_turns_per_session,
rebase_timestamps=not args.no_rebase_timestamps,
)
print(
f"{profile}: wrote {summary.request_count} requests from "
f"{summary.session_count} sessions to {output_path}"
)
def _select_profile(
*,
sessions: dict[str, list[TraceRequest]],
profile: str,
start_time_s: float,
max_sessions: int,
max_turns_per_session: int,
max_sampled_duration_s: float | None,
) -> list[TraceRequest]:
eligible: list[list[TraceRequest]] = []
for session_requests in sessions.values():
ordered = _ordered(session_requests)
if len(ordered) < 2:
continue
if ordered[0].timestamp_s < start_time_s:
continue
if profile == "kvc-fit-smallappend" and not _is_kvc_fit_smallappend(ordered):
continue
eligible.append(ordered[:max_turns_per_session])
eligible.sort(key=lambda items: (items[0].timestamp_s, items[0].session_id))
selected_sessions = eligible[:max_sessions]
selected = [request for items in selected_sessions for request in items]
selected.sort(key=lambda request: (request.timestamp_s, request.chat_id))
if selected and max_sampled_duration_s is not None:
first_ts = selected[0].timestamp_s
end_ts = first_ts + max_sampled_duration_s
selected = [
request for request in selected if request.timestamp_s <= end_ts
]
return selected
def _select_window(
*,
requests: list[TraceRequest],
start_time_s: float,
window_duration_s: float,
) -> list[TraceRequest]:
end_time_s = start_time_s + window_duration_s
selected = [
request
for request in requests
if start_time_s <= request.timestamp_s <= end_time_s
]
selected.sort(key=lambda request: (request.timestamp_s, request.chat_id))
return selected
def _select_window_session_sample(
*,
sessions: dict[str, list[TraceRequest]],
start_time_s: float,
window_duration_s: float,
target_requests: int,
bucket_count: int,
min_turns: int,
) -> list[TraceRequest]:
if target_requests <= 0:
raise ValueError("--window-target-requests must be positive")
if bucket_count <= 0:
raise ValueError("--window-buckets must be positive")
if min_turns <= 0:
raise ValueError("--window-min-turns must be positive")
end_time_s = start_time_s + window_duration_s
bucket_width_s = window_duration_s / bucket_count
buckets: list[list[list[TraceRequest]]] = [[] for _ in range(bucket_count)]
for session_requests in sessions.values():
ordered = _ordered(session_requests)
if not ordered:
continue
first = ordered[0]
if first.timestamp_s < start_time_s or first.timestamp_s > end_time_s:
continue
in_window = [
request
for request in ordered
if start_time_s <= request.timestamp_s <= end_time_s
]
if len(in_window) < min_turns:
continue
bucket_index = min(
bucket_count - 1,
int((first.timestamp_s - start_time_s) / bucket_width_s),
)
buckets[bucket_index].append(in_window)
for bucket in buckets:
bucket.sort(key=lambda items: (items[0].timestamp_s, items[0].session_id))
selected_sessions: list[list[TraceRequest]] = []
selected_count = 0
positions = [0 for _ in range(bucket_count)]
while selected_count < target_requests:
progressed = False
for index, bucket in enumerate(buckets):
if positions[index] >= len(bucket):
continue
session_requests = bucket[positions[index]]
positions[index] += 1
selected_sessions.append(session_requests)
selected_count += len(session_requests)
progressed = True
if selected_count >= target_requests:
break
if not progressed:
break
selected = [request for items in selected_sessions for request in items]
selected.sort(key=lambda request: (request.timestamp_s, request.chat_id))
if len(selected) < target_requests:
raise ValueError(
f"window session sample selected only {len(selected)} requests; "
f"target was {target_requests}"
)
return selected
def _is_kvc_fit_smallappend(session_requests: list[TraceRequest]) -> bool:
initial = session_requests[0]
if initial.input_length < 2048 or initial.input_length > 16000:
return False
for request in session_requests:
if request.output_length > 2048:
return False
for previous, current in zip(session_requests, session_requests[1:], strict=False):
append_tokens = current.input_length - (
previous.input_length + previous.output_length
)
if append_tokens <= 0 or append_tokens > 2048:
return False
if _overlap_ratio(previous, current) < 0.75:
return False
return True
def _write_sample(
*,
selected: list[TraceRequest],
input_trace_path: Path,
output_path: Path,
profile: str,
max_sessions: int,
max_turns_per_session: int,
rebase_timestamps: bool,
) -> SampleSummary:
if not selected:
raise ValueError(f"profile {profile!r} selected no requests")
first_ts = selected[0].timestamp_s
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as handle:
for request in selected:
timestamp = request.timestamp_s - first_ts if rebase_timestamps else request.timestamp_s
payload = {
"chat_id": request.chat_id,
"parent_chat_id": request.parent_chat_id,
"timestamp": round(timestamp, 6),
"input_length": request.input_length,
"output_length": request.output_length,
"type": request.request_type,
"turn": request.turn_id,
"hash_ids": list(request.hash_ids),
}
handle.write(json.dumps(payload, sort_keys=True) + "\n")
sessions = defaultdict(list)
for request in selected:
sessions[request.session_id].append(request)
selected_chat_ids = {request.chat_id for request in selected}
missing_parent_count = sum(
1
for request in selected
if request.parent_chat_id >= 0 and request.parent_chat_id not in selected_chat_ids
)
append_values: list[float] = []
gap_values: list[float] = []
overlap_values: list[float] = []
direct_eligible_count = 0
for session_requests in sessions.values():
ordered = _ordered(session_requests)
for previous, current in zip(ordered, ordered[1:], strict=False):
append_tokens = current.input_length - (
previous.input_length + previous.output_length
)
overlap_ratio = _overlap_ratio(previous, current)
append_values.append(float(append_tokens))
gap_values.append(float(current.timestamp_s - previous.timestamp_s))
overlap_values.append(overlap_ratio)
if append_tokens > 0 and append_tokens <= 2048 and overlap_ratio > 0:
direct_eligible_count += 1
turn2plus_count = sum(max(0, len(items) - 1) for items in sessions.values())
start = min(request.timestamp_s for request in selected)
end = max(request.timestamp_s for request in selected)
summary = SampleSummary(
input_trace_path=str(input_trace_path),
output_trace_path=str(output_path),
profile=profile,
request_count=len(selected),
session_count=len(sessions),
multi_turn_session_count=sum(1 for items in sessions.values() if len(items) > 1),
turn2plus_count=turn2plus_count,
direct_eligible_turn2plus_count=direct_eligible_count,
direct_eligible_turn2plus_ratio=(
direct_eligible_count / turn2plus_count if turn2plus_count else 0.0
),
missing_parent_count=missing_parent_count,
max_sessions=max_sessions,
max_turns_per_session=max_turns_per_session,
start_time_s=0.0 if rebase_timestamps else start,
end_time_s=end - start if rebase_timestamps else end,
sampled_duration_s=end - start,
rebased_timestamps=rebase_timestamps,
input_tokens=_stats([float(request.input_length) for request in selected]),
output_tokens=_stats([float(request.output_length) for request in selected]),
append_tokens=_stats(append_values),
inter_turn_gap_s=_stats(gap_values),
overlap_ratio=_stats(overlap_values),
)
with output_path.with_suffix(output_path.suffix + ".summary.json").open(
"w", encoding="utf-8"
) as handle:
json.dump(asdict(summary), handle, indent=2, sort_keys=True)
return summary
def _ordered(session_requests: list[TraceRequest]) -> list[TraceRequest]:
return sorted(
session_requests,
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
)
def _overlap_ratio(previous: TraceRequest, current: TraceRequest) -> float:
if not current.hash_ids:
return 0.0
previous_blocks = set(previous.hash_ids)
overlap = sum(1 for block in current.hash_ids if block in previous_blocks)
return overlap / len(current.hash_ids)
def _stats(values: list[float]) -> dict[str, float] | None:
if not values:
return None
ordered = sorted(values)
return {
"count": float(len(ordered)),
"mean": statistics.fmean(ordered),
"min": ordered[0],
"p50": _percentile(ordered, 0.50),
"p90": _percentile(ordered, 0.90),
"p99": _percentile(ordered, 0.99),
"max": ordered[-1],
}
def _percentile(sorted_values: list[float], percentile: float) -> float:
if len(sorted_values) == 1:
return sorted_values[0]
return sorted_values[round((len(sorted_values) - 1) * percentile)]
if __name__ == "__main__":
main()