202 lines
6.1 KiB
Python
202 lines
6.1 KiB
Python
from __future__ import annotations
|
|
|
|
import csv
|
|
import json
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
|
|
MODEL_META_ROOT = Path(__file__).resolve().parent
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ModelMeta:
|
|
family: str
|
|
provider: str
|
|
model_name: str
|
|
request_model_hints: tuple[str, ...]
|
|
|
|
@property
|
|
def model_dir(self) -> Path:
|
|
return MODEL_META_ROOT / self.provider / self.model_name
|
|
|
|
@property
|
|
def tokenizer_path(self) -> Path:
|
|
return self.model_dir / "tokenizer.json"
|
|
|
|
@property
|
|
def chat_template_path(self) -> Path:
|
|
return self.model_dir / "chat_template.jinja"
|
|
|
|
|
|
MODEL_REGISTRY = {
|
|
"glm5": ModelMeta(
|
|
family="glm5",
|
|
provider="ZhipuAI",
|
|
model_name="GLM-5-FP8",
|
|
request_model_hints=("glm", "zhipu"),
|
|
),
|
|
"qwen3-coder": ModelMeta(
|
|
family="qwen3-coder",
|
|
provider="Qwen",
|
|
model_name="Qwen3-Coder-480B-A35B-Instruct",
|
|
request_model_hints=("qwen3-coder", "qwen3 coder", "qwen3_coder"),
|
|
),
|
|
}
|
|
|
|
MODEL_ALIASES = {
|
|
"glm": "glm5",
|
|
"glm5": "glm5",
|
|
"zhipu-glm5": "glm5",
|
|
"zhipuai-glm5": "glm5",
|
|
"qwen": "qwen3-coder",
|
|
"qwen3": "qwen3-coder",
|
|
"qwen3-coder": "qwen3-coder",
|
|
"qwen3_coder": "qwen3-coder",
|
|
"coder": "qwen3-coder",
|
|
}
|
|
|
|
|
|
def infer_model_family_from_request_model(request_model: str | None) -> str | None:
|
|
text = str(request_model or "").strip().lower()
|
|
if not text:
|
|
return None
|
|
for family, meta in MODEL_REGISTRY.items():
|
|
if any(hint in text for hint in meta.request_model_hints):
|
|
return family
|
|
return None
|
|
|
|
|
|
def _infer_model_family_from_path(input_path: str | Path | None) -> str | None:
|
|
text = str(input_path or "").strip().lower()
|
|
if not text:
|
|
return None
|
|
if "qwen3-coder" in text or "qwen3_coder" in text:
|
|
return "qwen3-coder"
|
|
if "glm5" in text or "trace-glm" in text:
|
|
return "glm5"
|
|
return None
|
|
|
|
|
|
def detect_model_family_from_trace_file(path: str | Path) -> str | None:
|
|
resolved = Path(path)
|
|
with resolved.open("r", encoding="utf-8") as handle:
|
|
for line in handle:
|
|
stripped = line.strip()
|
|
if not stripped:
|
|
continue
|
|
raw = json.loads(stripped)
|
|
if isinstance(raw.get("meta"), dict):
|
|
meta = raw["meta"]
|
|
family = str(meta.get("model_family", "")).strip()
|
|
if family:
|
|
return resolve_model_family(family)
|
|
inferred = infer_model_family_from_request_model(meta.get("request_model"))
|
|
if inferred:
|
|
return inferred
|
|
inferred = infer_model_family_from_request_model(raw.get("request_model"))
|
|
if inferred:
|
|
return inferred
|
|
break
|
|
return _infer_model_family_from_path(path)
|
|
|
|
|
|
def detect_model_family_from_features(path: str | Path) -> str | None:
|
|
resolved = Path(path)
|
|
with resolved.open("r", encoding="utf-8") as handle:
|
|
reader = csv.DictReader(handle)
|
|
for row in reader:
|
|
inferred = infer_model_family_from_request_model(row.get("model"))
|
|
if inferred:
|
|
return inferred
|
|
break
|
|
return _infer_model_family_from_path(path)
|
|
|
|
|
|
def detect_model_family_from_records(records) -> str | None:
|
|
for record in records:
|
|
inferred = infer_model_family_from_request_model(record.meta.request_model)
|
|
if inferred:
|
|
return inferred
|
|
break
|
|
return None
|
|
|
|
|
|
def resolve_model_family(
|
|
model_family: str | None = None,
|
|
*,
|
|
request_model: str | None = None,
|
|
input_path: str | Path | None = None,
|
|
features_path: str | Path | None = None,
|
|
records=None,
|
|
) -> str:
|
|
candidate = str(model_family or "auto").strip().lower()
|
|
if candidate and candidate != "auto":
|
|
if candidate in MODEL_ALIASES:
|
|
return MODEL_ALIASES[candidate]
|
|
raise ValueError(f"Unsupported model family: {model_family}")
|
|
|
|
inferred = infer_model_family_from_request_model(request_model)
|
|
if inferred:
|
|
return inferred
|
|
if records is not None:
|
|
inferred = detect_model_family_from_records(records)
|
|
if inferred:
|
|
return inferred
|
|
if features_path is not None:
|
|
inferred = detect_model_family_from_features(features_path)
|
|
if inferred:
|
|
return inferred
|
|
if input_path is not None:
|
|
inferred = detect_model_family_from_trace_file(input_path)
|
|
if inferred:
|
|
return inferred
|
|
return "glm5"
|
|
|
|
|
|
def get_model_meta(model_family: str | None = None, *, model_meta_dir: str | Path | None = None, **kwargs) -> ModelMeta:
|
|
family = resolve_model_family(model_family, **kwargs)
|
|
base_meta = MODEL_REGISTRY[family]
|
|
if model_meta_dir is None:
|
|
return base_meta
|
|
|
|
custom_root = Path(model_meta_dir)
|
|
custom_model_dir = custom_root / base_meta.provider / base_meta.model_name
|
|
if not custom_model_dir.exists():
|
|
raise FileNotFoundError(f"Model meta directory not found for {family}: {custom_model_dir}")
|
|
return ModelMeta(
|
|
family=base_meta.family,
|
|
provider=base_meta.provider,
|
|
model_name=base_meta.model_name,
|
|
request_model_hints=base_meta.request_model_hints,
|
|
)
|
|
|
|
|
|
def resolve_chat_template_path(
|
|
model_family: str | None = None,
|
|
*,
|
|
model_meta_dir: str | Path | None = None,
|
|
**kwargs,
|
|
) -> Path:
|
|
family = resolve_model_family(model_family, **kwargs)
|
|
meta = MODEL_REGISTRY[family]
|
|
model_dir = Path(model_meta_dir) / meta.provider / meta.model_name if model_meta_dir else meta.model_dir
|
|
return model_dir / "chat_template.jinja"
|
|
|
|
|
|
def resolve_tokenizer_path(
|
|
tokenizer_path: str | Path | None = None,
|
|
*,
|
|
model_family: str | None = None,
|
|
model_meta_dir: str | Path | None = None,
|
|
**kwargs,
|
|
) -> str:
|
|
if tokenizer_path:
|
|
path = Path(tokenizer_path)
|
|
return str(path.parent if path.is_file() else path)
|
|
|
|
family = resolve_model_family(model_family, **kwargs)
|
|
meta = MODEL_REGISTRY[family]
|
|
model_dir = Path(model_meta_dir) / meta.provider / meta.model_name if model_meta_dir else meta.model_dir
|
|
return str(model_dir)
|