451 lines
16 KiB
Python
Executable File
451 lines
16 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Prepare balanced real-Ali trace samples for KVC experiments.
|
|
|
|
The generic sampler is duration-oriented and can be dominated by one long
|
|
session. This script keeps real request lengths/timestamps but caps turns per
|
|
session so live sweeps can compare policies on a repeatable multi-session
|
|
workload.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import statistics
|
|
from collections import defaultdict
|
|
from dataclasses import asdict, dataclass
|
|
from pathlib import Path
|
|
|
|
from agentic_pd_hybrid.trace import TraceRequest, load_trace
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SampleSummary:
|
|
input_trace_path: str
|
|
output_trace_path: str
|
|
profile: str
|
|
request_count: int
|
|
session_count: int
|
|
multi_turn_session_count: int
|
|
turn2plus_count: int
|
|
direct_eligible_turn2plus_count: int
|
|
direct_eligible_turn2plus_ratio: float
|
|
missing_parent_count: int
|
|
max_sessions: int
|
|
max_turns_per_session: int
|
|
start_time_s: float
|
|
end_time_s: float
|
|
sampled_duration_s: float
|
|
rebased_timestamps: bool
|
|
input_tokens: dict[str, float] | None
|
|
output_tokens: dict[str, float] | None
|
|
append_tokens: dict[str, float] | None
|
|
inter_turn_gap_s: dict[str, float] | None
|
|
overlap_ratio: dict[str, float] | None
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--trace", type=Path, required=True)
|
|
parser.add_argument("--output-root", type=Path, required=True)
|
|
parser.add_argument("--max-sessions", type=int, default=64)
|
|
parser.add_argument("--max-turns-per-session", type=int, default=12)
|
|
parser.add_argument("--start-time-s", type=float, default=0.0)
|
|
parser.add_argument(
|
|
"--window-duration-s",
|
|
type=float,
|
|
default=None,
|
|
help=(
|
|
"If set, also write continuous-window samples that keep only requests "
|
|
"inside [start-time, start-time + window-duration]."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--window-target-requests",
|
|
type=int,
|
|
default=None,
|
|
help=(
|
|
"For continuous-window samples, select whole sessions across time "
|
|
"buckets until at least this many requests are included. This keeps "
|
|
"the window span while making live runs tractable."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--window-buckets",
|
|
type=int,
|
|
default=15,
|
|
help="Number of time buckets used with --window-target-requests.",
|
|
)
|
|
parser.add_argument(
|
|
"--window-min-turns",
|
|
type=int,
|
|
default=1,
|
|
help=(
|
|
"Minimum number of in-window turns per selected session for "
|
|
"continuous-window samples."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--window-output-name",
|
|
default="ali-window.jsonl",
|
|
help="Output filename for the continuous-window sample.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-sampled-duration-s",
|
|
type=float,
|
|
default=None,
|
|
help=(
|
|
"For balanced profile samples, drop requests after the first selected "
|
|
"timestamp plus this duration. Use only for quick smoke runs; headline "
|
|
"runs should preserve the full sampled span."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--profiles",
|
|
nargs="+",
|
|
default=["representative-mt", "kvc-fit-smallappend"],
|
|
choices=["representative-mt", "kvc-fit-smallappend"],
|
|
)
|
|
parser.add_argument(
|
|
"--no-rebase-timestamps",
|
|
action="store_true",
|
|
help="Keep original timestamps instead of shifting the sample to start at 0.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
requests = load_trace(args.trace)
|
|
sessions: dict[str, list[TraceRequest]] = defaultdict(list)
|
|
for request in requests:
|
|
sessions[request.session_id].append(request)
|
|
|
|
args.output_root.mkdir(parents=True, exist_ok=True)
|
|
if args.window_duration_s is not None:
|
|
if args.window_target_requests is None:
|
|
selected = _select_window(
|
|
requests=requests,
|
|
start_time_s=args.start_time_s,
|
|
window_duration_s=args.window_duration_s,
|
|
)
|
|
profile = "window"
|
|
else:
|
|
selected = _select_window_session_sample(
|
|
sessions=sessions,
|
|
start_time_s=args.start_time_s,
|
|
window_duration_s=args.window_duration_s,
|
|
target_requests=args.window_target_requests,
|
|
bucket_count=args.window_buckets,
|
|
min_turns=args.window_min_turns,
|
|
)
|
|
profile = (
|
|
"window-session-sample"
|
|
if args.window_min_turns <= 1
|
|
else f"window-session-sample-min{args.window_min_turns}turns"
|
|
)
|
|
output_path = args.output_root / args.window_output_name
|
|
summary = _write_sample(
|
|
selected=selected,
|
|
input_trace_path=args.trace,
|
|
output_path=output_path,
|
|
profile=profile,
|
|
max_sessions=args.max_sessions,
|
|
max_turns_per_session=args.max_turns_per_session,
|
|
rebase_timestamps=not args.no_rebase_timestamps,
|
|
)
|
|
print(
|
|
f"window: wrote {summary.request_count} requests from "
|
|
f"{summary.session_count} sessions to {output_path}"
|
|
)
|
|
|
|
for profile in args.profiles:
|
|
selected = _select_profile(
|
|
sessions=sessions,
|
|
profile=profile,
|
|
start_time_s=args.start_time_s,
|
|
max_sessions=args.max_sessions,
|
|
max_turns_per_session=args.max_turns_per_session,
|
|
max_sampled_duration_s=args.max_sampled_duration_s,
|
|
)
|
|
output_path = args.output_root / f"ali-{profile}.jsonl"
|
|
summary = _write_sample(
|
|
selected=selected,
|
|
input_trace_path=args.trace,
|
|
output_path=output_path,
|
|
profile=profile,
|
|
max_sessions=args.max_sessions,
|
|
max_turns_per_session=args.max_turns_per_session,
|
|
rebase_timestamps=not args.no_rebase_timestamps,
|
|
)
|
|
print(
|
|
f"{profile}: wrote {summary.request_count} requests from "
|
|
f"{summary.session_count} sessions to {output_path}"
|
|
)
|
|
|
|
|
|
def _select_profile(
|
|
*,
|
|
sessions: dict[str, list[TraceRequest]],
|
|
profile: str,
|
|
start_time_s: float,
|
|
max_sessions: int,
|
|
max_turns_per_session: int,
|
|
max_sampled_duration_s: float | None,
|
|
) -> list[TraceRequest]:
|
|
eligible: list[list[TraceRequest]] = []
|
|
for session_requests in sessions.values():
|
|
ordered = _ordered(session_requests)
|
|
if len(ordered) < 2:
|
|
continue
|
|
if ordered[0].timestamp_s < start_time_s:
|
|
continue
|
|
if profile == "kvc-fit-smallappend" and not _is_kvc_fit_smallappend(ordered):
|
|
continue
|
|
eligible.append(ordered[:max_turns_per_session])
|
|
|
|
eligible.sort(key=lambda items: (items[0].timestamp_s, items[0].session_id))
|
|
selected_sessions = eligible[:max_sessions]
|
|
selected = [request for items in selected_sessions for request in items]
|
|
selected.sort(key=lambda request: (request.timestamp_s, request.chat_id))
|
|
if selected and max_sampled_duration_s is not None:
|
|
first_ts = selected[0].timestamp_s
|
|
end_ts = first_ts + max_sampled_duration_s
|
|
selected = [
|
|
request for request in selected if request.timestamp_s <= end_ts
|
|
]
|
|
return selected
|
|
|
|
|
|
def _select_window(
|
|
*,
|
|
requests: list[TraceRequest],
|
|
start_time_s: float,
|
|
window_duration_s: float,
|
|
) -> list[TraceRequest]:
|
|
end_time_s = start_time_s + window_duration_s
|
|
selected = [
|
|
request
|
|
for request in requests
|
|
if start_time_s <= request.timestamp_s <= end_time_s
|
|
]
|
|
selected.sort(key=lambda request: (request.timestamp_s, request.chat_id))
|
|
return selected
|
|
|
|
|
|
def _select_window_session_sample(
|
|
*,
|
|
sessions: dict[str, list[TraceRequest]],
|
|
start_time_s: float,
|
|
window_duration_s: float,
|
|
target_requests: int,
|
|
bucket_count: int,
|
|
min_turns: int,
|
|
) -> list[TraceRequest]:
|
|
if target_requests <= 0:
|
|
raise ValueError("--window-target-requests must be positive")
|
|
if bucket_count <= 0:
|
|
raise ValueError("--window-buckets must be positive")
|
|
if min_turns <= 0:
|
|
raise ValueError("--window-min-turns must be positive")
|
|
|
|
end_time_s = start_time_s + window_duration_s
|
|
bucket_width_s = window_duration_s / bucket_count
|
|
buckets: list[list[list[TraceRequest]]] = [[] for _ in range(bucket_count)]
|
|
for session_requests in sessions.values():
|
|
ordered = _ordered(session_requests)
|
|
if not ordered:
|
|
continue
|
|
first = ordered[0]
|
|
if first.timestamp_s < start_time_s or first.timestamp_s > end_time_s:
|
|
continue
|
|
in_window = [
|
|
request
|
|
for request in ordered
|
|
if start_time_s <= request.timestamp_s <= end_time_s
|
|
]
|
|
if len(in_window) < min_turns:
|
|
continue
|
|
bucket_index = min(
|
|
bucket_count - 1,
|
|
int((first.timestamp_s - start_time_s) / bucket_width_s),
|
|
)
|
|
buckets[bucket_index].append(in_window)
|
|
|
|
for bucket in buckets:
|
|
bucket.sort(key=lambda items: (items[0].timestamp_s, items[0].session_id))
|
|
|
|
selected_sessions: list[list[TraceRequest]] = []
|
|
selected_count = 0
|
|
positions = [0 for _ in range(bucket_count)]
|
|
while selected_count < target_requests:
|
|
progressed = False
|
|
for index, bucket in enumerate(buckets):
|
|
if positions[index] >= len(bucket):
|
|
continue
|
|
session_requests = bucket[positions[index]]
|
|
positions[index] += 1
|
|
selected_sessions.append(session_requests)
|
|
selected_count += len(session_requests)
|
|
progressed = True
|
|
if selected_count >= target_requests:
|
|
break
|
|
if not progressed:
|
|
break
|
|
|
|
selected = [request for items in selected_sessions for request in items]
|
|
selected.sort(key=lambda request: (request.timestamp_s, request.chat_id))
|
|
if len(selected) < target_requests:
|
|
raise ValueError(
|
|
f"window session sample selected only {len(selected)} requests; "
|
|
f"target was {target_requests}"
|
|
)
|
|
return selected
|
|
|
|
|
|
def _is_kvc_fit_smallappend(session_requests: list[TraceRequest]) -> bool:
|
|
initial = session_requests[0]
|
|
if initial.input_length < 2048 or initial.input_length > 16000:
|
|
return False
|
|
for request in session_requests:
|
|
if request.output_length > 2048:
|
|
return False
|
|
for previous, current in zip(session_requests, session_requests[1:], strict=False):
|
|
append_tokens = current.input_length - (
|
|
previous.input_length + previous.output_length
|
|
)
|
|
if append_tokens <= 0 or append_tokens > 2048:
|
|
return False
|
|
if _overlap_ratio(previous, current) < 0.75:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _write_sample(
|
|
*,
|
|
selected: list[TraceRequest],
|
|
input_trace_path: Path,
|
|
output_path: Path,
|
|
profile: str,
|
|
max_sessions: int,
|
|
max_turns_per_session: int,
|
|
rebase_timestamps: bool,
|
|
) -> SampleSummary:
|
|
if not selected:
|
|
raise ValueError(f"profile {profile!r} selected no requests")
|
|
|
|
first_ts = selected[0].timestamp_s
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with output_path.open("w", encoding="utf-8") as handle:
|
|
for request in selected:
|
|
timestamp = request.timestamp_s - first_ts if rebase_timestamps else request.timestamp_s
|
|
payload = {
|
|
"chat_id": request.chat_id,
|
|
"parent_chat_id": request.parent_chat_id,
|
|
"timestamp": round(timestamp, 6),
|
|
"input_length": request.input_length,
|
|
"output_length": request.output_length,
|
|
"type": request.request_type,
|
|
"turn": request.turn_id,
|
|
"hash_ids": list(request.hash_ids),
|
|
}
|
|
handle.write(json.dumps(payload, sort_keys=True) + "\n")
|
|
|
|
sessions = defaultdict(list)
|
|
for request in selected:
|
|
sessions[request.session_id].append(request)
|
|
|
|
selected_chat_ids = {request.chat_id for request in selected}
|
|
missing_parent_count = sum(
|
|
1
|
|
for request in selected
|
|
if request.parent_chat_id >= 0 and request.parent_chat_id not in selected_chat_ids
|
|
)
|
|
append_values: list[float] = []
|
|
gap_values: list[float] = []
|
|
overlap_values: list[float] = []
|
|
direct_eligible_count = 0
|
|
for session_requests in sessions.values():
|
|
ordered = _ordered(session_requests)
|
|
for previous, current in zip(ordered, ordered[1:], strict=False):
|
|
append_tokens = current.input_length - (
|
|
previous.input_length + previous.output_length
|
|
)
|
|
overlap_ratio = _overlap_ratio(previous, current)
|
|
append_values.append(float(append_tokens))
|
|
gap_values.append(float(current.timestamp_s - previous.timestamp_s))
|
|
overlap_values.append(overlap_ratio)
|
|
if append_tokens > 0 and append_tokens <= 2048 and overlap_ratio > 0:
|
|
direct_eligible_count += 1
|
|
|
|
turn2plus_count = sum(max(0, len(items) - 1) for items in sessions.values())
|
|
|
|
start = min(request.timestamp_s for request in selected)
|
|
end = max(request.timestamp_s for request in selected)
|
|
summary = SampleSummary(
|
|
input_trace_path=str(input_trace_path),
|
|
output_trace_path=str(output_path),
|
|
profile=profile,
|
|
request_count=len(selected),
|
|
session_count=len(sessions),
|
|
multi_turn_session_count=sum(1 for items in sessions.values() if len(items) > 1),
|
|
turn2plus_count=turn2plus_count,
|
|
direct_eligible_turn2plus_count=direct_eligible_count,
|
|
direct_eligible_turn2plus_ratio=(
|
|
direct_eligible_count / turn2plus_count if turn2plus_count else 0.0
|
|
),
|
|
missing_parent_count=missing_parent_count,
|
|
max_sessions=max_sessions,
|
|
max_turns_per_session=max_turns_per_session,
|
|
start_time_s=0.0 if rebase_timestamps else start,
|
|
end_time_s=end - start if rebase_timestamps else end,
|
|
sampled_duration_s=end - start,
|
|
rebased_timestamps=rebase_timestamps,
|
|
input_tokens=_stats([float(request.input_length) for request in selected]),
|
|
output_tokens=_stats([float(request.output_length) for request in selected]),
|
|
append_tokens=_stats(append_values),
|
|
inter_turn_gap_s=_stats(gap_values),
|
|
overlap_ratio=_stats(overlap_values),
|
|
)
|
|
with output_path.with_suffix(output_path.suffix + ".summary.json").open(
|
|
"w", encoding="utf-8"
|
|
) as handle:
|
|
json.dump(asdict(summary), handle, indent=2, sort_keys=True)
|
|
return summary
|
|
|
|
|
|
def _ordered(session_requests: list[TraceRequest]) -> list[TraceRequest]:
|
|
return sorted(
|
|
session_requests,
|
|
key=lambda request: (request.timestamp_s, request.turn_id, request.chat_id),
|
|
)
|
|
|
|
|
|
def _overlap_ratio(previous: TraceRequest, current: TraceRequest) -> float:
|
|
if not current.hash_ids:
|
|
return 0.0
|
|
previous_blocks = set(previous.hash_ids)
|
|
overlap = sum(1 for block in current.hash_ids if block in previous_blocks)
|
|
return overlap / len(current.hash_ids)
|
|
|
|
|
|
def _stats(values: list[float]) -> dict[str, float] | None:
|
|
if not values:
|
|
return None
|
|
ordered = sorted(values)
|
|
return {
|
|
"count": float(len(ordered)),
|
|
"mean": statistics.fmean(ordered),
|
|
"min": ordered[0],
|
|
"p50": _percentile(ordered, 0.50),
|
|
"p90": _percentile(ordered, 0.90),
|
|
"p99": _percentile(ordered, 0.99),
|
|
"max": ordered[-1],
|
|
}
|
|
|
|
|
|
def _percentile(sorted_values: list[float], percentile: float) -> float:
|
|
if len(sorted_values) == 1:
|
|
return sorted_values[0]
|
|
return sorted_values[round((len(sorted_values) - 1) * percentile)]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|