#!/usr/bin/env python3 """Fast GSM8K evaluation — keeps xserv-chat running, pipes problems via stdin. Usage: python eval_gsm8k_fast.py [--limit N] [--max-tokens N] [--tp N] """ import argparse import json import os import re import select import subprocess import sys import time from pathlib import Path SCRIPT_DIR = Path(__file__).parent DATA_PATH = SCRIPT_DIR / "bench" / "data" / "gsm8k.json" XSERV_CHAT = Path(__file__).parent.parent / "target" / "release" / "xserv-chat" SYSTEM_PROMPT = ( "You are a careful math problem solver. Solve the problem step by step. " "Put your final numeric answer inside \\boxed{}." ) _BOXED_RE = re.compile(r"\\boxed\s*\{([^{}]*)\}") _NUM_RE = re.compile(r"-?\d+(?:,\d{3})*(?:\.\d+)?") def normalize_num(s: str) -> str | None: s = s.replace(",", "").strip() try: f = float(s) except ValueError: return None return str(int(f)) if f.is_integer() else f"{f:g}" def extract_answer(text: str) -> str | None: if not text: return None boxed = _BOXED_RE.findall(text) if boxed: nums = _NUM_RE.findall(boxed[-1]) if nums: return normalize_num(nums[-1]) nums = _NUM_RE.findall(text) if nums: return normalize_num(nums[-1]) return None def read_until_prompt(proc, timeout=120): """Read from proc.stdout until we see 'user> ' prompt, return collected text.""" import io buf = [] deadline = time.time() + timeout fd = proc.stdout.fileno() while time.time() < deadline: remaining = deadline - time.time() ready, _, _ = select.select([fd], [], [], min(remaining, 0.1)) if ready: chunk = os.read(fd, 4096) if not chunk: break text = chunk.decode("utf-8", errors="replace") buf.append(text) joined = "".join(buf) if "user> " in joined.split("assistant>")[-1] if "assistant>" in joined else "user> " in joined: # Check if we have a complete response (ends with "user> ") if joined.rstrip().endswith("user>"): break return "".join(buf) def main(): parser = argparse.ArgumentParser(description="Fast GSM8K eval via persistent xserv-chat") parser.add_argument("model_dir", help="Model directory") parser.add_argument("--limit", type=int, default=50, help="Number of problems") parser.add_argument("--max-tokens", type=int, default=512, help="Max generation tokens") parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism") parser.add_argument("--offset", type=int, default=0, help="Start from problem N") parser.add_argument("--gpu", type=str, default="0", help="CUDA_VISIBLE_DEVICES value, e.g. '0' or '2,3' (must cover --tp ranks)") args = parser.parse_args() if not DATA_PATH.exists(): print(f"Error: {DATA_PATH} not found", file=sys.stderr) sys.exit(1) with open(DATA_PATH) as f: problems = json.load(f) problems = problems[args.offset:args.offset + args.limit] # Start xserv-chat as persistent subprocess cmd = [ str(XSERV_CHAT), args.model_dir, "--max-tokens", str(args.max_tokens), "--max-seq-len", "2048", "--system", SYSTEM_PROMPT, "--no-color", ] if args.tp > 1: cmd += ["--tp", str(args.tp)] env = {**os.environ, "CUDA_VISIBLE_DEVICES": str(args.gpu)} print(f"GSM8K evaluation: {len(problems)} problems, model={args.model_dir}") print(f"max_tokens={args.max_tokens}, tp={args.tp}, gpu={args.gpu}") print(f"Starting xserv-chat...", file=sys.stderr) proc = subprocess.Popen( cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env, bufsize=0, ) # Wait for the "Ready" message on stderr, and first "user> " on stdout # Read stderr in background to avoid blocking import threading stderr_lines = [] def read_stderr(): while True: line = proc.stderr.readline() if not line: break stderr_lines.append(line.decode("utf-8", errors="replace")) t = threading.Thread(target=read_stderr, daemon=True) t.start() # Wait for first prompt startup_text = read_until_prompt(proc, timeout=120) time.sleep(0.5) # small settle print(f"Model loaded. Starting evaluation.", file=sys.stderr) print("-" * 72) correct = 0 errors = 0 total_gen_time = 0.0 for i, prob in enumerate(problems): question = prob["problem"].replace("\n", " ") # Send question + newline try: proc.stdin.write((question + "\n").encode("utf-8")) proc.stdin.flush() except BrokenPipeError: print(f"[E] Process died at problem {i}", file=sys.stderr) break t0 = time.time() response_text = read_until_prompt(proc, timeout=120) elapsed = time.time() - t0 total_gen_time += elapsed # Extract the assistant response response = "" if "assistant>" in response_text: parts = response_text.split("assistant>", 1) if len(parts) > 1: rest = parts[1] if "user>" in rest: response = rest[:rest.rindex("user>")].strip() else: response = rest.strip() pred = extract_answer(response) gold = normalize_num(prob["answer"]) is_correct = pred is not None and gold is not None and pred == gold if is_correct: correct += 1 # Send /clear to reset context for next problem try: proc.stdin.write(b"/clear\n") proc.stdin.flush() # Read the "history cleared" response clear_resp = read_until_prompt(proc, timeout=10) except BrokenPipeError: pass mark = "✓" if is_correct else "✗" print(f"[{mark}] {i+1:3d}/{len(problems)} " f"id={prob['id']:>4s} gold={prob['answer']:>8s} " f"pred={str(pred):>8s} {elapsed:.1f}s") # Cleanup try: proc.stdin.write(b"/exit\n") proc.stdin.flush() except: pass proc.wait(timeout=5) print("-" * 72) n_scored = len(problems) - errors accuracy = correct / max(n_scored, 1) print(f"Results: {correct}/{n_scored} correct = {accuracy*100:.1f}% accuracy") if errors: print(f" ({errors} errors/timeouts)") print(f"Generation time: {total_gen_time:.1f}s, avg {total_gen_time/max(len(problems),1):.1f}s/problem") if __name__ == "__main__": main()