#!/usr/bin/env python3 """Compute inter-turn T_external on the chatbot trace, v2. Uses formatted's parent_chat_id chains for sessions, and matches each formatted record to a raw input/output pair by (timestamp, input_length) rather than by index (index-by-index drifted up to 120s in v1). Run on dash0. Writes /tmp/chatbot_inter_turn_gap.json. """ import json import bisect from collections import defaultdict import numpy as np from datetime import datetime RAW_IN = "/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260321-260327/qwen3-max-input-032309-032311.jsonl" RAW_OUT = "/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260321-260327/qwen3-max-output-032109-032711.jsonl" FMT = "/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260321-260327-formatted/qwen_chat_blksz_64_032309-032311.jsonl" OUT_JSON = "/tmp/chatbot_inter_turn_gap.json" def parse_time_str_to_ms(s): try: if "." in s: dt = datetime.strptime(s, "%Y-%m-%d %H:%M:%S.%f") else: dt = datetime.strptime(s, "%Y-%m-%d %H:%M:%S") return int(dt.timestamp() * 1000) except Exception: return None print("Reading raw output (joining by request_id)...") # In this trace prompt_token_num is anonymized to '0'; use generate_token_num # as the matching key (matches formatted output_length). For end_ms we use # time_to_finish_token (ms duration from request start) — the "time" string # field is log-write time, not request completion time. out_info = {} # request_id -> (ttf_ms, generate_token_num) n_out_seen = 0 with open(RAW_OUT) as f: for line in f: try: d = json.loads(line) except: continue rid = d.get("request_id") if rid is None: continue n_out_seen += 1 gtn = d.get("generate_token_num") ttf = d.get("time_to_finish_token") if ttf is None: continue try: ttf_ms = int(float(ttf)) except: continue try: gtn = int(gtn) if gtn is not None else None except: gtn = None # Keep the largest ttf_ms if duplicates (multiple log lines per request) prev = out_info.get(rid) if prev is None or ttf_ms > prev[0]: out_info[rid] = (ttf_ms, gtn) print(f" scanned: {n_out_seen}, unique req with ttf: {len(out_info)}") print("Reading raw input + joining...") joined = [] # list of (start_ms, end_ms, input_length, request_id) n_in_seen = 0 seen_rids = set() with open(RAW_IN) as f: for line in f: try: d = json.loads(line) except: continue n_in_seen += 1 rid = d.get("request_id") ts = d.get("timestamp") if rid is None or ts is None or rid in seen_rids: continue seen_rids.add(rid) info = out_info.get(rid) if info is None: continue ttf_ms, gtn = info try: ts = int(ts) except: continue end_ms = ts + ttf_ms joined.append((ts, end_ms, gtn, rid)) print(f" input scanned: {n_in_seen}, joined start+end+gtn: {len(joined)}") joined.sort(key=lambda x: x[0]) starts = [j[0] for j in joined] gtns = [j[2] for j in joined] # generate_token_num (output_length-equivalent) ends = [j[1] for j in joined] print(f"start_ms range: [{starts[0]}, {starts[-1]}], duration {(starts[-1]-starts[0])/1000:.0f}s") print("Reading formatted...") fmt_rows = [] with open(FMT) as f: for line in f: try: d = json.loads(line) except: continue fmt_rows.append(( int(d["chat_id"]), int(d["parent_chat_id"]), float(d["timestamp"]), int(d.get("input_length", 0)), int(d.get("output_length", 0)), )) print(f" formatted records: {len(fmt_rows)}") print(f"fmt timestamp range: [{fmt_rows[0][2]}, {fmt_rows[-1][2]}]s " f"(duration {fmt_rows[-1][2] - fmt_rows[0][2]:.0f}s)") # Calibrate T0 by matching first few formatted records with raw records. # We use output_length (formatted) vs generate_token_num (raw output) as the # matching key — prompt_token_num is anonymized to 0. print("Calibrating T0 (raw_ms anchor for formatted ts=0)...") T0_candidates = [] for chat_id, _pcid, ts_rel, _il, ol in fmt_rows[:200]: for k in range(min(2000, len(joined))): if gtns[k] == ol: T0_candidates.append(starts[k] - int(ts_rel * 1000)) break T0_candidates.sort() T0 = T0_candidates[len(T0_candidates) // 2] if T0_candidates else starts[0] print(f" T0 from {len(T0_candidates)} candidates -> {T0} ms") print(f" candidate T0 distribution: min={min(T0_candidates) if T0_candidates else 'n/a'} " f"max={max(T0_candidates) if T0_candidates else 'n/a'}") print("Matching formatted -> raw by (ts_rel, output_length)...") TOLERANCE_MS = 60_000 # ±60 s window fmt_to_timing = {} matched = 0 ambiguous = 0 unmatched = 0 for chat_id, _pcid, ts_rel, _il, ol in fmt_rows: target_ms = T0 + int(ts_rel * 1000) lo = bisect.bisect_left(starts, target_ms - TOLERANCE_MS) hi = bisect.bisect_right(starts, target_ms + TOLERANCE_MS) best = None best_drift = None for k in range(lo, hi): if gtns[k] == ol: d = abs(starts[k] - target_ms) if best_drift is None or d < best_drift: best_drift = d best = k if best is None: unmatched += 1 continue fmt_to_timing[chat_id] = (starts[best], ends[best]) matched += 1 print(f" matched: {matched}, unmatched: {unmatched}, ambiguous: {ambiguous}") print(f" match rate: {matched/len(fmt_rows)*100:.1f}%") # Build session structure from parent_chat_id chains chat_to_session = {} for chat_id, pcid, _ts, _il, _ol in fmt_rows: if pcid < 0: sid = chat_id else: sid = chat_to_session.get(pcid, pcid) chat_to_session[chat_id] = sid sessions = defaultdict(list) for chat_id, _pcid, _ts, _il, _ol in fmt_rows: timing = fmt_to_timing.get(chat_id) if timing is None: continue sid = chat_to_session[chat_id] sessions[sid].append(timing) gaps_ms = [] neg = 0 multi = 0 for sid, turns in sessions.items(): if len(turns) < 2: continue multi += 1 turns.sort(key=lambda x: x[0]) for i in range(len(turns) - 1): g = turns[i + 1][0] - turns[i][1] if g < 0: neg += 1 continue gaps_ms.append(g) gaps = np.array(gaps_ms, dtype=np.float64) / 1000.0 print(f"multi_turn_sessions: {multi}, gaps_kept: {len(gaps)}, neg_dropped: {neg}") if len(gaps) == 0: print("No gaps to summarize.") else: pcts = [1, 5, 25, 50, 75, 90, 95, 99] ps = {f"p{p}": float(np.percentile(gaps, p)) for p in pcts} print(f"stats_s: min={gaps.min():.3f} mean={gaps.mean():.3f} max={gaps.max():.3f} {ps}") for thr in [0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 300.0, 1800.0]: pct = (gaps < thr).sum() / len(gaps) * 100 print(f"frac < {thr:7.1f}s : {pct:5.1f}%") n = len(gaps) arr = np.sort(gaps) idx_top = np.unique(np.round(np.geomspace(1, max(1, n // 100), 200)).astype(int)) - 1 idx_rest = np.unique(np.linspace(n // 100, n - 1, 300).astype(int)) idx = np.unique(np.concatenate([[0], idx_top, idx_rest, [n - 1]])) idx = idx[idx < n] samples = [{"rank_pct": float((i + 1) / n * 100), "gap_s": float(arr[i])} for i in idx] out = { "trace": "chatbot", "n_gaps": n, "n_sessions": multi, "negative_dropped": neg, "matched_formatted_to_raw": matched, "unmatched_formatted": unmatched, "stats_s": {**{"min": float(gaps.min()), "max": float(gaps.max()), "mean": float(gaps.mean())}, **ps}, "fraction_below": {f"{thr}s": float((gaps < thr).sum() / n) for thr in [0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 300.0, 1800.0]}, "cdf_samples": samples, } open(OUT_JSON, "w").write(json.dumps(out)) print(f"wrote {OUT_JSON}")