Store expert gate_up_proj and down_proj weights in FP8 E4M3 (1 byte/elem) with per-expert FP32 scale factors. At inference, a fused CUDA kernel dequantizes to BF16 before the existing cuBLAS batched GEMM. Results on gpt-oss-20b (50-problem GSM8K subset): - FP8 TP=1: 47/50 = 94.0% (single RTX 5090, ~25 GB VRAM) - BF16 TP=2: 47/50 = 94.0% (requires 2× RTX 5090, ~39 GB total) No measurable accuracy degradation. Model size: 41.8 GB → 22.7 GB (−46%). New files: - tools/quantize_fp8.py: offline BF16→FP8 conversion script - csrc/quantization/dequant_fp8.cu: per-expert-scale dequant kernel - crates/xserv-kernels/src/quantization.rs: Rust FFI wrapper - tools/eval_gsm8k_batch.sh: GSM8K accuracy evaluation harness Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
108 lines
3.2 KiB
Bash
Executable File
108 lines
3.2 KiB
Bash
Executable File
#!/bin/bash
|
|
# GSM8K evaluation via repeated xserv-chat invocations.
|
|
# Usage: eval_gsm8k_batch.sh <model-dir> <limit> [gpu_id] [tp]
|
|
set -uo pipefail
|
|
export PATH=/usr/local/cuda/bin:$PATH
|
|
source ~/.cargo/env 2>/dev/null || true
|
|
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH:-}
|
|
|
|
MODEL_DIR="${1:?Usage: $0 <model-dir> <limit> [gpu_id] [tp]}"
|
|
LIMIT="${2:-50}"
|
|
GPU="${3:-0}"
|
|
TP="${4:-1}"
|
|
export CUDA_VISIBLE_DEVICES=$GPU
|
|
|
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
|
XSERV_CHAT="$SCRIPT_DIR/../target/release/xserv-chat"
|
|
DATA="$SCRIPT_DIR/bench/data/gsm8k.json"
|
|
SYSTEM='Solve the problem step by step. Put your final numeric answer inside \\boxed{}.'
|
|
|
|
echo "=== GSM8K Eval: model=$MODEL_DIR, limit=$LIMIT, gpu=$GPU, tp=$TP ==="
|
|
|
|
TMPDIR=$(mktemp -d)
|
|
trap "rm -rf $TMPDIR" EXIT
|
|
|
|
# Generate problem files
|
|
python3 -c "
|
|
import json
|
|
problems = json.load(open('$DATA'))[:$LIMIT]
|
|
for i, p in enumerate(problems):
|
|
with open(f'$TMPDIR/{i:04d}.txt', 'w') as f:
|
|
f.write(p['problem'].replace(chr(10), ' '))
|
|
with open(f'$TMPDIR/{i:04d}.gold', 'w') as f:
|
|
f.write(p['answer'])
|
|
print(f'{len(problems)} problems prepared')
|
|
"
|
|
|
|
TOTAL=$(ls "$TMPDIR"/*.txt 2>/dev/null | wc -l)
|
|
CORRECT=0
|
|
SCORED=0
|
|
START_TIME=$(date +%s)
|
|
|
|
TP_FLAG=""
|
|
if [ "$TP" -gt 1 ]; then
|
|
TP_FLAG="--tp $TP"
|
|
fi
|
|
|
|
for f in $(ls "$TMPDIR"/*.txt | sort); do
|
|
IDX=$(basename "$f" .txt)
|
|
GOLD=$(cat "$TMPDIR/${IDX}.gold")
|
|
QUESTION=$(cat "$f")
|
|
|
|
# Run single-question xserv-chat
|
|
RAW_OUT=$(echo "$QUESTION" | timeout 120 "$XSERV_CHAT" "$MODEL_DIR" \
|
|
--max-tokens 512 --max-seq-len 1024 \
|
|
--system "$SYSTEM" --no-color $TP_FLAG 2>/dev/null || true)
|
|
|
|
# Extract predicted answer
|
|
PRED=$(echo "$RAW_OUT" | python3 -c "
|
|
import re, sys
|
|
text = sys.stdin.read()
|
|
# Extract everything after 'assistant>'
|
|
if 'assistant>' in text:
|
|
text = text.split('assistant>', 1)[1]
|
|
if 'user>' in text:
|
|
text = text[:text.rindex('user>')]
|
|
boxed = re.findall(r'\\\\boxed\s*\{([^{}]*)\}', text)
|
|
if boxed:
|
|
nums = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', boxed[-1])
|
|
if nums:
|
|
s = nums[-1].replace(',','')
|
|
f = float(s)
|
|
print(str(int(f)) if f == int(f) else f'{f:g}')
|
|
sys.exit(0)
|
|
nums = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', text)
|
|
if nums:
|
|
s = nums[-1].replace(',','')
|
|
f = float(s)
|
|
print(str(int(f)) if f == int(f) else f'{f:g}')
|
|
else:
|
|
print('NONE')
|
|
" 2>/dev/null || echo "NONE")
|
|
|
|
# Normalize gold
|
|
GOLD_NORM=$(python3 -c "
|
|
s='$GOLD'.replace(',','').strip()
|
|
f=float(s)
|
|
print(str(int(f)) if f==int(f) else f'{f:g}')
|
|
" 2>/dev/null || echo "$GOLD")
|
|
|
|
SCORED=$((SCORED + 1))
|
|
if [ "$PRED" = "$GOLD_NORM" ]; then
|
|
CORRECT=$((CORRECT + 1))
|
|
echo "[✓] $IDX gold=$GOLD_NORM pred=$PRED"
|
|
elif [ "$PRED" = "NONE" ]; then
|
|
echo "[E] $IDX gold=$GOLD_NORM pred=NONE (no output)"
|
|
else
|
|
echo "[✗] $IDX gold=$GOLD_NORM pred=$PRED"
|
|
fi
|
|
done
|
|
|
|
END_TIME=$(date +%s)
|
|
ELAPSED=$((END_TIME - START_TIME))
|
|
|
|
echo "------------------------------------------------------------------------"
|
|
python3 -c "print(f'Results: $CORRECT/$SCORED correct = {$CORRECT/$SCORED*100:.1f}% accuracy')"
|
|
echo "Wall time: ${ELAPSED}s (avg $((ELAPSED / TOTAL))s/problem)"
|
|
echo "=== Done ==="
|