Initial commit
This commit is contained in:
201
trace_model_meta/registry.py
Normal file
201
trace_model_meta/registry.py
Normal file
@@ -0,0 +1,201 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
MODEL_META_ROOT = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelMeta:
|
||||
family: str
|
||||
provider: str
|
||||
model_name: str
|
||||
request_model_hints: tuple[str, ...]
|
||||
|
||||
@property
|
||||
def model_dir(self) -> Path:
|
||||
return MODEL_META_ROOT / self.provider / self.model_name
|
||||
|
||||
@property
|
||||
def tokenizer_path(self) -> Path:
|
||||
return self.model_dir / "tokenizer.json"
|
||||
|
||||
@property
|
||||
def chat_template_path(self) -> Path:
|
||||
return self.model_dir / "chat_template.jinja"
|
||||
|
||||
|
||||
MODEL_REGISTRY = {
|
||||
"glm5": ModelMeta(
|
||||
family="glm5",
|
||||
provider="ZhipuAI",
|
||||
model_name="GLM-5-FP8",
|
||||
request_model_hints=("glm", "zhipu"),
|
||||
),
|
||||
"qwen3-coder": ModelMeta(
|
||||
family="qwen3-coder",
|
||||
provider="Qwen",
|
||||
model_name="Qwen3-Coder-480B-A35B-Instruct",
|
||||
request_model_hints=("qwen3-coder", "qwen3 coder", "qwen3_coder"),
|
||||
),
|
||||
}
|
||||
|
||||
MODEL_ALIASES = {
|
||||
"glm": "glm5",
|
||||
"glm5": "glm5",
|
||||
"zhipu-glm5": "glm5",
|
||||
"zhipuai-glm5": "glm5",
|
||||
"qwen": "qwen3-coder",
|
||||
"qwen3": "qwen3-coder",
|
||||
"qwen3-coder": "qwen3-coder",
|
||||
"qwen3_coder": "qwen3-coder",
|
||||
"coder": "qwen3-coder",
|
||||
}
|
||||
|
||||
|
||||
def infer_model_family_from_request_model(request_model: str | None) -> str | None:
|
||||
text = str(request_model or "").strip().lower()
|
||||
if not text:
|
||||
return None
|
||||
for family, meta in MODEL_REGISTRY.items():
|
||||
if any(hint in text for hint in meta.request_model_hints):
|
||||
return family
|
||||
return None
|
||||
|
||||
|
||||
def _infer_model_family_from_path(input_path: str | Path | None) -> str | None:
|
||||
text = str(input_path or "").strip().lower()
|
||||
if not text:
|
||||
return None
|
||||
if "qwen3-coder" in text or "qwen3_coder" in text:
|
||||
return "qwen3-coder"
|
||||
if "glm5" in text or "trace-glm" in text:
|
||||
return "glm5"
|
||||
return None
|
||||
|
||||
|
||||
def detect_model_family_from_trace_file(path: str | Path) -> str | None:
|
||||
resolved = Path(path)
|
||||
with resolved.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
raw = json.loads(stripped)
|
||||
if isinstance(raw.get("meta"), dict):
|
||||
meta = raw["meta"]
|
||||
family = str(meta.get("model_family", "")).strip()
|
||||
if family:
|
||||
return resolve_model_family(family)
|
||||
inferred = infer_model_family_from_request_model(meta.get("request_model"))
|
||||
if inferred:
|
||||
return inferred
|
||||
inferred = infer_model_family_from_request_model(raw.get("request_model"))
|
||||
if inferred:
|
||||
return inferred
|
||||
break
|
||||
return _infer_model_family_from_path(path)
|
||||
|
||||
|
||||
def detect_model_family_from_features(path: str | Path) -> str | None:
|
||||
resolved = Path(path)
|
||||
with resolved.open("r", encoding="utf-8") as handle:
|
||||
reader = csv.DictReader(handle)
|
||||
for row in reader:
|
||||
inferred = infer_model_family_from_request_model(row.get("model"))
|
||||
if inferred:
|
||||
return inferred
|
||||
break
|
||||
return _infer_model_family_from_path(path)
|
||||
|
||||
|
||||
def detect_model_family_from_records(records) -> str | None:
|
||||
for record in records:
|
||||
inferred = infer_model_family_from_request_model(record.meta.request_model)
|
||||
if inferred:
|
||||
return inferred
|
||||
break
|
||||
return None
|
||||
|
||||
|
||||
def resolve_model_family(
|
||||
model_family: str | None = None,
|
||||
*,
|
||||
request_model: str | None = None,
|
||||
input_path: str | Path | None = None,
|
||||
features_path: str | Path | None = None,
|
||||
records=None,
|
||||
) -> str:
|
||||
candidate = str(model_family or "auto").strip().lower()
|
||||
if candidate and candidate != "auto":
|
||||
if candidate in MODEL_ALIASES:
|
||||
return MODEL_ALIASES[candidate]
|
||||
raise ValueError(f"Unsupported model family: {model_family}")
|
||||
|
||||
inferred = infer_model_family_from_request_model(request_model)
|
||||
if inferred:
|
||||
return inferred
|
||||
if records is not None:
|
||||
inferred = detect_model_family_from_records(records)
|
||||
if inferred:
|
||||
return inferred
|
||||
if features_path is not None:
|
||||
inferred = detect_model_family_from_features(features_path)
|
||||
if inferred:
|
||||
return inferred
|
||||
if input_path is not None:
|
||||
inferred = detect_model_family_from_trace_file(input_path)
|
||||
if inferred:
|
||||
return inferred
|
||||
return "glm5"
|
||||
|
||||
|
||||
def get_model_meta(model_family: str | None = None, *, model_meta_dir: str | Path | None = None, **kwargs) -> ModelMeta:
|
||||
family = resolve_model_family(model_family, **kwargs)
|
||||
base_meta = MODEL_REGISTRY[family]
|
||||
if model_meta_dir is None:
|
||||
return base_meta
|
||||
|
||||
custom_root = Path(model_meta_dir)
|
||||
custom_model_dir = custom_root / base_meta.provider / base_meta.model_name
|
||||
if not custom_model_dir.exists():
|
||||
raise FileNotFoundError(f"Model meta directory not found for {family}: {custom_model_dir}")
|
||||
return ModelMeta(
|
||||
family=base_meta.family,
|
||||
provider=base_meta.provider,
|
||||
model_name=base_meta.model_name,
|
||||
request_model_hints=base_meta.request_model_hints,
|
||||
)
|
||||
|
||||
|
||||
def resolve_chat_template_path(
|
||||
model_family: str | None = None,
|
||||
*,
|
||||
model_meta_dir: str | Path | None = None,
|
||||
**kwargs,
|
||||
) -> Path:
|
||||
family = resolve_model_family(model_family, **kwargs)
|
||||
meta = MODEL_REGISTRY[family]
|
||||
model_dir = Path(model_meta_dir) / meta.provider / meta.model_name if model_meta_dir else meta.model_dir
|
||||
return model_dir / "chat_template.jinja"
|
||||
|
||||
|
||||
def resolve_tokenizer_path(
|
||||
tokenizer_path: str | Path | None = None,
|
||||
*,
|
||||
model_family: str | None = None,
|
||||
model_meta_dir: str | Path | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if tokenizer_path:
|
||||
path = Path(tokenizer_path)
|
||||
return str(path.parent if path.is_file() else path)
|
||||
|
||||
family = resolve_model_family(model_family, **kwargs)
|
||||
meta = MODEL_REGISTRY[family]
|
||||
model_dir = Path(model_meta_dir) / meta.provider / meta.model_name if model_meta_dir else meta.model_dir
|
||||
return str(model_dir)
|
||||
Reference in New Issue
Block a user