Validate served model name consistency
This commit is contained in:
@@ -256,6 +256,25 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertIn("There is no TTFT SLO for this study.", prompt)
|
||||
self.assertIn("decode-only", prompt)
|
||||
|
||||
def test_load_study_spec_rejects_mismatched_served_model_name(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(
|
||||
tmp_path,
|
||||
engine_overrides={
|
||||
"base_flags": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 8000,
|
||||
"served-model-name": "engine-name",
|
||||
}
|
||||
},
|
||||
)
|
||||
payload = json.loads(study_path.read_text(encoding="utf-8"))
|
||||
payload["model"]["served_model_name"] = "trace-name"
|
||||
study_path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
with self.assertRaisesRegex(SpecError, "must match engine.base_flags"):
|
||||
load_study_spec(study_path)
|
||||
|
||||
def test_bailian_endpoint_defaults(self) -> None:
|
||||
endpoint = LLMEndpointSpec.from_dict({"provider": "bailian", "model": "qwen-plus"})
|
||||
self.assertEqual(endpoint.provider, "bailian")
|
||||
|
||||
Reference in New Issue
Block a user