Honor exact output lengths in replay requests

This commit is contained in:
2026-04-04 21:33:26 +08:00
parent 647d241725
commit b33d1356e7
2 changed files with 11 additions and 1 deletions

View File

@@ -59,7 +59,15 @@ def resolve_window_record(study: StudySpec, *, study_spec_path: Path) -> WindowR
raise TraceError(f"window {study.trace.window_id} does not define trace_file") raise TraceError(f"window {study.trace.window_id} does not define trace_file")
trace_path = Path(trace_file) trace_path = Path(trace_file)
if not trace_path.is_absolute(): if not trace_path.is_absolute():
trace_path = (windows_path.parent / trace_path).resolve() candidate = (windows_path.parent / trace_path).resolve()
if candidate.exists():
trace_path = candidate
else:
parts = trace_path.parts
if parts and parts[0] == "trace_windows":
trace_path = (windows_path.parent / Path(*parts[1:])).resolve()
else:
trace_path = candidate
return WindowRecord( return WindowRecord(
window_id=study.trace.window_id, window_id=study.trace.window_id,
trace_path=trace_path, trace_path=trace_path,
@@ -154,6 +162,7 @@ def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[Win
} }
completion_tokens = _coerce_completion_tokens(row) completion_tokens = _coerce_completion_tokens(row)
if completion_tokens is not None: if completion_tokens is not None:
body["min_tokens"] = completion_tokens
body["max_tokens"] = completion_tokens body["max_tokens"] = completion_tokens
temperature = row.get("temperature") temperature = row.get("temperature")
if isinstance(temperature, (int, float)) and not isinstance(temperature, bool): if isinstance(temperature, (int, float)) and not isinstance(temperature, bool):

View File

@@ -230,6 +230,7 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(len(requests), 1) self.assertEqual(len(requests), 1)
message = requests[0].body["messages"][0]["content"] message = requests[0].body["messages"][0]["content"]
self.assertEqual(message.count("token"), 8) self.assertEqual(message.count("token"), 8)
self.assertEqual(requests[0].body["min_tokens"], 16)
self.assertEqual(requests[0].body["max_tokens"], 16) self.assertEqual(requests[0].body["max_tokens"], 16)
def test_slo_evaluation_step_and_fixed_rules(self) -> None: def test_slo_evaluation_step_and_fixed_rules(self) -> None: