Files
xserv/crates/xserv-kernels
Gahow Wang 5a16225c1f quantization: cache cuBLASLt FP8 plan per shape — fix per-expert heuristic churn
batched_gemm_fp8 rebuilt the cuBLASLt matmul descriptor, four matrix
layouts, a preference, and a 4-byte scale alloc, AND ran the algo
heuristic search — once per expert, per GEMM, per layer, on every
forward (~1500 heuristic searches per decoded token). FP8 decode ran at
27.0 ms/tok vs BF16 18.8 ms, i.e. slower than the path it was meant to
accelerate.

Cache the full plan (descriptor + layouts + heuristically-chosen algo)
in a thread-local map keyed by (M, N, K) so the heuristic runs once per
shape and is reused across experts and forwards; allocate the 1.0 scale
buffer once; pass each expert's weight scale via the cuBLASLt B-scale
device pointer instead of folding it into alpha (identical FP32-epilogue
precision, and no host readback of b_scales). The per-expert loop now
issues only cublasLtMatmul.

Measured on dash5 (gpt-oss-20b, TP=2, 5090): FP8 decode TPOT 27.0 -> 17.9
ms, now faster than BF16 (18.8 ms); GSM8K-200 accuracy unchanged
(FP8 93.0% vs BF16 90.5%, within noise).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-12 00:58:46 +08:00
..