""" Compare xserv Qwen3 output against HuggingFace transformers. Usage: python3 tools/bench_compare_qwen3.py """ import json import sys import time import torch from transformers import AutoModelForCausalLM, AutoTokenizer def main(): if len(sys.argv) < 3: print(f"Usage: {sys.argv[0]} ") sys.exit(1) xserv_path = sys.argv[1] model_dir = sys.argv[2] with open(xserv_path) as f: xserv_results = json.load(f) print(f"Loading transformers model from {model_dir}...") model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_dir) model.eval() model.cuda() # Warmup with torch.no_grad(): ids = tokenizer.encode("warmup", return_tensors="pt").cuda() model(ids) torch.cuda.synchronize() total = len(xserv_results) match_count = 0 mismatch_count = 0 xserv_ttft_sum = 0.0 xserv_tbt_sum = 0.0 hf_ttft_sum = 0.0 hf_tbt_sum = 0.0 num_with_tbt = 0 print(f"\n{'='*100}") print(f"{'#':>3} {'Match':>5} {'Prompt':<45} {'xserv TTFT':>10} {'HF TTFT':>10} {'xserv TBT':>10} {'HF TBT':>10}") print(f"{'='*100}") for i, xr in enumerate(xserv_results): prompt = xr["prompt"] gen_tokens = xr["num_generated"] xserv_ids = xr["generated_ids"] input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda() hf_generated = [] hf_token_times = [] with torch.no_grad(): all_ids = input_ids.clone() torch.cuda.synchronize() t0 = time.perf_counter() out = model(all_ids) torch.cuda.synchronize() hf_ttft_us = (time.perf_counter() - t0) * 1e6 next_id = out.logits[0, -1].argmax().item() hf_generated.append(next_id) all_ids = torch.cat([all_ids, torch.tensor([[next_id]]).cuda()], dim=1) for _ in range(1, gen_tokens): torch.cuda.synchronize() t_start = time.perf_counter() out = model(all_ids) torch.cuda.synchronize() elapsed = (time.perf_counter() - t_start) * 1e6 hf_token_times.append(elapsed) next_id = out.logits[0, -1].argmax().item() hf_generated.append(next_id) all_ids = torch.cat([all_ids, torch.tensor([[next_id]]).cuda()], dim=1) if next_id == tokenizer.eos_token_id: break hf_tbt_us = sum(hf_token_times) / len(hf_token_times) if hf_token_times else 0 match = xserv_ids == hf_generated if match: match_count += 1 status = " OK " else: mismatch_count += 1 status = "FAIL!" xserv_ttft_ms = xr["ttft_us"] / 1000.0 xserv_tbt_ms = xr["tbt_us"] / 1000.0 hf_ttft_ms = hf_ttft_us / 1000.0 hf_tbt_ms = hf_tbt_us / 1000.0 prompt_short = prompt[:43] + ".." if len(prompt) > 45 else prompt print(f"{i+1:>3} {status} {prompt_short:<45} {xserv_ttft_ms:>8.1f}ms {hf_ttft_ms:>8.1f}ms {xserv_tbt_ms:>8.1f}ms {hf_tbt_ms:>8.1f}ms") if not match: for j in range(max(len(xserv_ids), len(hf_generated))): x = xserv_ids[j] if j < len(xserv_ids) else None h = hf_generated[j] if j < len(hf_generated) else None if x != h: x_tok = tokenizer.decode([x]) if x is not None else "" h_tok = tokenizer.decode([h]) if h is not None else "" print(f" diverge@{j}: xserv={x}({repr(x_tok)}) hf={h}({repr(h_tok)})") break xserv_ttft_sum += xr["ttft_us"] xserv_tbt_sum += xr["tbt_us"] hf_ttft_sum += hf_ttft_us hf_tbt_sum += hf_tbt_us if xr["tbt_us"] > 0: num_with_tbt += 1 print(f"{'='*100}") print(f"\n=== CORRECTNESS ===") print(f"Total: {total}, Match: {match_count}/{total} ({match_count/total*100:.1f}%), Mismatch: {mismatch_count}") print(f"\n=== PERFORMANCE ===") print(f"{'Metric':<20} {'xserv':>12} {'transformers':>12} {'ratio':>10}") print(f"{'-'*54}") avg_x_ttft = xserv_ttft_sum / total / 1000 avg_h_ttft = hf_ttft_sum / total / 1000 avg_x_tbt = xserv_tbt_sum / num_with_tbt / 1000 if num_with_tbt > 0 else 0 avg_h_tbt = hf_tbt_sum / num_with_tbt / 1000 if num_with_tbt > 0 else 0 print(f"{'TTFT (ms)':<20} {avg_x_ttft:>10.1f}ms {avg_h_ttft:>10.1f}ms {avg_x_ttft/avg_h_ttft if avg_h_ttft>0 else 0:>9.1f}x") print(f"{'TBT (ms)':<20} {avg_x_tbt:>10.1f}ms {avg_h_tbt:>10.1f}ms {avg_x_tbt/avg_h_tbt if avg_h_tbt>0 else 0:>9.1f}x") xserv_tps = 1000.0 / avg_x_tbt if avg_x_tbt > 0 else 0 hf_tps = 1000.0 / avg_h_tbt if avg_h_tbt > 0 else 0 print(f"{'Throughput (tok/s)':<20} {xserv_tps:>10.1f} {hf_tps:>10.1f} {xserv_tps/hf_tps if hf_tps>0 else 0:>9.2f}x") if __name__ == "__main__": main()