Validate served model name consistency
This commit is contained in:
@@ -9,7 +9,7 @@
|
||||
},
|
||||
"model": {
|
||||
"model_id": "qwen3-235b-a22b-256k-0717-internal",
|
||||
"served_model_name": "qwen3-235b-decode-aituner"
|
||||
"served_model_name": "qwen3-235b-decode"
|
||||
},
|
||||
"engine": {
|
||||
"engine_name": "vllm",
|
||||
@@ -107,7 +107,7 @@
|
||||
"base_flags": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 18120,
|
||||
"served-model-name": "qwen3-235b-decode-aituner",
|
||||
"served-model-name": "qwen3-235b-decode",
|
||||
"gpu-memory-utilization": 0.75,
|
||||
"max-model-len": 262144,
|
||||
"enable-chunked-prefill": true,
|
||||
|
||||
@@ -576,7 +576,7 @@ class StudySpec:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "StudySpec":
|
||||
return cls(
|
||||
study = cls(
|
||||
study_id=_require_str(data.get("study_id"), context="study_id"),
|
||||
hardware=HardwareSpec.from_dict(
|
||||
_require_mapping(data.get("hardware"), context="hardware")
|
||||
@@ -595,6 +595,14 @@ class StudySpec:
|
||||
if data.get("capability_profile_path")
|
||||
else None,
|
||||
)
|
||||
served_model_name = str(study.engine.base_flags.get("served-model-name") or "").strip()
|
||||
if served_model_name and served_model_name != study.model.served_model_name:
|
||||
raise SpecError(
|
||||
"model.served_model_name must match engine.base_flags['served-model-name'] "
|
||||
f"when both are set. Got model.served_model_name={study.model.served_model_name!r} "
|
||||
f"and engine.base_flags['served-model-name']={served_model_name!r}."
|
||||
)
|
||||
return study
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -256,6 +256,25 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertIn("There is no TTFT SLO for this study.", prompt)
|
||||
self.assertIn("decode-only", prompt)
|
||||
|
||||
def test_load_study_spec_rejects_mismatched_served_model_name(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(
|
||||
tmp_path,
|
||||
engine_overrides={
|
||||
"base_flags": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 8000,
|
||||
"served-model-name": "engine-name",
|
||||
}
|
||||
},
|
||||
)
|
||||
payload = json.loads(study_path.read_text(encoding="utf-8"))
|
||||
payload["model"]["served_model_name"] = "trace-name"
|
||||
study_path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
with self.assertRaisesRegex(SpecError, "must match engine.base_flags"):
|
||||
load_study_spec(study_path)
|
||||
|
||||
def test_bailian_endpoint_defaults(self) -> None:
|
||||
endpoint = LLMEndpointSpec.from_dict({"provider": "bailian", "model": "qwen-plus"})
|
||||
self.assertEqual(endpoint.provider, "bailian")
|
||||
|
||||
Reference in New Issue
Block a user