Files
ali-trace-tools/trace_model_meta/registry.py
2026-04-21 15:44:47 +00:00

202 lines
6.1 KiB
Python

from __future__ import annotations
import csv
import json
from dataclasses import dataclass
from pathlib import Path
MODEL_META_ROOT = Path(__file__).resolve().parent
@dataclass(frozen=True)
class ModelMeta:
family: str
provider: str
model_name: str
request_model_hints: tuple[str, ...]
@property
def model_dir(self) -> Path:
return MODEL_META_ROOT / self.provider / self.model_name
@property
def tokenizer_path(self) -> Path:
return self.model_dir / "tokenizer.json"
@property
def chat_template_path(self) -> Path:
return self.model_dir / "chat_template.jinja"
MODEL_REGISTRY = {
"glm5": ModelMeta(
family="glm5",
provider="ZhipuAI",
model_name="GLM-5-FP8",
request_model_hints=("glm", "zhipu"),
),
"qwen3-coder": ModelMeta(
family="qwen3-coder",
provider="Qwen",
model_name="Qwen3-Coder-480B-A35B-Instruct",
request_model_hints=("qwen3-coder", "qwen3 coder", "qwen3_coder"),
),
}
MODEL_ALIASES = {
"glm": "glm5",
"glm5": "glm5",
"zhipu-glm5": "glm5",
"zhipuai-glm5": "glm5",
"qwen": "qwen3-coder",
"qwen3": "qwen3-coder",
"qwen3-coder": "qwen3-coder",
"qwen3_coder": "qwen3-coder",
"coder": "qwen3-coder",
}
def infer_model_family_from_request_model(request_model: str | None) -> str | None:
text = str(request_model or "").strip().lower()
if not text:
return None
for family, meta in MODEL_REGISTRY.items():
if any(hint in text for hint in meta.request_model_hints):
return family
return None
def _infer_model_family_from_path(input_path: str | Path | None) -> str | None:
text = str(input_path or "").strip().lower()
if not text:
return None
if "qwen3-coder" in text or "qwen3_coder" in text:
return "qwen3-coder"
if "glm5" in text or "trace-glm" in text:
return "glm5"
return None
def detect_model_family_from_trace_file(path: str | Path) -> str | None:
resolved = Path(path)
with resolved.open("r", encoding="utf-8") as handle:
for line in handle:
stripped = line.strip()
if not stripped:
continue
raw = json.loads(stripped)
if isinstance(raw.get("meta"), dict):
meta = raw["meta"]
family = str(meta.get("model_family", "")).strip()
if family:
return resolve_model_family(family)
inferred = infer_model_family_from_request_model(meta.get("request_model"))
if inferred:
return inferred
inferred = infer_model_family_from_request_model(raw.get("request_model"))
if inferred:
return inferred
break
return _infer_model_family_from_path(path)
def detect_model_family_from_features(path: str | Path) -> str | None:
resolved = Path(path)
with resolved.open("r", encoding="utf-8") as handle:
reader = csv.DictReader(handle)
for row in reader:
inferred = infer_model_family_from_request_model(row.get("model"))
if inferred:
return inferred
break
return _infer_model_family_from_path(path)
def detect_model_family_from_records(records) -> str | None:
for record in records:
inferred = infer_model_family_from_request_model(record.meta.request_model)
if inferred:
return inferred
break
return None
def resolve_model_family(
model_family: str | None = None,
*,
request_model: str | None = None,
input_path: str | Path | None = None,
features_path: str | Path | None = None,
records=None,
) -> str:
candidate = str(model_family or "auto").strip().lower()
if candidate and candidate != "auto":
if candidate in MODEL_ALIASES:
return MODEL_ALIASES[candidate]
raise ValueError(f"Unsupported model family: {model_family}")
inferred = infer_model_family_from_request_model(request_model)
if inferred:
return inferred
if records is not None:
inferred = detect_model_family_from_records(records)
if inferred:
return inferred
if features_path is not None:
inferred = detect_model_family_from_features(features_path)
if inferred:
return inferred
if input_path is not None:
inferred = detect_model_family_from_trace_file(input_path)
if inferred:
return inferred
return "glm5"
def get_model_meta(model_family: str | None = None, *, model_meta_dir: str | Path | None = None, **kwargs) -> ModelMeta:
family = resolve_model_family(model_family, **kwargs)
base_meta = MODEL_REGISTRY[family]
if model_meta_dir is None:
return base_meta
custom_root = Path(model_meta_dir)
custom_model_dir = custom_root / base_meta.provider / base_meta.model_name
if not custom_model_dir.exists():
raise FileNotFoundError(f"Model meta directory not found for {family}: {custom_model_dir}")
return ModelMeta(
family=base_meta.family,
provider=base_meta.provider,
model_name=base_meta.model_name,
request_model_hints=base_meta.request_model_hints,
)
def resolve_chat_template_path(
model_family: str | None = None,
*,
model_meta_dir: str | Path | None = None,
**kwargs,
) -> Path:
family = resolve_model_family(model_family, **kwargs)
meta = MODEL_REGISTRY[family]
model_dir = Path(model_meta_dir) / meta.provider / meta.model_name if model_meta_dir else meta.model_dir
return model_dir / "chat_template.jinja"
def resolve_tokenizer_path(
tokenizer_path: str | Path | None = None,
*,
model_family: str | None = None,
model_meta_dir: str | Path | None = None,
**kwargs,
) -> str:
if tokenizer_path:
path = Path(tokenizer_path)
return str(path.parent if path.is_file() else path)
family = resolve_model_family(model_family, **kwargs)
meta = MODEL_REGISTRY[family]
model_dir = Path(model_meta_dir) / meta.provider / meta.model_name if model_meta_dir else meta.model_dir
return str(model_dir)