Validate served model name consistency

This commit is contained in:
2026-04-09 22:50:23 +08:00
parent baba1a3c4f
commit d582a8ed1b
3 changed files with 30 additions and 3 deletions

View File

@@ -9,7 +9,7 @@
}, },
"model": { "model": {
"model_id": "qwen3-235b-a22b-256k-0717-internal", "model_id": "qwen3-235b-a22b-256k-0717-internal",
"served_model_name": "qwen3-235b-decode-aituner" "served_model_name": "qwen3-235b-decode"
}, },
"engine": { "engine": {
"engine_name": "vllm", "engine_name": "vllm",
@@ -107,7 +107,7 @@
"base_flags": { "base_flags": {
"host": "127.0.0.1", "host": "127.0.0.1",
"port": 18120, "port": 18120,
"served-model-name": "qwen3-235b-decode-aituner", "served-model-name": "qwen3-235b-decode",
"gpu-memory-utilization": 0.75, "gpu-memory-utilization": 0.75,
"max-model-len": 262144, "max-model-len": 262144,
"enable-chunked-prefill": true, "enable-chunked-prefill": true,

View File

@@ -576,7 +576,7 @@ class StudySpec:
@classmethod @classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "StudySpec": def from_dict(cls, data: Mapping[str, Any]) -> "StudySpec":
return cls( study = cls(
study_id=_require_str(data.get("study_id"), context="study_id"), study_id=_require_str(data.get("study_id"), context="study_id"),
hardware=HardwareSpec.from_dict( hardware=HardwareSpec.from_dict(
_require_mapping(data.get("hardware"), context="hardware") _require_mapping(data.get("hardware"), context="hardware")
@@ -595,6 +595,14 @@ class StudySpec:
if data.get("capability_profile_path") if data.get("capability_profile_path")
else None, else None,
) )
served_model_name = str(study.engine.base_flags.get("served-model-name") or "").strip()
if served_model_name and served_model_name != study.model.served_model_name:
raise SpecError(
"model.served_model_name must match engine.base_flags['served-model-name'] "
f"when both are set. Got model.served_model_name={study.model.served_model_name!r} "
f"and engine.base_flags['served-model-name']={served_model_name!r}."
)
return study
@dataclass(frozen=True) @dataclass(frozen=True)

View File

@@ -256,6 +256,25 @@ class CoreFlowTests(unittest.TestCase):
self.assertIn("There is no TTFT SLO for this study.", prompt) self.assertIn("There is no TTFT SLO for this study.", prompt)
self.assertIn("decode-only", prompt) self.assertIn("decode-only", prompt)
def test_load_study_spec_rejects_mismatched_served_model_name(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(
tmp_path,
engine_overrides={
"base_flags": {
"host": "127.0.0.1",
"port": 8000,
"served-model-name": "engine-name",
}
},
)
payload = json.loads(study_path.read_text(encoding="utf-8"))
payload["model"]["served_model_name"] = "trace-name"
study_path.write_text(json.dumps(payload), encoding="utf-8")
with self.assertRaisesRegex(SpecError, "must match engine.base_flags"):
load_study_spec(study_path)
def test_bailian_endpoint_defaults(self) -> None: def test_bailian_endpoint_defaults(self) -> None:
endpoint = LLMEndpointSpec.from_dict({"provider": "bailian", "model": "qwen-plus"}) endpoint = LLMEndpointSpec.from_dict({"provider": "bailian", "model": "qwen-plus"})
self.assertEqual(endpoint.provider, "bailian") self.assertEqual(endpoint.provider, "bailian")