371 lines
14 KiB
Python
371 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from jinja2 import Environment
|
|
import re
|
|
|
|
from trace_analyzer.helpers import compact_json, parse_jsonish
|
|
from trace_model_meta import infer_model_family_from_request_model, resolve_chat_template_path
|
|
|
|
|
|
@dataclass
|
|
class ParsedMessageEvent:
|
|
role: str
|
|
content_type: str
|
|
text_len: int
|
|
has_cache_control: bool = False
|
|
item_count: int = 0
|
|
|
|
|
|
@dataclass
|
|
class ParsedToolSpec:
|
|
name: str
|
|
tool_type: str
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Qwen3CoderToolParserSpec:
|
|
tool_call_start_token: str = "<tool_call>"
|
|
tool_call_end_token: str = "</tool_call>"
|
|
tool_call_prefix: str = "<function="
|
|
function_end_token: str = "</function>"
|
|
parameter_prefix: str = "<parameter="
|
|
parameter_end_token: str = "</parameter>"
|
|
|
|
|
|
def _template_tojson(value, ensure_ascii=True):
|
|
return json.dumps(value, ensure_ascii=ensure_ascii)
|
|
|
|
|
|
@lru_cache(maxsize=4)
|
|
def _load_chat_template(path: str):
|
|
environment = Environment()
|
|
environment.filters["tojson"] = _template_tojson
|
|
return environment.from_string(Path(path).read_text(encoding="utf-8"))
|
|
|
|
|
|
def _load_glm5_chat_template():
|
|
return _load_chat_template(str(resolve_chat_template_path("glm5")))
|
|
|
|
|
|
def _load_qwen3_coder_chat_template():
|
|
return _load_chat_template(str(resolve_chat_template_path("qwen3-coder")))
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _load_qwen3_coder_tool_parser_spec() -> Qwen3CoderToolParserSpec:
|
|
defaults = Qwen3CoderToolParserSpec()
|
|
parser_path = resolve_chat_template_path("qwen3-coder").with_name("qwen3coder_tool_parser.py")
|
|
if not parser_path.exists():
|
|
return defaults
|
|
text = parser_path.read_text(encoding="utf-8")
|
|
|
|
def _extract(name: str, fallback: str) -> str:
|
|
pattern = rf'self\.{re.escape(name)}:\s*str\s*=\s*"([^"]+)"'
|
|
match = re.search(pattern, text)
|
|
return match.group(1) if match else fallback
|
|
|
|
return Qwen3CoderToolParserSpec(
|
|
tool_call_start_token=_extract("tool_call_start_token", defaults.tool_call_start_token),
|
|
tool_call_end_token=_extract("tool_call_end_token", defaults.tool_call_end_token),
|
|
tool_call_prefix=_extract("tool_call_prefix", defaults.tool_call_prefix),
|
|
function_end_token=_extract("function_end_token", defaults.function_end_token),
|
|
parameter_prefix=_extract("parameter_prefix", defaults.parameter_prefix),
|
|
parameter_end_token=_extract("parameter_end_token", defaults.parameter_end_token),
|
|
)
|
|
|
|
|
|
def _stringify_message_content_for_template(content):
|
|
if isinstance(content, list):
|
|
parts = []
|
|
for item in content:
|
|
if isinstance(item, str):
|
|
parts.append(item)
|
|
elif isinstance(item, dict) and isinstance(item.get("text"), str):
|
|
parts.append(item["text"])
|
|
else:
|
|
parts.append(compact_json(item))
|
|
return "".join(parts)
|
|
if isinstance(content, dict):
|
|
if isinstance(content.get("text"), str):
|
|
return content["text"]
|
|
return compact_json(content)
|
|
if content is None:
|
|
return ""
|
|
return str(content)
|
|
|
|
|
|
def _normalize_message_content_for_template(content, role=""):
|
|
if role == "tool" and isinstance(content, (list, dict)):
|
|
has_named_tool_refs = False
|
|
if isinstance(content, list):
|
|
has_named_tool_refs = any(isinstance(item, dict) and item.get("name") for item in content)
|
|
elif isinstance(content, dict):
|
|
has_named_tool_refs = bool(content.get("name"))
|
|
if not has_named_tool_refs:
|
|
return _stringify_message_content_for_template(content)
|
|
if isinstance(content, list):
|
|
normalized_items = []
|
|
for item in content:
|
|
if isinstance(item, dict) and "text" in item and "type" not in item:
|
|
normalized_item = dict(item)
|
|
normalized_item["type"] = "text"
|
|
normalized_items.append(normalized_item)
|
|
else:
|
|
normalized_items.append(item)
|
|
return normalized_items
|
|
if isinstance(content, dict) and "text" in content:
|
|
normalized_item = dict(content)
|
|
normalized_item.setdefault("type", "text")
|
|
return [normalized_item]
|
|
return content
|
|
|
|
|
|
def _normalize_tool_call_for_template(tool_call):
|
|
if not isinstance(tool_call, dict):
|
|
return tool_call
|
|
normalized = dict(tool_call)
|
|
function = normalized.get("function")
|
|
normalized_function = dict(function) if isinstance(function, dict) else None
|
|
if normalized_function is None and ("name" in normalized or "arguments" in normalized):
|
|
normalized_function = {}
|
|
if normalized_function is not None:
|
|
if "name" not in normalized_function and normalized.get("name"):
|
|
normalized_function["name"] = normalized["name"]
|
|
if "arguments" not in normalized_function and "arguments" in normalized:
|
|
normalized_function["arguments"] = normalized["arguments"]
|
|
arguments = parse_jsonish(normalized_function.get("arguments", {}))
|
|
normalized_function["arguments"] = arguments if isinstance(arguments, dict) else {"__raw_arguments__": arguments}
|
|
normalized["function"] = normalized_function
|
|
return normalized
|
|
|
|
|
|
def _normalize_tool_spec_for_template(tool):
|
|
if not isinstance(tool, dict):
|
|
return tool
|
|
normalized = dict(tool)
|
|
function = normalized.get("function")
|
|
if isinstance(function, dict):
|
|
normalized_function = dict(function)
|
|
parameters = parse_jsonish(normalized_function.get("parameters", {}))
|
|
if isinstance(parameters, dict):
|
|
normalized_function["parameters"] = parameters
|
|
normalized["function"] = normalized_function
|
|
return normalized
|
|
|
|
|
|
def _normalize_qwen_message_for_template(message):
|
|
if not isinstance(message, dict):
|
|
return message
|
|
normalized_message = dict(message)
|
|
normalized_message["content"] = _stringify_message_content_for_template(message.get("content"))
|
|
normalized_tool_calls = []
|
|
for tool_call in message.get("tool_calls", []):
|
|
normalized_tool_call = _normalize_tool_call_for_template(tool_call)
|
|
if isinstance(normalized_tool_call, dict):
|
|
function = normalized_tool_call.get("function")
|
|
if isinstance(function, dict):
|
|
arguments = function.get("arguments", {})
|
|
if not isinstance(arguments, dict):
|
|
function["arguments"] = {"__raw_arguments__": arguments}
|
|
normalized_tool_calls.append(normalized_tool_call)
|
|
normalized_message["tool_calls"] = normalized_tool_calls
|
|
return normalized_message
|
|
|
|
|
|
def _render_qwen3_coder_tool_call(tool_call, spec: Qwen3CoderToolParserSpec) -> str:
|
|
normalized_tool_call = _normalize_tool_call_for_template(tool_call)
|
|
if not isinstance(normalized_tool_call, dict):
|
|
return ""
|
|
function = normalized_tool_call.get("function", {})
|
|
if not isinstance(function, dict):
|
|
return ""
|
|
function_name = str(function.get("name", "")).strip()
|
|
if not function_name:
|
|
return ""
|
|
parts = [
|
|
spec.tool_call_start_token,
|
|
f"{spec.tool_call_prefix}{function_name}>",
|
|
]
|
|
arguments = function.get("arguments", {})
|
|
if isinstance(arguments, dict):
|
|
for arg_name, arg_value in arguments.items():
|
|
rendered_value = (
|
|
json.dumps(arg_value, ensure_ascii=False)
|
|
if isinstance(arg_value, (dict, list))
|
|
else str(arg_value)
|
|
)
|
|
parts.extend(
|
|
[
|
|
f"{spec.parameter_prefix}{arg_name}>",
|
|
rendered_value,
|
|
spec.parameter_end_token,
|
|
]
|
|
)
|
|
parts.extend([spec.function_end_token, spec.tool_call_end_token])
|
|
return "\n".join(parts)
|
|
|
|
|
|
def _normalize_qwen_message_with_tool_parser(message, spec: Qwen3CoderToolParserSpec):
|
|
normalized_message = _normalize_qwen_message_for_template(message)
|
|
if not isinstance(normalized_message, dict):
|
|
return normalized_message
|
|
if normalized_message.get("role") != "assistant":
|
|
return normalized_message
|
|
tool_calls = normalized_message.get("tool_calls", [])
|
|
if not isinstance(tool_calls, list) or not tool_calls:
|
|
return normalized_message
|
|
rendered_tool_calls = [chunk for chunk in (_render_qwen3_coder_tool_call(tool_call, spec) for tool_call in tool_calls) if chunk]
|
|
if not rendered_tool_calls:
|
|
return normalized_message
|
|
prefix = str(normalized_message.get("content", "") or "").strip()
|
|
combined_parts = [prefix] if prefix else []
|
|
combined_parts.extend(rendered_tool_calls)
|
|
normalized_message["content"] = "\n".join(combined_parts)
|
|
normalized_message["tool_calls"] = []
|
|
return normalized_message
|
|
|
|
|
|
def build_glm5_canonical_prompt(payload):
|
|
input_payload = payload.get("input", {}) if isinstance(payload, dict) else {}
|
|
parameters = payload.get("parameters", {}) if isinstance(payload, dict) else {}
|
|
messages = []
|
|
for message in input_payload.get("messages", []):
|
|
if not isinstance(message, dict):
|
|
messages.append(message)
|
|
continue
|
|
normalized_message = dict(message)
|
|
normalized_message["content"] = _normalize_message_content_for_template(
|
|
message.get("content"),
|
|
role=str(message.get("role", "")),
|
|
)
|
|
normalized_message["tool_calls"] = [
|
|
_normalize_tool_call_for_template(tool_call) for tool_call in message.get("tool_calls", [])
|
|
]
|
|
messages.append(normalized_message)
|
|
tools = [_normalize_tool_spec_for_template(tool) for tool in parameters.get("tools", []) if isinstance(tool, dict)]
|
|
return _load_glm5_chat_template().render(
|
|
messages=messages,
|
|
tools=tools,
|
|
add_generation_prompt=True,
|
|
)
|
|
|
|
|
|
def build_qwen3_coder_canonical_prompt(payload):
|
|
input_payload = payload.get("input", {}) if isinstance(payload, dict) else {}
|
|
parameters = payload.get("parameters", {}) if isinstance(payload, dict) else {}
|
|
tool_parser_spec = _load_qwen3_coder_tool_parser_spec()
|
|
messages = [
|
|
_normalize_qwen_message_with_tool_parser(message, tool_parser_spec)
|
|
for message in input_payload.get("messages", [])
|
|
]
|
|
if not messages:
|
|
messages = [{"role": "system", "content": ""}]
|
|
tools = [_normalize_tool_spec_for_template(tool) for tool in parameters.get("tools", []) if isinstance(tool, dict)]
|
|
return _load_qwen3_coder_chat_template().render(
|
|
messages=messages,
|
|
tools=tools,
|
|
add_generation_prompt=True,
|
|
)
|
|
|
|
|
|
class GLMTraceAdapter:
|
|
name = "glm"
|
|
|
|
def detect(self, raw):
|
|
return "request_params" in raw and "response_params" in raw
|
|
|
|
def parse_message(self, message):
|
|
role = str(message.get("role", "unknown"))
|
|
content = message.get("content")
|
|
if isinstance(content, str):
|
|
return ParsedMessageEvent(
|
|
role=role,
|
|
content_type="text",
|
|
text_len=len(content),
|
|
has_cache_control=False,
|
|
item_count=1,
|
|
)
|
|
|
|
if isinstance(content, list):
|
|
text_len = 0
|
|
has_cache_control = False
|
|
item_types = []
|
|
for item in content:
|
|
if isinstance(item, str):
|
|
text_len += len(item)
|
|
item_types.append("text")
|
|
continue
|
|
if not isinstance(item, dict):
|
|
text_len += len(compact_json(item))
|
|
item_types.append(type(item).__name__)
|
|
continue
|
|
if "text" in item and isinstance(item["text"], str):
|
|
text_len += len(item["text"])
|
|
else:
|
|
text_len += len(compact_json(item))
|
|
has_cache_control = has_cache_control or "cache_control" in item
|
|
item_types.append(str(item.get("type", "text" if "text" in item else "object")))
|
|
content_type = "+".join(sorted(set(item_types))) if item_types else "list"
|
|
return ParsedMessageEvent(
|
|
role=role,
|
|
content_type=content_type,
|
|
text_len=text_len,
|
|
has_cache_control=has_cache_control,
|
|
item_count=len(content),
|
|
)
|
|
|
|
if isinstance(content, dict):
|
|
return ParsedMessageEvent(
|
|
role=role,
|
|
content_type=str(content.get("type", "object")),
|
|
text_len=len(compact_json(content)),
|
|
has_cache_control="cache_control" in content,
|
|
item_count=1,
|
|
)
|
|
|
|
return ParsedMessageEvent(
|
|
role=role,
|
|
content_type=type(content).__name__,
|
|
text_len=len(str(content)),
|
|
has_cache_control=False,
|
|
item_count=1,
|
|
)
|
|
|
|
def parse_tool(self, tool):
|
|
function = tool.get("function", {})
|
|
name = ""
|
|
if isinstance(function, dict):
|
|
name = str(function.get("name", ""))
|
|
return ParsedToolSpec(
|
|
name=name,
|
|
tool_type=str(tool.get("type", "function")),
|
|
)
|
|
|
|
def build_canonical_prompt(self, payload):
|
|
return build_glm5_canonical_prompt(payload)
|
|
|
|
|
|
class QwenTraceAdapter(GLMTraceAdapter):
|
|
name = "qwen3-coder"
|
|
|
|
def detect(self, raw):
|
|
if "request_params" not in raw:
|
|
return False
|
|
return infer_model_family_from_request_model(raw.get("request_model")) == "qwen3-coder"
|
|
|
|
def build_canonical_prompt(self, payload):
|
|
return build_qwen3_coder_canonical_prompt(payload)
|
|
|
|
|
|
def get_raw_adapter(raw: dict[str, Any]):
|
|
for adapter in [QwenTraceAdapter(), GLMTraceAdapter()]:
|
|
if adapter.detect(raw):
|
|
return adapter
|
|
raise ValueError("Unsupported raw trace format for trace_formatter.")
|