#!/usr/bin/env python3 """Compare xserv prefill logits with HuggingFace transformers on 10 prompts.""" import os import sys import subprocess import re MODEL_DIR = "/opt/wjh/models/qwen3-8b" TOP_K = 10 PROMPTS = [ "What is the capital of France?", "Explain quantum computing.", "Hello world", "def fibonacci(n):", "The weather today is", "1 + 1 =", "Machine learning is", "Once upon a time", "Paris is known for", "How does gravity work?", ] def get_hf_topk(prompt, tokenizer, model, k=10): import torch inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits[0, -1, :].float().cpu() topk = torch.topk(logits, k) return list(zip(topk.indices.tolist(), topk.values.tolist())) def get_xserv_topk(prompt, k=10): xserv_bin = "/opt/wjh/projects/xserv/target/release/dump-logits" env = {**os.environ, "CUDA_VISIBLE_DEVICES": "0", "PATH": "/usr/local/cuda-12.9/bin:" + os.environ.get("PATH", "")} result = subprocess.run( [xserv_bin, MODEL_DIR, prompt], capture_output=True, text=True, timeout=180, env=env, ) # Parse output: " [ 0] id= 3555 logit= 24.5000 token=..." topk = [] for line in result.stdout.strip().split('\n'): m = re.match(r'\s*\[\s*\d+\]\s+id=\s*(\d+)\s+logit=\s*([\d.\-]+)', line) if m: topk.append((int(m.group(1)), float(m.group(2)))) if len(topk) >= k: break return topk def main(): import torch from transformers import AutoModelForCausalLM, AutoTokenizer print(f"Loading HF model on GPU 1...") tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_DIR, dtype=torch.bfloat16, device_map="cuda:1", trust_remote_code=True) model.eval() print("HF model loaded.\n") total = len(PROMPTS) top1_matches = 0 top5_overlaps = [] for i, prompt in enumerate(PROMPTS): print(f"[{i+1}/{total}] \"{prompt}\"") hf_top = get_hf_topk(prompt, tokenizer, model, TOP_K) xs_top = get_xserv_topk(prompt, TOP_K) if not xs_top: print(" xserv: NO OUTPUT") continue hf_ids = [t[0] for t in hf_top] xs_ids = [t[0] for t in xs_top] top1_match = hf_ids[0] == xs_ids[0] if top1_match: top1_matches += 1 top5_overlap = len(set(hf_ids[:5]) & set(xs_ids[:5])) top5_overlaps.append(top5_overlap) # Show comparison hf_tok = tokenizer.decode([hf_ids[0]]) xs_tok = tokenizer.decode([xs_ids[0]]) status = "MATCH" if top1_match else "DIFF" print(f" Top-1: HF={hf_ids[0]:>6}({hf_tok!r:>10}) | xserv={xs_ids[0]:>6}({xs_tok!r:>10}) [{status}]") print(f" Top-5 overlap: {top5_overlap}/5") # Show top-5 side by side print(f" {'HF':>25} | {'xserv':>25}") for j in range(min(5, len(hf_top), len(xs_top))): h_id, h_val = hf_top[j] x_id, x_val = xs_top[j] h_tok = tokenizer.decode([h_id]) x_tok = tokenizer.decode([x_id]) print(f" {h_id:>6} {h_val:>8.3f} {h_tok!r:>8} | {x_id:>6} {x_val:>8.3f} {x_tok!r:>8}") print() print("=" * 50) print(f"Top-1 match rate: {top1_matches}/{total} ({100*top1_matches/total:.0f}%)") avg_overlap = sum(top5_overlaps) / max(len(top5_overlaps), 1) print(f"Avg top-5 overlap: {avg_overlap:.1f}/5") print(f"Verdict: {'PASS' if top1_matches >= total * 0.7 else 'FAIL'}") if __name__ == "__main__": main()