Initial commit
This commit is contained in:
370
trace_formatter/raw_parser.py
Normal file
370
trace_formatter/raw_parser.py
Normal file
@@ -0,0 +1,370 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import Environment
|
||||
import re
|
||||
|
||||
from trace_analyzer.helpers import compact_json, parse_jsonish
|
||||
from trace_model_meta import infer_model_family_from_request_model, resolve_chat_template_path
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedMessageEvent:
|
||||
role: str
|
||||
content_type: str
|
||||
text_len: int
|
||||
has_cache_control: bool = False
|
||||
item_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedToolSpec:
|
||||
name: str
|
||||
tool_type: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Qwen3CoderToolParserSpec:
|
||||
tool_call_start_token: str = "<tool_call>"
|
||||
tool_call_end_token: str = "</tool_call>"
|
||||
tool_call_prefix: str = "<function="
|
||||
function_end_token: str = "</function>"
|
||||
parameter_prefix: str = "<parameter="
|
||||
parameter_end_token: str = "</parameter>"
|
||||
|
||||
|
||||
def _template_tojson(value, ensure_ascii=True):
|
||||
return json.dumps(value, ensure_ascii=ensure_ascii)
|
||||
|
||||
|
||||
@lru_cache(maxsize=4)
|
||||
def _load_chat_template(path: str):
|
||||
environment = Environment()
|
||||
environment.filters["tojson"] = _template_tojson
|
||||
return environment.from_string(Path(path).read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _load_glm5_chat_template():
|
||||
return _load_chat_template(str(resolve_chat_template_path("glm5")))
|
||||
|
||||
|
||||
def _load_qwen3_coder_chat_template():
|
||||
return _load_chat_template(str(resolve_chat_template_path("qwen3-coder")))
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_qwen3_coder_tool_parser_spec() -> Qwen3CoderToolParserSpec:
|
||||
defaults = Qwen3CoderToolParserSpec()
|
||||
parser_path = resolve_chat_template_path("qwen3-coder").with_name("qwen3coder_tool_parser.py")
|
||||
if not parser_path.exists():
|
||||
return defaults
|
||||
text = parser_path.read_text(encoding="utf-8")
|
||||
|
||||
def _extract(name: str, fallback: str) -> str:
|
||||
pattern = rf'self\.{re.escape(name)}:\s*str\s*=\s*"([^"]+)"'
|
||||
match = re.search(pattern, text)
|
||||
return match.group(1) if match else fallback
|
||||
|
||||
return Qwen3CoderToolParserSpec(
|
||||
tool_call_start_token=_extract("tool_call_start_token", defaults.tool_call_start_token),
|
||||
tool_call_end_token=_extract("tool_call_end_token", defaults.tool_call_end_token),
|
||||
tool_call_prefix=_extract("tool_call_prefix", defaults.tool_call_prefix),
|
||||
function_end_token=_extract("function_end_token", defaults.function_end_token),
|
||||
parameter_prefix=_extract("parameter_prefix", defaults.parameter_prefix),
|
||||
parameter_end_token=_extract("parameter_end_token", defaults.parameter_end_token),
|
||||
)
|
||||
|
||||
|
||||
def _stringify_message_content_for_template(content):
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict) and isinstance(item.get("text"), str):
|
||||
parts.append(item["text"])
|
||||
else:
|
||||
parts.append(compact_json(item))
|
||||
return "".join(parts)
|
||||
if isinstance(content, dict):
|
||||
if isinstance(content.get("text"), str):
|
||||
return content["text"]
|
||||
return compact_json(content)
|
||||
if content is None:
|
||||
return ""
|
||||
return str(content)
|
||||
|
||||
|
||||
def _normalize_message_content_for_template(content, role=""):
|
||||
if role == "tool" and isinstance(content, (list, dict)):
|
||||
has_named_tool_refs = False
|
||||
if isinstance(content, list):
|
||||
has_named_tool_refs = any(isinstance(item, dict) and item.get("name") for item in content)
|
||||
elif isinstance(content, dict):
|
||||
has_named_tool_refs = bool(content.get("name"))
|
||||
if not has_named_tool_refs:
|
||||
return _stringify_message_content_for_template(content)
|
||||
if isinstance(content, list):
|
||||
normalized_items = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "text" in item and "type" not in item:
|
||||
normalized_item = dict(item)
|
||||
normalized_item["type"] = "text"
|
||||
normalized_items.append(normalized_item)
|
||||
else:
|
||||
normalized_items.append(item)
|
||||
return normalized_items
|
||||
if isinstance(content, dict) and "text" in content:
|
||||
normalized_item = dict(content)
|
||||
normalized_item.setdefault("type", "text")
|
||||
return [normalized_item]
|
||||
return content
|
||||
|
||||
|
||||
def _normalize_tool_call_for_template(tool_call):
|
||||
if not isinstance(tool_call, dict):
|
||||
return tool_call
|
||||
normalized = dict(tool_call)
|
||||
function = normalized.get("function")
|
||||
normalized_function = dict(function) if isinstance(function, dict) else None
|
||||
if normalized_function is None and ("name" in normalized or "arguments" in normalized):
|
||||
normalized_function = {}
|
||||
if normalized_function is not None:
|
||||
if "name" not in normalized_function and normalized.get("name"):
|
||||
normalized_function["name"] = normalized["name"]
|
||||
if "arguments" not in normalized_function and "arguments" in normalized:
|
||||
normalized_function["arguments"] = normalized["arguments"]
|
||||
arguments = parse_jsonish(normalized_function.get("arguments", {}))
|
||||
normalized_function["arguments"] = arguments if isinstance(arguments, dict) else {"__raw_arguments__": arguments}
|
||||
normalized["function"] = normalized_function
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_tool_spec_for_template(tool):
|
||||
if not isinstance(tool, dict):
|
||||
return tool
|
||||
normalized = dict(tool)
|
||||
function = normalized.get("function")
|
||||
if isinstance(function, dict):
|
||||
normalized_function = dict(function)
|
||||
parameters = parse_jsonish(normalized_function.get("parameters", {}))
|
||||
if isinstance(parameters, dict):
|
||||
normalized_function["parameters"] = parameters
|
||||
normalized["function"] = normalized_function
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_qwen_message_for_template(message):
|
||||
if not isinstance(message, dict):
|
||||
return message
|
||||
normalized_message = dict(message)
|
||||
normalized_message["content"] = _stringify_message_content_for_template(message.get("content"))
|
||||
normalized_tool_calls = []
|
||||
for tool_call in message.get("tool_calls", []):
|
||||
normalized_tool_call = _normalize_tool_call_for_template(tool_call)
|
||||
if isinstance(normalized_tool_call, dict):
|
||||
function = normalized_tool_call.get("function")
|
||||
if isinstance(function, dict):
|
||||
arguments = function.get("arguments", {})
|
||||
if not isinstance(arguments, dict):
|
||||
function["arguments"] = {"__raw_arguments__": arguments}
|
||||
normalized_tool_calls.append(normalized_tool_call)
|
||||
normalized_message["tool_calls"] = normalized_tool_calls
|
||||
return normalized_message
|
||||
|
||||
|
||||
def _render_qwen3_coder_tool_call(tool_call, spec: Qwen3CoderToolParserSpec) -> str:
|
||||
normalized_tool_call = _normalize_tool_call_for_template(tool_call)
|
||||
if not isinstance(normalized_tool_call, dict):
|
||||
return ""
|
||||
function = normalized_tool_call.get("function", {})
|
||||
if not isinstance(function, dict):
|
||||
return ""
|
||||
function_name = str(function.get("name", "")).strip()
|
||||
if not function_name:
|
||||
return ""
|
||||
parts = [
|
||||
spec.tool_call_start_token,
|
||||
f"{spec.tool_call_prefix}{function_name}>",
|
||||
]
|
||||
arguments = function.get("arguments", {})
|
||||
if isinstance(arguments, dict):
|
||||
for arg_name, arg_value in arguments.items():
|
||||
rendered_value = (
|
||||
json.dumps(arg_value, ensure_ascii=False)
|
||||
if isinstance(arg_value, (dict, list))
|
||||
else str(arg_value)
|
||||
)
|
||||
parts.extend(
|
||||
[
|
||||
f"{spec.parameter_prefix}{arg_name}>",
|
||||
rendered_value,
|
||||
spec.parameter_end_token,
|
||||
]
|
||||
)
|
||||
parts.extend([spec.function_end_token, spec.tool_call_end_token])
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _normalize_qwen_message_with_tool_parser(message, spec: Qwen3CoderToolParserSpec):
|
||||
normalized_message = _normalize_qwen_message_for_template(message)
|
||||
if not isinstance(normalized_message, dict):
|
||||
return normalized_message
|
||||
if normalized_message.get("role") != "assistant":
|
||||
return normalized_message
|
||||
tool_calls = normalized_message.get("tool_calls", [])
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
return normalized_message
|
||||
rendered_tool_calls = [chunk for chunk in (_render_qwen3_coder_tool_call(tool_call, spec) for tool_call in tool_calls) if chunk]
|
||||
if not rendered_tool_calls:
|
||||
return normalized_message
|
||||
prefix = str(normalized_message.get("content", "") or "").strip()
|
||||
combined_parts = [prefix] if prefix else []
|
||||
combined_parts.extend(rendered_tool_calls)
|
||||
normalized_message["content"] = "\n".join(combined_parts)
|
||||
normalized_message["tool_calls"] = []
|
||||
return normalized_message
|
||||
|
||||
|
||||
def build_glm5_canonical_prompt(payload):
|
||||
input_payload = payload.get("input", {}) if isinstance(payload, dict) else {}
|
||||
parameters = payload.get("parameters", {}) if isinstance(payload, dict) else {}
|
||||
messages = []
|
||||
for message in input_payload.get("messages", []):
|
||||
if not isinstance(message, dict):
|
||||
messages.append(message)
|
||||
continue
|
||||
normalized_message = dict(message)
|
||||
normalized_message["content"] = _normalize_message_content_for_template(
|
||||
message.get("content"),
|
||||
role=str(message.get("role", "")),
|
||||
)
|
||||
normalized_message["tool_calls"] = [
|
||||
_normalize_tool_call_for_template(tool_call) for tool_call in message.get("tool_calls", [])
|
||||
]
|
||||
messages.append(normalized_message)
|
||||
tools = [_normalize_tool_spec_for_template(tool) for tool in parameters.get("tools", []) if isinstance(tool, dict)]
|
||||
return _load_glm5_chat_template().render(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
|
||||
def build_qwen3_coder_canonical_prompt(payload):
|
||||
input_payload = payload.get("input", {}) if isinstance(payload, dict) else {}
|
||||
parameters = payload.get("parameters", {}) if isinstance(payload, dict) else {}
|
||||
tool_parser_spec = _load_qwen3_coder_tool_parser_spec()
|
||||
messages = [
|
||||
_normalize_qwen_message_with_tool_parser(message, tool_parser_spec)
|
||||
for message in input_payload.get("messages", [])
|
||||
]
|
||||
if not messages:
|
||||
messages = [{"role": "system", "content": ""}]
|
||||
tools = [_normalize_tool_spec_for_template(tool) for tool in parameters.get("tools", []) if isinstance(tool, dict)]
|
||||
return _load_qwen3_coder_chat_template().render(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
|
||||
class GLMTraceAdapter:
|
||||
name = "glm"
|
||||
|
||||
def detect(self, raw):
|
||||
return "request_params" in raw and "response_params" in raw
|
||||
|
||||
def parse_message(self, message):
|
||||
role = str(message.get("role", "unknown"))
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return ParsedMessageEvent(
|
||||
role=role,
|
||||
content_type="text",
|
||||
text_len=len(content),
|
||||
has_cache_control=False,
|
||||
item_count=1,
|
||||
)
|
||||
|
||||
if isinstance(content, list):
|
||||
text_len = 0
|
||||
has_cache_control = False
|
||||
item_types = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
text_len += len(item)
|
||||
item_types.append("text")
|
||||
continue
|
||||
if not isinstance(item, dict):
|
||||
text_len += len(compact_json(item))
|
||||
item_types.append(type(item).__name__)
|
||||
continue
|
||||
if "text" in item and isinstance(item["text"], str):
|
||||
text_len += len(item["text"])
|
||||
else:
|
||||
text_len += len(compact_json(item))
|
||||
has_cache_control = has_cache_control or "cache_control" in item
|
||||
item_types.append(str(item.get("type", "text" if "text" in item else "object")))
|
||||
content_type = "+".join(sorted(set(item_types))) if item_types else "list"
|
||||
return ParsedMessageEvent(
|
||||
role=role,
|
||||
content_type=content_type,
|
||||
text_len=text_len,
|
||||
has_cache_control=has_cache_control,
|
||||
item_count=len(content),
|
||||
)
|
||||
|
||||
if isinstance(content, dict):
|
||||
return ParsedMessageEvent(
|
||||
role=role,
|
||||
content_type=str(content.get("type", "object")),
|
||||
text_len=len(compact_json(content)),
|
||||
has_cache_control="cache_control" in content,
|
||||
item_count=1,
|
||||
)
|
||||
|
||||
return ParsedMessageEvent(
|
||||
role=role,
|
||||
content_type=type(content).__name__,
|
||||
text_len=len(str(content)),
|
||||
has_cache_control=False,
|
||||
item_count=1,
|
||||
)
|
||||
|
||||
def parse_tool(self, tool):
|
||||
function = tool.get("function", {})
|
||||
name = ""
|
||||
if isinstance(function, dict):
|
||||
name = str(function.get("name", ""))
|
||||
return ParsedToolSpec(
|
||||
name=name,
|
||||
tool_type=str(tool.get("type", "function")),
|
||||
)
|
||||
|
||||
def build_canonical_prompt(self, payload):
|
||||
return build_glm5_canonical_prompt(payload)
|
||||
|
||||
|
||||
class QwenTraceAdapter(GLMTraceAdapter):
|
||||
name = "qwen3-coder"
|
||||
|
||||
def detect(self, raw):
|
||||
if "request_params" not in raw:
|
||||
return False
|
||||
return infer_model_family_from_request_model(raw.get("request_model")) == "qwen3-coder"
|
||||
|
||||
def build_canonical_prompt(self, payload):
|
||||
return build_qwen3_coder_canonical_prompt(payload)
|
||||
|
||||
|
||||
def get_raw_adapter(raw: dict[str, Any]):
|
||||
for adapter in [QwenTraceAdapter(), GLMTraceAdapter()]:
|
||||
if adapter.detect(raw):
|
||||
return adapter
|
||||
raise ValueError("Unsupported raw trace format for trace_formatter.")
|
||||
Reference in New Issue
Block a user