Files
xserv/tools/eval_gsm8k_fast.py
Gahow Wang 63f5599717 server: serve gpt-oss on a single GPU via the TP engine (world=1)
gpt-oss has no single-GPU engine path, so --tp 1 fell through to the
Qwen3-only engine and every request 503'd. Route gpt_oss to run_tp
even at tp=1: NCCL world-1 init works and all_reduce already no-ops
(bench-gpt-oss --tp 1 exercised this path). Quantized gpt-oss (22 GB
FP8 / 13 GB MXFP4) now serves on one 32 GB 5090.

Also fix eval_gsm8k_fast.py --gpu to accept a device list ("2,3"):
it was type=int, so any --tp 2 run pinned CUDA_VISIBLE_DEVICES to one
GPU and rank 1's set_device panicked while rank 0 spun in NCCL init.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 16:29:10 +08:00

207 lines
6.6 KiB
Python

#!/usr/bin/env python3
"""Fast GSM8K evaluation — keeps xserv-chat running, pipes problems via stdin.
Usage:
python eval_gsm8k_fast.py <model-dir> [--limit N] [--max-tokens N] [--tp N]
"""
import argparse
import json
import os
import re
import select
import subprocess
import sys
import time
from pathlib import Path
SCRIPT_DIR = Path(__file__).parent
DATA_PATH = SCRIPT_DIR / "bench" / "data" / "gsm8k.json"
XSERV_CHAT = Path(__file__).parent.parent / "target" / "release" / "xserv-chat"
SYSTEM_PROMPT = (
"You are a careful math problem solver. Solve the problem step by step. "
"Put your final numeric answer inside \\boxed{}."
)
_BOXED_RE = re.compile(r"\\boxed\s*\{([^{}]*)\}")
_NUM_RE = re.compile(r"-?\d+(?:,\d{3})*(?:\.\d+)?")
def normalize_num(s: str) -> str | None:
s = s.replace(",", "").strip()
try:
f = float(s)
except ValueError:
return None
return str(int(f)) if f.is_integer() else f"{f:g}"
def extract_answer(text: str) -> str | None:
if not text:
return None
boxed = _BOXED_RE.findall(text)
if boxed:
nums = _NUM_RE.findall(boxed[-1])
if nums:
return normalize_num(nums[-1])
nums = _NUM_RE.findall(text)
if nums:
return normalize_num(nums[-1])
return None
def read_until_prompt(proc, timeout=120):
"""Read from proc.stdout until we see 'user> ' prompt, return collected text."""
import io
buf = []
deadline = time.time() + timeout
fd = proc.stdout.fileno()
while time.time() < deadline:
remaining = deadline - time.time()
ready, _, _ = select.select([fd], [], [], min(remaining, 0.1))
if ready:
chunk = os.read(fd, 4096)
if not chunk:
break
text = chunk.decode("utf-8", errors="replace")
buf.append(text)
joined = "".join(buf)
if "user> " in joined.split("assistant>")[-1] if "assistant>" in joined else "user> " in joined:
# Check if we have a complete response (ends with "user> ")
if joined.rstrip().endswith("user>"):
break
return "".join(buf)
def main():
parser = argparse.ArgumentParser(description="Fast GSM8K eval via persistent xserv-chat")
parser.add_argument("model_dir", help="Model directory")
parser.add_argument("--limit", type=int, default=50, help="Number of problems")
parser.add_argument("--max-tokens", type=int, default=512, help="Max generation tokens")
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism")
parser.add_argument("--offset", type=int, default=0, help="Start from problem N")
parser.add_argument("--gpu", type=str, default="0",
help="CUDA_VISIBLE_DEVICES value, e.g. '0' or '2,3' (must cover --tp ranks)")
args = parser.parse_args()
if not DATA_PATH.exists():
print(f"Error: {DATA_PATH} not found", file=sys.stderr)
sys.exit(1)
with open(DATA_PATH) as f:
problems = json.load(f)
problems = problems[args.offset:args.offset + args.limit]
# Start xserv-chat as persistent subprocess
cmd = [
str(XSERV_CHAT), args.model_dir,
"--max-tokens", str(args.max_tokens),
"--max-seq-len", "2048",
"--system", SYSTEM_PROMPT,
"--no-color",
]
if args.tp > 1:
cmd += ["--tp", str(args.tp)]
env = {**os.environ, "CUDA_VISIBLE_DEVICES": str(args.gpu)}
print(f"GSM8K evaluation: {len(problems)} problems, model={args.model_dir}")
print(f"max_tokens={args.max_tokens}, tp={args.tp}, gpu={args.gpu}")
print(f"Starting xserv-chat...", file=sys.stderr)
proc = subprocess.Popen(
cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
env=env, bufsize=0,
)
# Wait for the "Ready" message on stderr, and first "user> " on stdout
# Read stderr in background to avoid blocking
import threading
stderr_lines = []
def read_stderr():
while True:
line = proc.stderr.readline()
if not line:
break
stderr_lines.append(line.decode("utf-8", errors="replace"))
t = threading.Thread(target=read_stderr, daemon=True)
t.start()
# Wait for first prompt
startup_text = read_until_prompt(proc, timeout=120)
time.sleep(0.5) # small settle
print(f"Model loaded. Starting evaluation.", file=sys.stderr)
print("-" * 72)
correct = 0
errors = 0
total_gen_time = 0.0
for i, prob in enumerate(problems):
question = prob["problem"].replace("\n", " ")
# Send question + newline
try:
proc.stdin.write((question + "\n").encode("utf-8"))
proc.stdin.flush()
except BrokenPipeError:
print(f"[E] Process died at problem {i}", file=sys.stderr)
break
t0 = time.time()
response_text = read_until_prompt(proc, timeout=120)
elapsed = time.time() - t0
total_gen_time += elapsed
# Extract the assistant response
response = ""
if "assistant>" in response_text:
parts = response_text.split("assistant>", 1)
if len(parts) > 1:
rest = parts[1]
if "user>" in rest:
response = rest[:rest.rindex("user>")].strip()
else:
response = rest.strip()
pred = extract_answer(response)
gold = normalize_num(prob["answer"])
is_correct = pred is not None and gold is not None and pred == gold
if is_correct:
correct += 1
# Send /clear to reset context for next problem
try:
proc.stdin.write(b"/clear\n")
proc.stdin.flush()
# Read the "history cleared" response
clear_resp = read_until_prompt(proc, timeout=10)
except BrokenPipeError:
pass
mark = "" if is_correct else ""
print(f"[{mark}] {i+1:3d}/{len(problems)} "
f"id={prob['id']:>4s} gold={prob['answer']:>8s} "
f"pred={str(pred):>8s} {elapsed:.1f}s")
# Cleanup
try:
proc.stdin.write(b"/exit\n")
proc.stdin.flush()
except:
pass
proc.wait(timeout=5)
print("-" * 72)
n_scored = len(problems) - errors
accuracy = correct / max(n_scored, 1)
print(f"Results: {correct}/{n_scored} correct = {accuracy*100:.1f}% accuracy")
if errors:
print(f" ({errors} errors/timeouts)")
print(f"Generation time: {total_gen_time:.1f}s, avg {total_gen_time/max(len(problems),1):.1f}s/problem")
if __name__ == "__main__":
main()