Validate served model name consistency
This commit is contained in:
@@ -9,7 +9,7 @@
|
|||||||
},
|
},
|
||||||
"model": {
|
"model": {
|
||||||
"model_id": "qwen3-235b-a22b-256k-0717-internal",
|
"model_id": "qwen3-235b-a22b-256k-0717-internal",
|
||||||
"served_model_name": "qwen3-235b-decode-aituner"
|
"served_model_name": "qwen3-235b-decode"
|
||||||
},
|
},
|
||||||
"engine": {
|
"engine": {
|
||||||
"engine_name": "vllm",
|
"engine_name": "vllm",
|
||||||
@@ -107,7 +107,7 @@
|
|||||||
"base_flags": {
|
"base_flags": {
|
||||||
"host": "127.0.0.1",
|
"host": "127.0.0.1",
|
||||||
"port": 18120,
|
"port": 18120,
|
||||||
"served-model-name": "qwen3-235b-decode-aituner",
|
"served-model-name": "qwen3-235b-decode",
|
||||||
"gpu-memory-utilization": 0.75,
|
"gpu-memory-utilization": 0.75,
|
||||||
"max-model-len": 262144,
|
"max-model-len": 262144,
|
||||||
"enable-chunked-prefill": true,
|
"enable-chunked-prefill": true,
|
||||||
|
|||||||
@@ -576,7 +576,7 @@ class StudySpec:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Mapping[str, Any]) -> "StudySpec":
|
def from_dict(cls, data: Mapping[str, Any]) -> "StudySpec":
|
||||||
return cls(
|
study = cls(
|
||||||
study_id=_require_str(data.get("study_id"), context="study_id"),
|
study_id=_require_str(data.get("study_id"), context="study_id"),
|
||||||
hardware=HardwareSpec.from_dict(
|
hardware=HardwareSpec.from_dict(
|
||||||
_require_mapping(data.get("hardware"), context="hardware")
|
_require_mapping(data.get("hardware"), context="hardware")
|
||||||
@@ -595,6 +595,14 @@ class StudySpec:
|
|||||||
if data.get("capability_profile_path")
|
if data.get("capability_profile_path")
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
|
served_model_name = str(study.engine.base_flags.get("served-model-name") or "").strip()
|
||||||
|
if served_model_name and served_model_name != study.model.served_model_name:
|
||||||
|
raise SpecError(
|
||||||
|
"model.served_model_name must match engine.base_flags['served-model-name'] "
|
||||||
|
f"when both are set. Got model.served_model_name={study.model.served_model_name!r} "
|
||||||
|
f"and engine.base_flags['served-model-name']={served_model_name!r}."
|
||||||
|
)
|
||||||
|
return study
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|||||||
@@ -256,6 +256,25 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
self.assertIn("There is no TTFT SLO for this study.", prompt)
|
self.assertIn("There is no TTFT SLO for this study.", prompt)
|
||||||
self.assertIn("decode-only", prompt)
|
self.assertIn("decode-only", prompt)
|
||||||
|
|
||||||
|
def test_load_study_spec_rejects_mismatched_served_model_name(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
tmp_path = Path(tmp)
|
||||||
|
study_path = _write_study_assets(
|
||||||
|
tmp_path,
|
||||||
|
engine_overrides={
|
||||||
|
"base_flags": {
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": 8000,
|
||||||
|
"served-model-name": "engine-name",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
payload = json.loads(study_path.read_text(encoding="utf-8"))
|
||||||
|
payload["model"]["served_model_name"] = "trace-name"
|
||||||
|
study_path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
with self.assertRaisesRegex(SpecError, "must match engine.base_flags"):
|
||||||
|
load_study_spec(study_path)
|
||||||
|
|
||||||
def test_bailian_endpoint_defaults(self) -> None:
|
def test_bailian_endpoint_defaults(self) -> None:
|
||||||
endpoint = LLMEndpointSpec.from_dict({"provider": "bailian", "model": "qwen-plus"})
|
endpoint = LLMEndpointSpec.from_dict({"provider": "bailian", "model": "qwen-plus"})
|
||||||
self.assertEqual(endpoint.provider, "bailian")
|
self.assertEqual(endpoint.provider, "bailian")
|
||||||
|
|||||||
Reference in New Issue
Block a user