diff --git a/tests/test_ali_trace_pipeline.py b/tests/test_ali_trace_pipeline.py index 45aa96a..766f991 100644 --- a/tests/test_ali_trace_pipeline.py +++ b/tests/test_ali_trace_pipeline.py @@ -3,11 +3,12 @@ import tempfile import unittest from datetime import datetime, timezone from pathlib import Path +from unittest import mock from trace_analyzer.cli import main as analyzer_main from trace_analyzer.parser import load_records from trace_formatter.cli import main as formatter_main -from trace_formatter.formatting import build_unified_row, discover_source_files, format_and_sort_trace +from trace_formatter.formatting import build_unified_row, discover_source_files, export_release_ready_trace, format_and_sort_trace def utc_ms(value: str) -> int: @@ -476,6 +477,74 @@ class AliTracePipelineTest(unittest.TestCase): formatted_rows = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] self.assertEqual([row["meta"]["request_id"] for row in formatted_rows], ["req-kept"]) + def test_format_and_sort_trace_normalizes_invalid_surrogates_before_chunking(self): + with tempfile.TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + input_dir = root / "raw" + input_dir.mkdir() + output_path = root / "formatted.jsonl" + + row = make_raw_row( + "req-\ud83c", + utc_ms("2026-04-17 15:00:03.000"), + messages=[{"role": "user", "content": "bad \ud83c content"}], + ) + + with (input_dir / "0417-1500-1530.jsonl").open("w", encoding="utf-8") as handle: + handle.write(json.dumps(row) + "\n") + + stats = format_and_sort_trace(input_dir=input_dir, output_path=output_path, chunk_bytes=256) + + self.assertEqual(stats["row_count"], 1) + formatted_rows = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] + self.assertEqual(formatted_rows[0]["meta"]["request_id"], "req-\uFFFD") + self.assertEqual(formatted_rows[0]["message_events"][0]["text_len"], len("bad \uFFFD content")) + self.assertEqual(formatted_rows[0]["raw_messages"][0]["content"], "bad \uFFFD content") + self.assertIn("\uFFFD", formatted_rows[0]["canonical_prompt"]) + + def test_format_and_sort_trace_normalizes_nonstandard_glm_tool_call_shapes(self): + with tempfile.TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + input_dir = root / "raw" + input_dir.mkdir() + output_path = root / "formatted.jsonl" + + row = make_raw_row( + "req-tool-call-shape", + utc_ms("2026-04-17 15:00:04.000"), + messages=[ + {"role": "user", "content": "hello"}, + { + "role": "assistant", + "content": "calling tool", + "tool_calls": { + "id": "call-1", + "type": "function", + "arguments": "{\"path\":\"/tmp/a.txt\"}", + }, + }, + ], + ) + request_params = json.loads(row["request_params"]) + request_params["payload"]["parameters"]["tools"] = { + "type": "function", + "name": "read_file", + "parameters": {"type": "object", "properties": {"path": {"type": "string"}}}, + } + row["request_params"] = json.dumps(request_params) + + with (input_dir / "0417-1500-1530.jsonl").open("w", encoding="utf-8") as handle: + handle.write(json.dumps(row) + "\n") + + stats = format_and_sort_trace(input_dir=input_dir, output_path=output_path, chunk_bytes=256) + + self.assertEqual(stats["row_count"], 1) + formatted_rows = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] + self.assertIn("", formatted_rows[0]["canonical_prompt"]) + self.assertIn("\"name\": \"read_file\"", formatted_rows[0]["canonical_prompt"]) + self.assertIn("", formatted_rows[0]["canonical_prompt"]) + self.assertIn("path", formatted_rows[0]["canonical_prompt"]) + def test_trace_formatter_cli_formats_one_raw_jsonl_file(self): with tempfile.TemporaryDirectory() as temp_dir: root = Path(temp_dir) @@ -564,6 +633,66 @@ class AliTracePipelineTest(unittest.TestCase): self.assertTrue(log_path.exists()) self.assertIn("Scan raw trace", log_path.read_text(encoding="utf-8")) + def test_format_and_sort_trace_defaults_temp_dir_to_output_parent(self): + with tempfile.TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + input_dir = root / "raw" + input_dir.mkdir() + output_dir = root / "formatted" + output_path = output_dir / "formatted.jsonl" + with (input_dir / "0417-1500-1530.jsonl").open("w", encoding="utf-8") as handle: + handle.write(json.dumps(make_raw_row("req-1", utc_ms("2026-04-17 15:00:01.000"))) + "\n") + + captured = {} + real_temporary_directory = tempfile.TemporaryDirectory + + def recording_temporary_directory(*args, **kwargs): + captured["dir"] = kwargs.get("dir") + return real_temporary_directory(*args, **kwargs) + + with mock.patch("trace_formatter.formatting.tempfile.TemporaryDirectory", side_effect=recording_temporary_directory): + format_and_sort_trace(input_dir=input_dir, output_path=output_path, chunk_bytes=256) + + self.assertEqual(Path(captured["dir"]), output_dir) + + def test_export_release_ready_trace_defaults_temp_dir_to_output_parent(self): + with tempfile.TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + raw_input_path = root / "trace-raw.jsonl" + release_output_dir = root / "release" + release_output_path = release_output_dir / "trace.jsonl" + raw_row = { + "schema_version": "2026.04.21", + "sort_time_ms": utc_ms("2026-04-17 15:00:01.000"), + "meta": { + "model_family": "glm5", + "request_ready_time_ms": utc_ms("2026-04-17 15:00:01.000"), + "chat_id": 0, + "parent_chat_id": -1, + "turn": 1, + }, + "canonical_prompt": "hello", + "usage": {"input_tokens": 1, "output_tokens": 1}, + } + raw_input_path.write_text(json.dumps(raw_row) + "\n", encoding="utf-8") + + captured = {} + real_temporary_directory = tempfile.TemporaryDirectory + + def recording_temporary_directory(*args, **kwargs): + captured["dir"] = kwargs.get("dir") + return real_temporary_directory(*args, **kwargs) + + with mock.patch("trace_formatter.formatting.tempfile.TemporaryDirectory", side_effect=recording_temporary_directory): + export_release_ready_trace( + raw_input_path=raw_input_path, + release_output_path=release_output_path, + jobs=1, + block_size=8, + ) + + self.assertEqual(Path(captured["dir"]), release_output_dir) + def test_format_and_sort_trace_infers_window_in_ready_time_scale_when_wall_clock_has_timezone_offset(self): with tempfile.TemporaryDirectory() as temp_dir: root = Path(temp_dir) diff --git a/trace_analyzer/helpers.py b/trace_analyzer/helpers.py index a732a2f..9aaf75b 100644 --- a/trace_analyzer/helpers.py +++ b/trace_analyzer/helpers.py @@ -2,17 +2,103 @@ import json from statistics import mean, median +def _text_has_surrogates(value): + for char in value: + codepoint = ord(char) + if 0xD800 <= codepoint <= 0xDFFF: + return True + return False + + +def normalize_unicode_text(value): + if not isinstance(value, str) or not value: + return value + if not _text_has_surrogates(value): + return value + + normalized = [] + index = 0 + while index < len(value): + codepoint = ord(value[index]) + if 0xD800 <= codepoint <= 0xDBFF: + if index + 1 < len(value): + next_codepoint = ord(value[index + 1]) + if 0xDC00 <= next_codepoint <= 0xDFFF: + combined = 0x10000 + ((codepoint - 0xD800) << 10) + (next_codepoint - 0xDC00) + normalized.append(chr(combined)) + index += 2 + continue + normalized.append("\uFFFD") + index += 1 + continue + if 0xDC00 <= codepoint <= 0xDFFF: + normalized.append("\uFFFD") + index += 1 + continue + normalized.append(value[index]) + index += 1 + return "".join(normalized) + + +def normalize_unicode_value(value): + if isinstance(value, str): + return normalize_unicode_text(value) + if isinstance(value, list): + normalized = None + for index, item in enumerate(value): + normalized_item = normalize_unicode_value(item) + if normalized is not None: + normalized.append(normalized_item) + continue + if normalized_item is not item: + normalized = list(value[:index]) + normalized.append(normalized_item) + return normalized if normalized is not None else value + if isinstance(value, tuple): + normalized = None + for index, item in enumerate(value): + normalized_item = normalize_unicode_value(item) + if normalized is not None: + normalized.append(normalized_item) + continue + if normalized_item is not item: + normalized = list(value[:index]) + normalized.append(normalized_item) + return tuple(normalized) if normalized is not None else value + if isinstance(value, dict): + normalized = None + for key, item in value.items(): + normalized_key = normalize_unicode_text(key) if isinstance(key, str) else key + normalized_item = normalize_unicode_value(item) + if normalized is not None: + normalized[normalized_key] = normalized_item + continue + if normalized_key is not key or normalized_item is not item: + normalized = {} + for original_key, original_item in value.items(): + if original_key == key: + break + normalized[original_key] = original_item + normalized[normalized_key] = normalized_item + return normalized if normalized is not None else value + return value + + def parse_jsonish(value): """Parse nested JSON strings until a non-string value is reached.""" + if not isinstance(value, str): + return normalize_unicode_value(value) + current = value while isinstance(current, str): text = current.strip() if not text: - return current + return normalize_unicode_text(current) try: - current = json.loads(text) + parsed = json.loads(text) except json.JSONDecodeError: - return current + return normalize_unicode_text(current) + current = normalize_unicode_value(parsed) if "\\ud" in text else parsed return current @@ -75,4 +161,4 @@ def safe_div(numerator, denominator): def compact_json(data): - return json.dumps(data, ensure_ascii=False, separators=(",", ":")) + return normalize_unicode_text(json.dumps(data, ensure_ascii=False, separators=(",", ":"))) diff --git a/trace_formatter/formatting.py b/trace_formatter/formatting.py index 9aa7463..80706b3 100644 --- a/trace_formatter/formatting.py +++ b/trace_formatter/formatting.py @@ -15,7 +15,7 @@ from dataclasses import asdict from pathlib import Path from typing import Iterator, TextIO -from trace_analyzer.helpers import parse_jsonish, safe_int +from trace_analyzer.helpers import normalize_unicode_text, parse_jsonish, safe_int from tokenizers import Tokenizer from tqdm.auto import tqdm from trace_model_meta import infer_model_family_from_request_model, resolve_tokenizer_path @@ -225,18 +225,18 @@ def _build_unified_row_from_components( "sort_time_ms": sort_time_ms, "meta": { "model_family": model_family, - "request_id": str(raw.get("request_id", "")), + "request_id": normalize_unicode_text(str(raw.get("request_id", ""))), "session_id": "", - "raw_session_id": str(raw.get("session_id", "")), + "raw_session_id": normalize_unicode_text(str(raw.get("session_id", ""))), "user_id": user_id, "parent_request_id": "", "parent_chat_id": -1, "chat_id": -1, "turn": 0, - "request_model": str(raw.get("request_model", "")), - "time": str(raw.get("time", "")), - "status_code": str(raw.get("status_code", "")), - "status_name": str(raw.get("status_name", "")), + "request_model": normalize_unicode_text(str(raw.get("request_model", ""))), + "time": normalize_unicode_text(str(raw.get("time", ""))), + "status_code": normalize_unicode_text(str(raw.get("status_code", ""))), + "status_name": normalize_unicode_text(str(raw.get("status_name", ""))), "request_ready_time_ms": sort_time_ms, "request_end_time_ms": request_end_time_ms, "total_cost_time_ms": total_cost_time_ms, @@ -417,6 +417,15 @@ def _open_progress_stream(log_file: str | Path | None): yield _TeeStream(sys.stderr, handle) +def _resolve_temp_root_dir(*, tmp_dir: str | Path | None, output_path: str | Path) -> Path: + if tmp_dir is not None: + root = Path(tmp_dir) + else: + root = Path(output_path).parent + root.mkdir(parents=True, exist_ok=True) + return root + + def _block_digest(block: list[int]) -> bytes: digest = hashlib.blake2b(digest_size=16) digest.update(len(block).to_bytes(4, "little", signed=False)) @@ -564,6 +573,7 @@ def export_release_ready_trace( input_path = Path(raw_input_path) release_destination = Path(release_output_path) release_destination.parent.mkdir(parents=True, exist_ok=True) + temp_root_dir = _resolve_temp_root_dir(tmp_dir=tmp_dir, output_path=release_destination) requested_jobs = jobs if jobs is not None else min(os.cpu_count() or 1, 16) shard_jobs = max(1, requested_jobs) @@ -573,7 +583,7 @@ def export_release_ready_trace( block_ids_by_digest: dict[str, int] = {} row_count = 0 - with tempfile.TemporaryDirectory(dir=tmp_dir) as temp_root: + with tempfile.TemporaryDirectory(dir=temp_root_dir) as temp_root: shard_root = Path(temp_root) / "release-shards" shard_root.mkdir(parents=True, exist_ok=True) shard_specs = [ @@ -682,12 +692,13 @@ def format_and_sort_trace( source_files = discover_source_files(input_dir) destination = Path(output_path) destination.parent.mkdir(parents=True, exist_ok=True) + temp_root_dir = _resolve_temp_root_dir(tmp_dir=tmp_dir, output_path=destination) time_offset_ms = infer_time_offset_ms(source_files[0]) if source_files else 0 time_window = infer_time_window(source_files, start_time=start_time, end_time=end_time) if truncate_to_window else None total_input_bytes = sum(path.stat().st_size for path in source_files if path.suffix != ".zst") has_zst = any(path.suffix == ".zst" for path in source_files) - with _open_progress_stream(log_file) as progress_stream, tempfile.TemporaryDirectory(dir=tmp_dir) as temp_root: + with _open_progress_stream(log_file) as progress_stream, tempfile.TemporaryDirectory(dir=temp_root_dir) as temp_root: temp_raw_destination = Path(temp_root) / "formatted-raw.tmp.jsonl" chunk_root = Path(temp_root) chunk_paths: list[Path] = [] diff --git a/trace_formatter/raw_parser.py b/trace_formatter/raw_parser.py index 239bd39..8c7f52e 100644 --- a/trace_formatter/raw_parser.py +++ b/trace_formatter/raw_parser.py @@ -128,20 +128,33 @@ def _normalize_message_content_for_template(content, role=""): def _normalize_tool_call_for_template(tool_call): if not isinstance(tool_call, dict): - return tool_call + return { + "function": { + "name": "", + "arguments": {"__raw_tool_call__": parse_jsonish(tool_call)}, + } + } + normalized = dict(tool_call) function = normalized.get("function") - normalized_function = dict(function) if isinstance(function, dict) else None - if normalized_function is None and ("name" in normalized or "arguments" in normalized): - normalized_function = {} - if normalized_function is not None: - if "name" not in normalized_function and normalized.get("name"): - normalized_function["name"] = normalized["name"] - if "arguments" not in normalized_function and "arguments" in normalized: - normalized_function["arguments"] = normalized["arguments"] - arguments = parse_jsonish(normalized_function.get("arguments", {})) - normalized_function["arguments"] = arguments if isinstance(arguments, dict) else {"__raw_arguments__": arguments} - normalized["function"] = normalized_function + normalized_function = dict(function) if isinstance(function, dict) else {} + + tool_name = ( + normalized_function.get("name") + or normalized.get("name") + or normalized.get("tool_name") + or normalized.get("function_name") + or "" + ) + raw_arguments = ( + normalized_function.get("arguments") + if "arguments" in normalized_function + else normalized.get("arguments", normalized.get("parameters", normalized.get("args", {}))) + ) + arguments = parse_jsonish(raw_arguments) + normalized_function["name"] = str(tool_name or "") + normalized_function["arguments"] = arguments if isinstance(arguments, dict) else {"__raw_arguments__": arguments} + normalized["function"] = normalized_function return normalized @@ -150,11 +163,23 @@ def _normalize_tool_spec_for_template(tool): return tool normalized = dict(tool) function = normalized.get("function") - if isinstance(function, dict): - normalized_function = dict(function) - parameters = parse_jsonish(normalized_function.get("parameters", {})) + normalized_function = dict(function) if isinstance(function, dict) else {} + tool_name = ( + normalized_function.get("name") + or normalized.get("name") + or normalized.get("tool_name") + or normalized.get("function_name") + or "" + ) + if tool_name or function is not None or "name" in normalized or "parameters" in normalized: + parameters = parse_jsonish( + normalized_function.get("parameters", normalized.get("parameters", normalized.get("args", {}))) + ) + normalized_function["name"] = str(tool_name or "") if isinstance(parameters, dict): normalized_function["parameters"] = parameters + else: + normalized_function["parameters"] = {"__raw_parameters__": parameters} normalized["function"] = normalized_function return normalized @@ -165,7 +190,12 @@ def _normalize_qwen_message_for_template(message): normalized_message = dict(message) normalized_message["content"] = _stringify_message_content_for_template(message.get("content")) normalized_tool_calls = [] - for tool_call in message.get("tool_calls", []): + tool_calls = message.get("tool_calls", []) + if isinstance(tool_calls, dict): + tool_calls = [tool_calls] + elif not isinstance(tool_calls, list): + tool_calls = [tool_calls] + for tool_call in tool_calls: normalized_tool_call = _normalize_tool_call_for_template(tool_call) if isinstance(normalized_tool_call, dict): function = normalized_tool_call.get("function") @@ -244,11 +274,21 @@ def build_glm5_canonical_prompt(payload): message.get("content"), role=str(message.get("role", "")), ) + tool_calls = message.get("tool_calls", []) + if isinstance(tool_calls, dict): + tool_calls = [tool_calls] + elif not isinstance(tool_calls, list): + tool_calls = [tool_calls] normalized_message["tool_calls"] = [ - _normalize_tool_call_for_template(tool_call) for tool_call in message.get("tool_calls", []) + _normalize_tool_call_for_template(tool_call) for tool_call in tool_calls ] messages.append(normalized_message) - tools = [_normalize_tool_spec_for_template(tool) for tool in parameters.get("tools", []) if isinstance(tool, dict)] + tools_payload = parameters.get("tools", []) + if isinstance(tools_payload, dict): + tools_payload = [tools_payload] + elif not isinstance(tools_payload, list): + tools_payload = [tools_payload] + tools = [_normalize_tool_spec_for_template(tool) for tool in tools_payload if isinstance(tool, dict)] return _load_glm5_chat_template().render( messages=messages, tools=tools, @@ -266,7 +306,12 @@ def build_qwen3_coder_canonical_prompt(payload): ] if not messages: messages = [{"role": "system", "content": ""}] - tools = [_normalize_tool_spec_for_template(tool) for tool in parameters.get("tools", []) if isinstance(tool, dict)] + tools_payload = parameters.get("tools", []) + if isinstance(tools_payload, dict): + tools_payload = [tools_payload] + elif not isinstance(tools_payload, list): + tools_payload = [tools_payload] + tools = [_normalize_tool_spec_for_template(tool) for tool in tools_payload if isinstance(tool, dict)] return _load_qwen3_coder_chat_template().render( messages=messages, tools=tools,