phase 8: add benchmark framework + baseline results
- bench-gpt2 binary: runs 50 prompts, measures TTFT/TBT per prompt, outputs JSON - bench_compare.py: compares xserv vs transformers token-by-token + timing - Baseline results: 50/50 correctness, 400ms TTFT / 407ms TBT (100x slower than PyTorch) - Bottlenecks documented: no KV cache, CPU round-trips, cuBLAS handle churn Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
154
tools/bench_compare.py
Normal file
154
tools/bench_compare.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Compare xserv GPT-2 output against HuggingFace transformers.
|
||||
Reads xserv results from JSON, runs same prompts through transformers, compares token-by-token.
|
||||
Also measures transformers timing for performance comparison.
|
||||
|
||||
Usage:
|
||||
python3 tools/bench_compare.py <xserv_results.json> <model_dir>
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 3:
|
||||
print(f"Usage: {sys.argv[0]} <xserv_results.json> <model_dir>")
|
||||
sys.exit(1)
|
||||
|
||||
xserv_path = sys.argv[1]
|
||||
model_dir = sys.argv[2]
|
||||
|
||||
with open(xserv_path) as f:
|
||||
xserv_results = json.load(f)
|
||||
|
||||
print(f"Loading transformers model from {model_dir}...")
|
||||
model = GPT2LMHeadModel.from_pretrained(model_dir)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
|
||||
model.eval()
|
||||
model.cuda()
|
||||
|
||||
# Warmup
|
||||
with torch.no_grad():
|
||||
model(torch.tensor([[tokenizer.encode("warmup")[0]]]).cuda())
|
||||
torch.cuda.synchronize()
|
||||
|
||||
total = len(xserv_results)
|
||||
match_count = 0
|
||||
mismatch_count = 0
|
||||
xserv_ttft_sum = 0.0
|
||||
xserv_tbt_sum = 0.0
|
||||
hf_ttft_sum = 0.0
|
||||
hf_tbt_sum = 0.0
|
||||
num_with_tbt = 0
|
||||
|
||||
print(f"\n{'='*100}")
|
||||
print(f"{'#':>3} {'Match':>5} {'Prompt':<45} {'xserv TTFT':>10} {'HF TTFT':>10} {'xserv TBT':>10} {'HF TBT':>10}")
|
||||
print(f"{'='*100}")
|
||||
|
||||
for i, xr in enumerate(xserv_results):
|
||||
prompt = xr["prompt"]
|
||||
gen_tokens = xr["num_generated"]
|
||||
xserv_ids = xr["generated_ids"]
|
||||
|
||||
input_ids = tokenizer.encode(prompt)
|
||||
input_tensor = torch.tensor([input_ids]).cuda()
|
||||
|
||||
# Generate with transformers, measuring timing
|
||||
hf_generated = []
|
||||
hf_token_times = []
|
||||
|
||||
with torch.no_grad():
|
||||
all_ids = input_tensor.clone()
|
||||
|
||||
# TTFT
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
out = model(all_ids)
|
||||
torch.cuda.synchronize()
|
||||
hf_ttft_us = (time.perf_counter() - t0) * 1e6
|
||||
next_id = out.logits[0, -1].argmax().item()
|
||||
hf_generated.append(next_id)
|
||||
all_ids = torch.cat([all_ids, torch.tensor([[next_id]]).cuda()], dim=1)
|
||||
|
||||
# Remaining tokens
|
||||
for _ in range(1, gen_tokens):
|
||||
torch.cuda.synchronize()
|
||||
t_start = time.perf_counter()
|
||||
out = model(all_ids)
|
||||
torch.cuda.synchronize()
|
||||
elapsed = (time.perf_counter() - t_start) * 1e6
|
||||
hf_token_times.append(elapsed)
|
||||
next_id = out.logits[0, -1].argmax().item()
|
||||
hf_generated.append(next_id)
|
||||
all_ids = torch.cat([all_ids, torch.tensor([[next_id]]).cuda()], dim=1)
|
||||
|
||||
eos_id = tokenizer.eos_token_id
|
||||
if eos_id is not None and next_id == eos_id:
|
||||
break
|
||||
|
||||
hf_tbt_us = sum(hf_token_times) / len(hf_token_times) if hf_token_times else 0
|
||||
|
||||
# Compare
|
||||
match = xserv_ids == hf_generated
|
||||
if match:
|
||||
match_count += 1
|
||||
status = " OK "
|
||||
else:
|
||||
mismatch_count += 1
|
||||
status = "FAIL!"
|
||||
|
||||
xserv_ttft_ms = xr["ttft_us"] / 1000.0
|
||||
xserv_tbt_ms = xr["tbt_us"] / 1000.0
|
||||
hf_ttft_ms = hf_ttft_us / 1000.0
|
||||
hf_tbt_ms = hf_tbt_us / 1000.0
|
||||
|
||||
prompt_short = prompt[:43] + ".." if len(prompt) > 45 else prompt
|
||||
print(f"{i+1:>3} {status} {prompt_short:<45} {xserv_ttft_ms:>8.1f}ms {hf_ttft_ms:>8.1f}ms {xserv_tbt_ms:>8.1f}ms {hf_tbt_ms:>8.1f}ms")
|
||||
|
||||
if not match:
|
||||
# Show first divergence
|
||||
for j in range(max(len(xserv_ids), len(hf_generated))):
|
||||
x = xserv_ids[j] if j < len(xserv_ids) else None
|
||||
h = hf_generated[j] if j < len(hf_generated) else None
|
||||
if x != h:
|
||||
x_tok = tokenizer.decode([x]) if x is not None else "<none>"
|
||||
h_tok = tokenizer.decode([h]) if h is not None else "<none>"
|
||||
print(f" ↳ diverge at token {j}: xserv={x}({repr(x_tok)}) vs hf={h}({repr(h_tok)})")
|
||||
break
|
||||
|
||||
xserv_ttft_sum += xr["ttft_us"]
|
||||
xserv_tbt_sum += xr["tbt_us"]
|
||||
hf_ttft_sum += hf_ttft_us
|
||||
hf_tbt_sum += hf_tbt_us
|
||||
if xr["tbt_us"] > 0:
|
||||
num_with_tbt += 1
|
||||
|
||||
print(f"{'='*100}")
|
||||
print(f"\n=== CORRECTNESS ===")
|
||||
print(f"Total prompts: {total}")
|
||||
print(f"Match: {match_count}/{total} ({match_count/total*100:.1f}%)")
|
||||
print(f"Mismatch: {mismatch_count}/{total}")
|
||||
|
||||
print(f"\n=== PERFORMANCE (average) ===")
|
||||
print(f"{'Metric':<20} {'xserv':>12} {'transformers':>12} {'ratio':>10}")
|
||||
print(f"{'-'*54}")
|
||||
avg_x_ttft = xserv_ttft_sum / total / 1000
|
||||
avg_h_ttft = hf_ttft_sum / total / 1000
|
||||
avg_x_tbt = xserv_tbt_sum / num_with_tbt / 1000 if num_with_tbt > 0 else 0
|
||||
avg_h_tbt = hf_tbt_sum / num_with_tbt / 1000 if num_with_tbt > 0 else 0
|
||||
print(f"{'TTFT (ms)':<20} {avg_x_ttft:>10.1f}ms {avg_h_ttft:>10.1f}ms {avg_x_ttft/avg_h_ttft:>9.1f}x")
|
||||
print(f"{'TBT (ms)':<20} {avg_x_tbt:>10.1f}ms {avg_h_tbt:>10.1f}ms {avg_x_tbt/avg_h_tbt if avg_h_tbt > 0 else 0:>9.1f}x")
|
||||
xserv_tps = 1000.0 / avg_x_tbt if avg_x_tbt > 0 else 0
|
||||
hf_tps = 1000.0 / avg_h_tbt if avg_h_tbt > 0 else 0
|
||||
print(f"{'Throughput (tok/s)':<20} {xserv_tps:>10.1f} {hf_tps:>10.1f} {xserv_tps/hf_tps if hf_tps > 0 else 0:>9.2f}x")
|
||||
|
||||
print(f"\nNote: xserv currently has no KV cache — full recompute per token.")
|
||||
print(f" transformers also runs without KV cache in this benchmark for fair comparison.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user