phase 9: KV cache + autoregressive generation
- KVCache: per-layer, per-head storage with append + reconstruct - forward_with_cache: prefill (full prompt) + decode (single token) modes - Fixed data layout bug: per-head vectors avoid cross-head interleaving - CLI updated to use KV cache by default - bench-gpt2 supports --no-cache flag for comparison Benchmark results (50 prompts × 20 tokens): - KV cache vs no-cache: 50/50 bit-identical (cache is correct) - 18x speedup: TTFT 400→24ms, TBT 407→22ms, throughput 2.5→44 tok/s - vs HF transformers: 40/50 match (10 are FP divergence, avg logit gap 0.20) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
40
tools/analyze_divergence.py
Normal file
40
tools/analyze_divergence.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import json
|
||||
import sys
|
||||
import torch
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained(sys.argv[2]).eval().cuda()
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(sys.argv[2])
|
||||
|
||||
with open(sys.argv[1]) as f:
|
||||
xr = json.load(f)
|
||||
|
||||
mismatches = []
|
||||
for i in range(len(xr)):
|
||||
ids = tokenizer.encode(xr[i]["prompt"])
|
||||
all_ids = list(ids)
|
||||
xserv_gen = xr[i]["generated_ids"]
|
||||
with torch.no_grad():
|
||||
for j in range(len(xserv_gen)):
|
||||
out = model(torch.tensor([all_ids]).cuda())
|
||||
logits = out.logits[0, -1]
|
||||
hf_next = logits.argmax().item()
|
||||
xs_next = xserv_gen[j]
|
||||
if hf_next != xs_next:
|
||||
xs_logit = logits[xs_next].item()
|
||||
hf_logit = logits[hf_next].item()
|
||||
hf_tok = tokenizer.decode([hf_next])
|
||||
xs_tok = tokenizer.decode([xs_next])
|
||||
gap = hf_logit - xs_logit
|
||||
print(
|
||||
f'[{i+1}] "{xr[i]["prompt"][:42]}" @ tok {j}: '
|
||||
f'hf={repr(hf_tok)}({hf_logit:.3f}) xserv={repr(xs_tok)}({xs_logit:.3f}) '
|
||||
f'gap={gap:.4f}'
|
||||
)
|
||||
mismatches.append(gap)
|
||||
break
|
||||
all_ids.append(hf_next)
|
||||
|
||||
print(f"\nTotal: {len(mismatches)}/{len(xr)} mismatches")
|
||||
if mismatches:
|
||||
print(f"Logit gaps: min={min(mismatches):.4f} max={max(mismatches):.4f} avg={sum(mismatches)/len(mismatches):.4f}")
|
||||
Reference in New Issue
Block a user