Files
xserv/tools/bench_compare_qwen3.py
Gahow Wang 268e40d764 phase 10: add Qwen3-8B benchmark + performance fix
Benchmark infrastructure:
- bench-qwen3 binary: 50 prompts × 20 tokens with KV cache
- bench_compare_qwen3.py: comparison against HF transformers (BF16)

Performance fix:
- Precompute transposed weights at model load time (eliminated per-token
  weight transpose CPU round-trip: was 252 transposes × 32MB each = 8GB/token)
- Result: from "infinite" (>10 min/token) to 144ms/token

Results (50 prompts):
- Prefill top-1: 42/50 (84%), top-5: 50/50 (100%) vs HF transformers
- Greedy sequence: 0/50 exact match (BF16 precision drift over 36 layers)
- Performance: TTFT=138ms, TBT=144ms, 6.9 tok/s (HF: 21ms, 45.6 tok/s)
- All outputs are coherent English/Chinese

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-22 10:25:33 +08:00

138 lines
5.0 KiB
Python

"""
Compare xserv Qwen3 output against HuggingFace transformers.
Usage: python3 tools/bench_compare_qwen3.py <xserv_results.json> <model_dir>
"""
import json
import sys
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def main():
if len(sys.argv) < 3:
print(f"Usage: {sys.argv[0]} <xserv_results.json> <model_dir>")
sys.exit(1)
xserv_path = sys.argv[1]
model_dir = sys.argv[2]
with open(xserv_path) as f:
xserv_results = json.load(f)
print(f"Loading transformers model from {model_dir}...")
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model.eval()
model.cuda()
# Warmup
with torch.no_grad():
ids = tokenizer.encode("warmup", return_tensors="pt").cuda()
model(ids)
torch.cuda.synchronize()
total = len(xserv_results)
match_count = 0
mismatch_count = 0
xserv_ttft_sum = 0.0
xserv_tbt_sum = 0.0
hf_ttft_sum = 0.0
hf_tbt_sum = 0.0
num_with_tbt = 0
print(f"\n{'='*100}")
print(f"{'#':>3} {'Match':>5} {'Prompt':<45} {'xserv TTFT':>10} {'HF TTFT':>10} {'xserv TBT':>10} {'HF TBT':>10}")
print(f"{'='*100}")
for i, xr in enumerate(xserv_results):
prompt = xr["prompt"]
gen_tokens = xr["num_generated"]
xserv_ids = xr["generated_ids"]
input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda()
hf_generated = []
hf_token_times = []
with torch.no_grad():
all_ids = input_ids.clone()
torch.cuda.synchronize()
t0 = time.perf_counter()
out = model(all_ids)
torch.cuda.synchronize()
hf_ttft_us = (time.perf_counter() - t0) * 1e6
next_id = out.logits[0, -1].argmax().item()
hf_generated.append(next_id)
all_ids = torch.cat([all_ids, torch.tensor([[next_id]]).cuda()], dim=1)
for _ in range(1, gen_tokens):
torch.cuda.synchronize()
t_start = time.perf_counter()
out = model(all_ids)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - t_start) * 1e6
hf_token_times.append(elapsed)
next_id = out.logits[0, -1].argmax().item()
hf_generated.append(next_id)
all_ids = torch.cat([all_ids, torch.tensor([[next_id]]).cuda()], dim=1)
if next_id == tokenizer.eos_token_id:
break
hf_tbt_us = sum(hf_token_times) / len(hf_token_times) if hf_token_times else 0
match = xserv_ids == hf_generated
if match:
match_count += 1
status = " OK "
else:
mismatch_count += 1
status = "FAIL!"
xserv_ttft_ms = xr["ttft_us"] / 1000.0
xserv_tbt_ms = xr["tbt_us"] / 1000.0
hf_ttft_ms = hf_ttft_us / 1000.0
hf_tbt_ms = hf_tbt_us / 1000.0
prompt_short = prompt[:43] + ".." if len(prompt) > 45 else prompt
print(f"{i+1:>3} {status} {prompt_short:<45} {xserv_ttft_ms:>8.1f}ms {hf_ttft_ms:>8.1f}ms {xserv_tbt_ms:>8.1f}ms {hf_tbt_ms:>8.1f}ms")
if not match:
for j in range(max(len(xserv_ids), len(hf_generated))):
x = xserv_ids[j] if j < len(xserv_ids) else None
h = hf_generated[j] if j < len(hf_generated) else None
if x != h:
x_tok = tokenizer.decode([x]) if x is not None else "<none>"
h_tok = tokenizer.decode([h]) if h is not None else "<none>"
print(f" diverge@{j}: xserv={x}({repr(x_tok)}) hf={h}({repr(h_tok)})")
break
xserv_ttft_sum += xr["ttft_us"]
xserv_tbt_sum += xr["tbt_us"]
hf_ttft_sum += hf_ttft_us
hf_tbt_sum += hf_tbt_us
if xr["tbt_us"] > 0:
num_with_tbt += 1
print(f"{'='*100}")
print(f"\n=== CORRECTNESS ===")
print(f"Total: {total}, Match: {match_count}/{total} ({match_count/total*100:.1f}%), Mismatch: {mismatch_count}")
print(f"\n=== PERFORMANCE ===")
print(f"{'Metric':<20} {'xserv':>12} {'transformers':>12} {'ratio':>10}")
print(f"{'-'*54}")
avg_x_ttft = xserv_ttft_sum / total / 1000
avg_h_ttft = hf_ttft_sum / total / 1000
avg_x_tbt = xserv_tbt_sum / num_with_tbt / 1000 if num_with_tbt > 0 else 0
avg_h_tbt = hf_tbt_sum / num_with_tbt / 1000 if num_with_tbt > 0 else 0
print(f"{'TTFT (ms)':<20} {avg_x_ttft:>10.1f}ms {avg_h_ttft:>10.1f}ms {avg_x_ttft/avg_h_ttft if avg_h_ttft>0 else 0:>9.1f}x")
print(f"{'TBT (ms)':<20} {avg_x_tbt:>10.1f}ms {avg_h_tbt:>10.1f}ms {avg_x_tbt/avg_h_tbt if avg_h_tbt>0 else 0:>9.1f}x")
xserv_tps = 1000.0 / avg_x_tbt if avg_x_tbt > 0 else 0
hf_tps = 1000.0 / avg_h_tbt if avg_h_tbt > 0 else 0
print(f"{'Throughput (tok/s)':<20} {xserv_tps:>10.1f} {hf_tps:>10.1f} {xserv_tps/hf_tps if hf_tps>0 else 0:>9.2f}x")
if __name__ == "__main__":
main()