Files
agentic-pd-hybrid/scripts/convert_inferact_to_trace.py
tim d11a66d11b feat(scripts): cu12.8 env wrapper + Inferact trace converter
setup_env.sh: source-able shell snippet that points tvm_ffi (vendor
sglang JIT compiler) at \$HOME/cuda-12.8/bin/nvcc and exposes both
libcudart.so.12 (for mooncake.engine, a cu12 wheel) and cu12.8 lib64
(for tvm_ffi compile-time linker) on LD_LIBRARY_PATH. Without this,
JIT-compiled kernels NEEDED libcudart.so.13 and driver 570 rejected
them at every JIT call.

convert_inferact_to_trace.py: turns Inferact codex_swebenchpro_traces
(ShareGPT {"from","value"} pairs) into the chat_id/parent_chat_id/
turn/hash_ids JSONL schema replay.py expects. Tokenizes with the
model's own tokenizer, builds prefix-sharing 24-token block hashes,
synthesizes timestamps. Output cross-checks 20,230 LLM calls — exactly
matches the Inferact README count for 610 successful trials.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 00:10:06 +08:00

190 lines
6.3 KiB
Python

"""Convert Inferact codex_swebenchpro_traces (ShareGPT) to agentic-pd-hybrid trace JSONL.
Output schema (one JSON object per line, matching src/agentic_pd_hybrid/trace.py):
chat_id, parent_chat_id, timestamp, input_length, output_length, type, turn, hash_ids
Each trial in the input becomes one session. Each (human, gpt) pair within a trial
becomes one turn. The prefix at turn N is the concatenation of all (human, gpt) pairs
from turns 0..N-1 plus the current human message — this mirrors how agentic coding
agents grow context across calls.
hash_ids are derived per 24-token block via sha256 of the block's text + previous hash,
which gives stable, deterministic, prefix-shared hashes across turns of the same session.
"""
from __future__ import annotations
import argparse
import hashlib
import json
import sys
import time
from pathlib import Path
BLOCK_TOKEN_BUDGET = 24
def _block_hash(text: str, prev_hash: int) -> int:
h = hashlib.sha256(text.encode("utf-8") + prev_hash.to_bytes(8, "big")).digest()
return int.from_bytes(h[:8], "big") & 0x7FFFFFFFFFFFFFFF
def _build_hash_ids(token_ids: list[int]) -> list[int]:
out: list[int] = []
prev = 0
for start in range(0, len(token_ids), BLOCK_TOKEN_BUDGET):
block = token_ids[start : start + BLOCK_TOKEN_BUDGET]
block_repr = ",".join(str(t) for t in block)
prev = _block_hash(block_repr, prev)
out.append(prev)
return out
def _pair_turns(conv: list[dict]) -> list[tuple[str, str]]:
"""Pair consecutive (human, gpt) messages. Skip malformed."""
pairs: list[tuple[str, str]] = []
i = 0
while i + 1 < len(conv):
a, b = conv[i], conv[i + 1]
if (
isinstance(a, dict)
and isinstance(b, dict)
and a.get("from") == "human"
and b.get("from") == "gpt"
):
pairs.append((str(a.get("value", "")), str(b.get("value", ""))))
i += 2
else:
i += 1
return pairs
def convert(
input_path: Path,
output_path: Path,
*,
tokenizer_path: str,
max_trials: int | None,
inter_turn_gap_s: float,
session_stagger_s: float,
request_type: str,
) -> None:
from transformers import AutoTokenizer
print(f"loading tokenizer from {tokenizer_path}", file=sys.stderr)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
print(f"loading {input_path}", file=sys.stderr)
data = json.loads(input_path.read_text())
if max_trials is not None:
data = data[:max_trials]
print(f"{len(data)} trials to process", file=sys.stderr)
next_chat_id = 1_000_000
written = 0
skipped_trials = 0
t0 = time.time()
with output_path.open("w", encoding="utf-8") as out_f:
for trial_idx, trial in enumerate(data):
conv = trial.get("conversations") or []
turns = _pair_turns(conv)
if not turns:
skipped_trials += 1
continue
base_ts = trial_idx * session_stagger_s
ts = base_ts
parent_chat_id = -1
prefix_text = ""
for turn_idx, (human, assistant) in enumerate(turns):
# Input at this turn = full prior context + current human message.
current_text = (
prefix_text + ("\n\n[USER]\n" if prefix_text else "[USER]\n") + human
)
input_ids = tokenizer.encode(current_text, add_special_tokens=False)
input_length = len(input_ids)
output_ids = tokenizer.encode(assistant, add_special_tokens=False)
output_length = max(1, len(output_ids))
hash_ids = _build_hash_ids(input_ids)
chat_id = next_chat_id
next_chat_id += 1
record = {
"chat_id": chat_id,
"parent_chat_id": parent_chat_id,
"timestamp": round(ts, 6),
"input_length": input_length,
"output_length": output_length,
"type": request_type,
"turn": turn_idx,
"hash_ids": hash_ids,
}
out_f.write(json.dumps(record) + "\n")
written += 1
parent_chat_id = chat_id
ts += inter_turn_gap_s
prefix_text = current_text + "\n\n[ASSISTANT]\n" + assistant
if (trial_idx + 1) % 20 == 0:
elapsed = time.time() - t0
rate = (trial_idx + 1) / elapsed if elapsed > 0 else 0
eta = (len(data) - trial_idx - 1) / rate if rate > 0 else 0
print(
f" trial {trial_idx + 1}/{len(data)} reqs={written} "
f"rate={rate:.1f} trial/s eta={eta:.0f}s",
file=sys.stderr,
)
elapsed = time.time() - t0
print(
f"done: wrote {written} requests across {len(data) - skipped_trials} sessions "
f"({skipped_trials} trials skipped, empty conversations) in {elapsed:.1f}s "
f"to {output_path}",
file=sys.stderr,
)
def main() -> None:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument(
"--input",
type=Path,
default=Path("third_party/codex_swebenchpro_traces/codex_swebenchpro.json"),
)
p.add_argument("--output", type=Path, required=True)
p.add_argument(
"--tokenizer",
default="/mnt/models/Qwen/Qwen3-30B-A3B-Instruct-2507",
help="Path or HF id for the tokenizer. Default matches v2 sweep model.",
)
p.add_argument(
"--max-trials",
type=int,
default=None,
help="Cap number of trials processed (useful for smoke / quick tests).",
)
p.add_argument("--inter-turn-gap-s", type=float, default=2.5)
p.add_argument("--session-stagger-s", type=float, default=1.0)
p.add_argument("--request-type", default="chat")
args = p.parse_args()
args.output.parent.mkdir(parents=True, exist_ok=True)
convert(
input_path=args.input,
output_path=args.output,
tokenizer_path=args.tokenizer,
max_trials=args.max_trials,
inter_turn_gap_s=args.inter_turn_gap_s,
session_stagger_s=args.session_stagger_s,
request_type=args.request_type,
)
if __name__ == "__main__":
main()