"""GSM8K — 1319 grade-school math problems with integer/decimal answers. Gold answers in the dataset are in the form `... #### 42`. We score by exact-match of the final number, with the same `\\boxed{}` / last-number extraction used for AIME, since for instruction-tuned models the response follows the prompt instructions, not the dataset's `####` convention. """ from __future__ import annotations import re from typing import Any from . import load_local TASK_NAME = "gsm8k" def load() -> list[dict[str, Any]]: local = load_local(TASK_NAME) if local is not None: return local return load_remote() def load_remote() -> list[dict[str, Any]]: """Fetch from HuggingFace. Requires network — used by fetch_datasets.py.""" from datasets import load_dataset # noqa: PLC0415 ds = load_dataset("openai/gsm8k", "main", split="test") out: list[dict[str, Any]] = [] for i, row in enumerate(ds): ans_full: str = row["answer"] # gold format: "\n#### 42" gold = ans_full.split("####")[-1].strip().replace(",", "") out.append({ "id": str(i), "problem": row["question"], "answer": gold, "source": "openai/gsm8k", }) return out SYSTEM_PROMPT = ( "You are a careful math problem solver. Solve the problem step by step. " "Put your final numeric answer inside \\boxed{}." ) def make_messages(problem: str) -> list[dict[str, str]]: return [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": problem}, ] _BOXED_RE = re.compile(r"\\boxed\s*\{([^{}]*)\}") # Allow comma-grouped thousands (e.g. "3,500"); _normalize_num strips them. _NUM_RE = re.compile(r"-?\d+(?:,\d{3})*(?:\.\d+)?") def _normalize_num(s: str) -> str | None: s = s.replace(",", "").strip() try: f = float(s) except ValueError: return None return str(int(f)) if f.is_integer() else f"{f:g}" def extract_answer(text: str) -> str | None: if not text: return None boxed = _BOXED_RE.findall(text) if boxed: nums = _NUM_RE.findall(boxed[-1]) if nums: return _normalize_num(nums[-1]) nums = _NUM_RE.findall(text) if nums: return _normalize_num(nums[-1]) return None def score(pred: str | None, gold: str) -> bool: if pred is None: return False gold_norm = _normalize_num(gold) return gold_norm is not None and pred == gold_norm