#!/usr/bin/env python3 """Direct GSM8K evaluation using xserv-chat CLI. Usage: python eval_gsm8k.py [--limit N] [--max-tokens N] [--tp N] Runs xserv-chat on each GSM8K problem, extracts the numeric answer, and reports accuracy. Uses temperature=0 (greedy) for determinism. """ import argparse import json import os import re import subprocess import sys import time from pathlib import Path SCRIPT_DIR = Path(__file__).parent DATA_PATH = SCRIPT_DIR / "bench" / "data" / "gsm8k.json" XSERV_CHAT = Path(__file__).parent.parent / "target" / "release" / "xserv-chat" SYSTEM_PROMPT = ( "You are a careful math problem solver. Solve the problem step by step. " "Put your final numeric answer inside \\boxed{}." ) _BOXED_RE = re.compile(r"\\boxed\s*\{([^{}]*)\}") _NUM_RE = re.compile(r"-?\d+(?:,\d{3})*(?:\.\d+)?") def normalize_num(s: str) -> str | None: s = s.replace(",", "").strip() try: f = float(s) except ValueError: return None return str(int(f)) if f.is_integer() else f"{f:g}" def extract_answer(text: str) -> str | None: if not text: return None boxed = _BOXED_RE.findall(text) if boxed: nums = _NUM_RE.findall(boxed[-1]) if nums: return normalize_num(nums[-1]) nums = _NUM_RE.findall(text) if nums: return normalize_num(nums[-1]) return None def run_one(model_dir: str, problem: str, max_tokens: int, tp: int) -> tuple[str, float]: """Run xserv-chat on a single problem, return (response_text, elapsed_s).""" cmd = [ str(XSERV_CHAT), model_dir, "--max-tokens", str(max_tokens), "--max-seq-len", "2048", "--system", SYSTEM_PROMPT, "--no-color", ] if tp > 1: cmd += ["--tp", str(tp)] t0 = time.time() proc = subprocess.run( cmd, input=problem + "\n/exit\n", capture_output=True, text=True, timeout=120, env={**os.environ, "CUDA_VISIBLE_DEVICES": os.environ.get("CUDA_VISIBLE_DEVICES", "0")}, ) elapsed = time.time() - t0 # Parse the assistant response from stdout output = proc.stdout # The output format is: "assistant> \nuser>" response = "" for line in output.split("\n"): if line.startswith("assistant> "): response = line[len("assistant> "):] elif response and not line.startswith("user>"): response += "\n" + line # Also capture multi-line responses between "assistant>" and next "user>" if "assistant>" in output: parts = output.split("assistant>", 1) if len(parts) > 1: rest = parts[1] if "user>" in rest: response = rest[:rest.index("user>")].strip() else: response = rest.strip() return response, elapsed def main(): parser = argparse.ArgumentParser(description="GSM8K evaluation via xserv-chat") parser.add_argument("model_dir", help="Model directory") parser.add_argument("--limit", type=int, default=50, help="Number of problems (default: 50)") parser.add_argument("--max-tokens", type=int, default=512, help="Max generation tokens") parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism") parser.add_argument("--offset", type=int, default=0, help="Start from problem N") args = parser.parse_args() if not DATA_PATH.exists(): print(f"Error: {DATA_PATH} not found", file=sys.stderr) sys.exit(1) with open(DATA_PATH) as f: problems = json.load(f) problems = problems[args.offset:args.offset + args.limit] print(f"GSM8K evaluation: {len(problems)} problems, model={args.model_dir}") print(f"max_tokens={args.max_tokens}, tp={args.tp}") print("-" * 72) correct = 0 errors = 0 total_time = 0.0 for i, prob in enumerate(problems): try: response, elapsed = run_one(args.model_dir, prob["problem"], args.max_tokens, args.tp) total_time += elapsed pred = extract_answer(response) gold = normalize_num(prob["answer"]) is_correct = pred is not None and gold is not None and pred == gold if is_correct: correct += 1 mark = "✓" if is_correct else "✗" print(f"[{mark}] {i+1:3d}/{len(problems)} " f"id={prob['id']:>4s} gold={prob['answer']:>8s} " f"pred={str(pred):>8s} {elapsed:.1f}s") except subprocess.TimeoutExpired: errors += 1 print(f"[T] {i+1:3d}/{len(problems)} id={prob['id']:>4s} TIMEOUT") except Exception as e: errors += 1 print(f"[E] {i+1:3d}/{len(problems)} id={prob['id']:>4s} {e}") print("-" * 72) n_scored = len(problems) - errors accuracy = correct / max(n_scored, 1) print(f"Results: {correct}/{n_scored} correct = {accuracy*100:.1f}% accuracy") if errors: print(f" ({errors} errors/timeouts)") print(f"Total time: {total_time:.1f}s, avg {total_time/max(len(problems),1):.1f}s/problem") if __name__ == "__main__": main()