quantization: single strided-batched FP8 MoE GEMM — cut per-token launches ~768→48
The plan-cache fix removed the per-expert heuristic churn but still issued one cublasLtMatmul per expert: ~768 tiny launches per decoded token (16 local experts × 2 GEMMs × 24 layers), which capped the FP8 decode win at ~1.05× over BF16. Collapse each MoE GEMM into ONE strided-batched cuBLASLt FP8 matmul (BATCH_COUNT + strided-batch offsets on all four layouts) → ~48 launches/token. A single strided call can't carry a per-batch scalar B-scale, so the per-expert weight scale moves out of the GEMM epilogue into a fused post-scale kernel (rowwise_scale_moe_bf16) that applies a_scale[token]·b_scale[expert] in one pass. This is precision-equivalent: BF16's relative error is scale-invariant, so scaling the unscaled GEMM output afterward loses nothing vs scaling in-epilogue. Measured on dash5 (gpt-oss-20b, TP=2, 5090), warm-server GSM8K: decode TPOT 17.45 → 13.08 ms (FP8 now 1.41× vs BF16 18.39 ms), throughput 57.3 → 76.4 tok/s, accuracy unchanged (FP8 91.0% vs BF16 90.0%). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -23,10 +23,11 @@ unsafe extern "C" {
|
|||||||
num_rows: i32, cols: i32,
|
num_rows: i32, cols: i32,
|
||||||
stream: *mut c_void,
|
stream: *mut c_void,
|
||||||
);
|
);
|
||||||
fn launch_rowwise_scale_bf16(
|
fn launch_rowwise_scale_moe_bf16(
|
||||||
data: *mut c_void,
|
data: *mut c_void,
|
||||||
scales: *const c_void,
|
a_scales: *const c_void,
|
||||||
num_rows: i32, cols: i32,
|
b_scales: *const c_void,
|
||||||
|
num_rows: i32, cols: i32, tokens: i32,
|
||||||
stream: *mut c_void,
|
stream: *mut c_void,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -136,11 +137,11 @@ struct Fp8Plan {
|
|||||||
struct CublasLtContext {
|
struct CublasLtContext {
|
||||||
handle: CublasLtHandle,
|
handle: CublasLtHandle,
|
||||||
workspace: GpuBuffer,
|
workspace: GpuBuffer,
|
||||||
/// Persistent device scalar holding 1.0, used as the A/B scale pointer
|
/// Persistent device scalar holding 1.0, used as the A/B scale pointer.
|
||||||
/// placeholder. Allocated once instead of per-expert.
|
/// Scales are applied post-GEMM, so the in-GEMM scales stay 1.0.
|
||||||
one_buf: GpuBuffer,
|
one_buf: GpuBuffer,
|
||||||
/// Cache of prepared matmul plans keyed by (M, N, K).
|
/// Cache of prepared matmul plans keyed by (M, N, K, batch).
|
||||||
plans: HashMap<(usize, usize, usize), Fp8Plan>,
|
plans: HashMap<(usize, usize, usize, usize), Fp8Plan>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CublasLtContext {
|
impl CublasLtContext {
|
||||||
@@ -154,14 +155,15 @@ impl CublasLtContext {
|
|||||||
Self { handle, workspace, one_buf, plans: HashMap::new() }
|
Self { handle, workspace, one_buf, plans: HashMap::new() }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the cached plan for (m, n, k), building (and caching) it on first use.
|
/// Get the cached strided-batched plan for (m, n, k, batch), building it on
|
||||||
fn plan(&mut self, m: usize, n: usize, k: usize) -> Fp8Plan {
|
/// first use.
|
||||||
if let Some(p) = self.plans.get(&(m, n, k)) {
|
fn plan(&mut self, m: usize, n: usize, k: usize, batch: usize) -> Fp8Plan {
|
||||||
|
if let Some(p) = self.plans.get(&(m, n, k, batch)) {
|
||||||
return *p;
|
return *p;
|
||||||
}
|
}
|
||||||
let one_ptr = self.one_buf.as_ptr() as *const c_void;
|
let one_ptr = self.one_buf.as_ptr() as *const c_void;
|
||||||
let plan = unsafe { build_fp8_plan(self.handle, one_ptr, m, n, k) };
|
let plan = unsafe { build_fp8_plan(self.handle, one_ptr, m, n, k, batch) };
|
||||||
self.plans.insert((m, n, k), plan);
|
self.plans.insert((m, n, k, batch), plan);
|
||||||
plan
|
plan
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -184,16 +186,18 @@ impl Drop for CublasLtContext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build an FP8 matmul plan for one (m, n, k) shape. See `batched_gemm_fp8`
|
/// Build a strided-batched FP8 matmul plan for `batch` experts of one
|
||||||
/// for the row-major → cuBLASLt col-major layout mapping (transA=T, transB=N,
|
/// (m, n, k) shape. Row-major → cuBLASLt col-major mapping (transA=T,
|
||||||
/// m_lt=N, n_lt=M, k_lt=K). The B-scale pointer is initialised to `one_ptr`
|
/// transB=N, m_lt=N, n_lt=M, k_lt=K). A/B scale pointers stay at 1.0 — both
|
||||||
/// and overwritten per-expert at call time.
|
/// the per-expert weight scale and the per-token activation scale are applied
|
||||||
|
/// post-GEMM in a fused kernel, which lets all experts run in one matmul.
|
||||||
unsafe fn build_fp8_plan(
|
unsafe fn build_fp8_plan(
|
||||||
handle: CublasLtHandle,
|
handle: CublasLtHandle,
|
||||||
one_ptr: *const c_void,
|
one_ptr: *const c_void,
|
||||||
m: usize,
|
m: usize,
|
||||||
n: usize,
|
n: usize,
|
||||||
k: usize,
|
k: usize,
|
||||||
|
batch: usize,
|
||||||
) -> Fp8Plan {
|
) -> Fp8Plan {
|
||||||
let m_lt = n as u64;
|
let m_lt = n as u64;
|
||||||
let n_lt = m as u64;
|
let n_lt = m as u64;
|
||||||
@@ -209,17 +213,33 @@ unsafe fn build_fp8_plan(
|
|||||||
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
|
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
|
||||||
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
|
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
|
||||||
|
|
||||||
|
// Per-expert strides in ELEMENTS for the strided-batch layout.
|
||||||
|
let stride_a = (n * k) as i64; // weights [N, K]
|
||||||
|
let stride_b = (m * k) as i64; // activations [M, K]
|
||||||
|
let stride_c = (m * n) as i64; // output [M, N]
|
||||||
|
let bc = batch as i32;
|
||||||
|
let set_batch = |layout: CublasLtMatrixLayout, stride: i64| {
|
||||||
|
cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||||
|
&bc as *const i32 as _, 4);
|
||||||
|
cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||||
|
&stride as *const i64 as _, 8);
|
||||||
|
};
|
||||||
|
|
||||||
// "A" layout (weights, transposed): physical (K, N) col-major, ld=K
|
// "A" layout (weights, transposed): physical (K, N) col-major, ld=K
|
||||||
let mut a_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
let mut a_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||||
cublasLtMatrixLayoutCreate(&mut a_layout, CUDA_R_8F_E4M3, k_lt, m_lt, k as i64);
|
cublasLtMatrixLayoutCreate(&mut a_layout, CUDA_R_8F_E4M3, k_lt, m_lt, k as i64);
|
||||||
|
set_batch(a_layout, stride_a);
|
||||||
// "B" layout (activations): physical (K, M) col-major, ld=K
|
// "B" layout (activations): physical (K, M) col-major, ld=K
|
||||||
let mut b_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
let mut b_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||||
cublasLtMatrixLayoutCreate(&mut b_layout, CUDA_R_8F_E4M3, k_lt, n_lt, k as i64);
|
cublasLtMatrixLayoutCreate(&mut b_layout, CUDA_R_8F_E4M3, k_lt, n_lt, k as i64);
|
||||||
|
set_batch(b_layout, stride_b);
|
||||||
// "C"/"D" layout (output): physical (N, M) col-major, ld=N
|
// "C"/"D" layout (output): physical (N, M) col-major, ld=N
|
||||||
let mut c_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
let mut c_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||||
cublasLtMatrixLayoutCreate(&mut c_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
|
cublasLtMatrixLayoutCreate(&mut c_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
|
||||||
|
set_batch(c_layout, stride_c);
|
||||||
let mut d_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
let mut d_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||||
cublasLtMatrixLayoutCreate(&mut d_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
|
cublasLtMatrixLayoutCreate(&mut d_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
|
||||||
|
set_batch(d_layout, stride_c);
|
||||||
|
|
||||||
let mut pref: CublasLtMatmulPreference = std::ptr::null_mut();
|
let mut pref: CublasLtMatmulPreference = std::ptr::null_mut();
|
||||||
cublasLtMatmulPreferenceCreate(&mut pref);
|
cublasLtMatmulPreferenceCreate(&mut pref);
|
||||||
@@ -233,7 +253,7 @@ unsafe fn build_fp8_plan(
|
|||||||
pref, 1, &mut heuristic, &mut found,
|
pref, 1, &mut heuristic, &mut found,
|
||||||
);
|
);
|
||||||
assert!(status == 0 && found > 0,
|
assert!(status == 0 && found > 0,
|
||||||
"cublasLtMatmulAlgoGetHeuristic failed for FP8 GEMM (m={m}, n={n}, k={k}): status={status}, found={found}");
|
"cublasLtMatmulAlgoGetHeuristic failed for batched FP8 GEMM (m={m}, n={n}, k={k}, batch={batch}): status={status}, found={found}");
|
||||||
cublasLtMatmulPreferenceDestroy(pref);
|
cublasLtMatmulPreferenceDestroy(pref);
|
||||||
|
|
||||||
Fp8Plan {
|
Fp8Plan {
|
||||||
@@ -354,71 +374,54 @@ pub fn batched_gemm_fp8(
|
|||||||
|
|
||||||
let c = Tensor::empty(&[batch, m, n], DType::BF16, a_fp8.device());
|
let c = Tensor::empty(&[batch, m, n], DType::BF16, a_fp8.device());
|
||||||
|
|
||||||
// Strides (in bytes) for one expert slice
|
|
||||||
let stride_a = m * k; // FP8: 1 byte per elem
|
|
||||||
let stride_b = n * k; // FP8: 1 byte per elem (transposed: [N, K])
|
|
||||||
let stride_c = m * n * 2; // BF16: 2 bytes per elem
|
|
||||||
|
|
||||||
CUBLASLT_CTX.with(|cell| {
|
CUBLASLT_CTX.with(|cell| {
|
||||||
let mut ctx = cell.borrow_mut();
|
let mut ctx = cell.borrow_mut();
|
||||||
let handle = ctx.handle;
|
let handle = ctx.handle;
|
||||||
let ws_ptr = ctx.workspace.as_ptr() as *mut c_void;
|
let ws_ptr = ctx.workspace.as_ptr() as *mut c_void;
|
||||||
// Build (or fetch) the cached plan for this shape — heuristic search and
|
// Cached strided-batched plan: heuristic + descriptor/layout creation
|
||||||
// descriptor/layout creation happen once per (m, n, k), not per-expert.
|
// happen once per (m, n, k, batch). All experts run in ONE matmul.
|
||||||
let plan = ctx.plan(m, n, k);
|
let plan = ctx.plan(m, n, k, batch);
|
||||||
|
|
||||||
// alpha=1, beta=0. Per-expert weight scale is supplied via the cuBLASLt
|
// alpha=1, beta=0, in-GEMM scales=1.0. The unscaled result
|
||||||
// B-scale pointer (device, scalar): cuBLASLt computes in the FP32 epilogue
|
// D_raw[e] = A_fp8[e] @ B_fp8[e]^T
|
||||||
// D = (1.0 * A_fp8) @ (b_scale[e] * B_fp8)^T = b_scale[e] * (A_fp8 @ B_fp8^T)
|
// is recovered to the real value by the fused post-scale kernel below.
|
||||||
// Per-token activation scale (a_scale) is applied post-GEMM (per row).
|
|
||||||
let alpha: f32 = 1.0;
|
let alpha: f32 = 1.0;
|
||||||
let beta: f32 = 0.0;
|
let beta: f32 = 0.0;
|
||||||
let ptr_sz = std::mem::size_of::<*const c_void>();
|
|
||||||
|
|
||||||
for e in 0..batch {
|
|
||||||
let a_ptr = unsafe { (a_fp8.data_ptr() as *const u8).add(e * stride_a) as *const c_void };
|
|
||||||
let b_ptr = unsafe { (b_fp8_t.data_ptr() as *const u8).add(e * stride_b) as *const c_void };
|
|
||||||
let c_ptr = unsafe { (c.data_ptr() as *mut u8).add(e * stride_c) as *mut c_void };
|
|
||||||
// Device pointer to this expert's scalar weight scale (FP32, 4 bytes).
|
|
||||||
let b_scale_ptr = unsafe { (b_scales.data_ptr() as *const u8).add(e * 4) as *const c_void };
|
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
cublasLtMatmulDescSetAttribute(
|
|
||||||
plan.desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
|
|
||||||
&b_scale_ptr as *const _ as _, ptr_sz,
|
|
||||||
);
|
|
||||||
let status = cublasLtMatmul(
|
let status = cublasLtMatmul(
|
||||||
handle, plan.desc,
|
handle, plan.desc,
|
||||||
&alpha as *const f32 as _,
|
&alpha as *const f32 as _,
|
||||||
b_ptr, // cuBLASLt "A" = weights
|
b_fp8_t.data_ptr() as *const c_void, // cuBLASLt "A" = weights
|
||||||
plan.a_layout,
|
plan.a_layout,
|
||||||
a_ptr, // cuBLASLt "B" = activations
|
a_fp8.data_ptr() as *const c_void, // cuBLASLt "B" = activations
|
||||||
plan.b_layout,
|
plan.b_layout,
|
||||||
&beta as *const f32 as _,
|
&beta as *const f32 as _,
|
||||||
c_ptr, // C (unused with beta=0)
|
c.data_ptr() as *const c_void, // C (unused with beta=0)
|
||||||
plan.c_layout,
|
plan.c_layout,
|
||||||
c_ptr, // D = output
|
c.data_ptr() as *mut c_void, // D = output
|
||||||
plan.d_layout,
|
plan.d_layout,
|
||||||
&plan.algo,
|
&plan.algo,
|
||||||
ws_ptr,
|
ws_ptr,
|
||||||
plan.workspace_size,
|
plan.workspace_size,
|
||||||
std::ptr::null_mut(),
|
std::ptr::null_mut(),
|
||||||
);
|
);
|
||||||
assert_eq!(status, 0, "cublasLtMatmul FP8 failed for expert {e}: status={status}");
|
assert_eq!(status, 0, "batched cublasLtMatmul FP8 failed: status={status}");
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Post-GEMM: multiply each row of c by its activation scale.
|
// Post-GEMM: recover the real result in one pass.
|
||||||
// c is [batch, M, N] BF16. a_scales is [batch * M] F32.
|
// c[e, t, :] *= a_scales[e*M + t] * b_scales[e]
|
||||||
// This recovers the per-token scale that was divided out during quantization.
|
// (per-token activation scale × per-expert weight scale). BF16's relative
|
||||||
|
// error is scale-invariant, so applying the scale here is precision-
|
||||||
|
// equivalent to folding it into the GEMM epilogue.
|
||||||
let total_rows = (batch * m) as i32;
|
let total_rows = (batch * m) as i32;
|
||||||
let cols = n as i32;
|
|
||||||
unsafe {
|
unsafe {
|
||||||
launch_rowwise_scale_bf16(
|
launch_rowwise_scale_moe_bf16(
|
||||||
c.data_ptr() as *mut c_void,
|
c.data_ptr() as *mut c_void,
|
||||||
a_scales.data_ptr() as *const c_void,
|
a_scales.data_ptr() as *const c_void,
|
||||||
total_rows, cols,
|
b_scales.data_ptr() as *const c_void,
|
||||||
|
total_rows, n as i32, m as i32,
|
||||||
std::ptr::null_mut(),
|
std::ptr::null_mut(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -86,6 +86,29 @@ __global__ void rowwise_scale_bf16_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Combined dequant scale for batched MoE FP8 GEMM output.
|
||||||
|
// data[row, :] *= a_scales[row] * b_scales[row / tokens]
|
||||||
|
// where row = expert * tokens + token. a_scales is the per-token activation
|
||||||
|
// scale; b_scales is the per-expert scalar weight scale. Lets a single
|
||||||
|
// strided-batched FP8 matmul (alpha=1, scales=1) recover the real result in
|
||||||
|
// one pass instead of folding the weight scale into a per-expert GEMM call.
|
||||||
|
__global__ void rowwise_scale_moe_bf16_kernel(
|
||||||
|
__nv_bfloat16* __restrict__ data,
|
||||||
|
const float* __restrict__ a_scales,
|
||||||
|
const float* __restrict__ b_scales,
|
||||||
|
int num_rows, int cols, int tokens
|
||||||
|
) {
|
||||||
|
int row = blockIdx.x;
|
||||||
|
if (row >= num_rows) return;
|
||||||
|
int tid = threadIdx.x;
|
||||||
|
float s = a_scales[row] * b_scales[row / tokens];
|
||||||
|
__nv_bfloat16* row_data = data + (long long)row * cols;
|
||||||
|
for (int i = tid; i < cols; i += blockDim.x) {
|
||||||
|
float v = __bfloat162float(row_data[i]) * s;
|
||||||
|
row_data[i] = __float2bfloat16(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
void launch_rowwise_scale_bf16(
|
void launch_rowwise_scale_bf16(
|
||||||
@@ -102,6 +125,20 @@ void launch_rowwise_scale_bf16(
|
|||||||
CUDA_CHECK_LAST_ERROR();
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void launch_rowwise_scale_moe_bf16(
|
||||||
|
void* data, const void* a_scales, const void* b_scales,
|
||||||
|
int num_rows, int cols, int tokens,
|
||||||
|
void* stream
|
||||||
|
) {
|
||||||
|
int block = 256;
|
||||||
|
int grid = num_rows;
|
||||||
|
rowwise_scale_moe_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
|
(__nv_bfloat16*)data, (const float*)a_scales, (const float*)b_scales,
|
||||||
|
num_rows, cols, tokens
|
||||||
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
void launch_quantize_bf16_to_fp8e4m3_rowwise(
|
void launch_quantize_bf16_to_fp8e4m3_rowwise(
|
||||||
const void* src,
|
const void* src,
|
||||||
void* dst,
|
void* dst,
|
||||||
|
|||||||
@@ -14,9 +14,10 @@ stay BF16.
|
|||||||
- **Activations**: quantized dynamically at runtime, **per-token** (per-row
|
- **Activations**: quantized dynamically at runtime, **per-token** (per-row
|
||||||
absmax), recovered by a post-GEMM row scale.
|
absmax), recovered by a post-GEMM row scale.
|
||||||
- **Compute**: `batched_gemm_fp8` (`crates/xserv-kernels/src/quantization.rs`)
|
- **Compute**: `batched_gemm_fp8` (`crates/xserv-kernels/src/quantization.rs`)
|
||||||
runs one cuBLASLt FP8 matmul per expert; the per-expert weight scale is
|
runs **one strided-batched cuBLASLt FP8 matmul for all experts** (`alpha=1`,
|
||||||
supplied via the cuBLASLt B-scale device pointer (FP32 epilogue, so precision
|
in-GEMM scales `1.0`); a fused kernel then applies `a_scale[token]·b_scale[expert]`
|
||||||
matches folding it into `alpha`).
|
in a single pass. BF16's relative error is scale-invariant, so applying both
|
||||||
|
scales post-GEMM is precision-equivalent to folding them into the epilogue.
|
||||||
- Model size: **22 GB** (FP8) vs **39 GB** (BF16). The FP8 model fits on a
|
- Model size: **22 GB** (FP8) vs **39 GB** (BF16). The FP8 model fits on a
|
||||||
single 32 GB 5090; BF16 needs ≥ 2.
|
single 32 GB 5090; BF16 needs ≥ 2.
|
||||||
|
|
||||||
@@ -34,34 +35,50 @@ decoded token. This made FP8 **slower than BF16**:
|
|||||||
| Throughput | 37 tok/s | **55.8 tok/s** | 53.2 tok/s |
|
| Throughput | 37 tok/s | **55.8 tok/s** | 53.2 tok/s |
|
||||||
|
|
||||||
Fix: cache the cuBLASLt plan (descriptor + layouts + heuristically-chosen algo)
|
Fix: cache the cuBLASLt plan (descriptor + layouts + heuristically-chosen algo)
|
||||||
in a thread-local map keyed by `(M, N, K)` so the heuristic runs once per shape;
|
in a thread-local map keyed by `(M, N, K, batch)` so the heuristic runs once per
|
||||||
allocate the scale buffer once; pass per-expert weight scales by device pointer.
|
shape, and allocate the scale buffer once.
|
||||||
The per-expert loop now issues only `cublasLtMatmul`.
|
|
||||||
|
|
||||||
## Results — GSM8K (200 problems, greedy, TP=2 on the same 2 GPUs)
|
## Reducing launches: one strided-batched matmul
|
||||||
|
|
||||||
|
The per-expert loop still issued one `cublasLtMatmul` per expert — ~768 tiny
|
||||||
|
launches per decoded token (16 local experts × 2 GEMMs × 24 layers). Collapsing
|
||||||
|
each MoE GEMM into a single **strided-batched** cuBLASLt FP8 matmul (BATCH_COUNT
|
||||||
|
+ strided-batch offsets) drops that to ~48, with a fused post-scale kernel
|
||||||
|
applying both scales. This required moving the per-expert weight scale out of the
|
||||||
|
GEMM epilogue (a single strided call can't carry a per-batch scalar) into the
|
||||||
|
post-scale kernel — precision-equivalent, as noted above.
|
||||||
|
|
||||||
|
| (gpt-oss-20b, TP=2) | per-expert FP8 | batched FP8 | BF16 |
|
||||||
|
|---|---|---|---|
|
||||||
|
| Decode TPOT | 17.9 ms | **13.8 ms** | 18.8 ms |
|
||||||
|
| Throughput | 55.8 tok/s | **72.3 tok/s** | 53.2 tok/s |
|
||||||
|
|
||||||
|
## Results — GSM8K (greedy, TP=2 on the same 2 GPUs)
|
||||||
|
|
||||||
|
200-problem run is the per-expert plan-cache fix; 100-problem run is the
|
||||||
|
strided-batched version. BF16 is the unchanged baseline in both.
|
||||||
|
|
||||||
Harness: `tools/fp8_compare.py` — a warm `xserv-server` per model, GSM8K streamed
|
Harness: `tools/fp8_compare.py` — a warm `xserv-server` per model, GSM8K streamed
|
||||||
through `/v1/chat/completions`; TTFT = time to first token, TPOT = mean
|
through `/v1/chat/completions`; TTFT = time to first token, TPOT = mean
|
||||||
inter-token latency, per request.
|
inter-token latency, per request.
|
||||||
|
|
||||||
| metric | FP8 W8A8 | BF16 |
|
| metric | FP8 per-expert (n=200) | FP8 batched (n=100) | BF16 |
|
||||||
|---|---|---|
|
|---|---|---|---|
|
||||||
| GSM8K accuracy | **93.0 %** | 90.5 % |
|
| GSM8K accuracy | 93.0 % | 91.0 % | 90.5 / 90.0 % |
|
||||||
| TTFT median | 67.4 ms | 68.8 ms |
|
| TTFT median | 67.4 ms | 65.0 ms | 68.8 / 69.5 ms |
|
||||||
| TTFT p90 | 90.4 ms | 96.7 ms |
|
| TPOT median | 17.45 ms | **13.08 ms** | 18.26 / 18.39 ms |
|
||||||
| TPOT median | **17.45 ms** | 18.26 ms |
|
| TPOT p90 | 17.65 ms | **13.28 ms** | 18.38 / 18.52 ms |
|
||||||
| TPOT p90 | 17.65 ms | 18.38 ms |
|
| Throughput | 57.3 tok/s | **76.4 tok/s** | 54.8 / 54.4 tok/s |
|
||||||
| Throughput | **57.3 tok/s** | 54.8 tok/s |
|
| Decode speedup vs BF16 | 1.05× | **1.41×** | 1.00× |
|
||||||
| Mean output tokens | 288 | 293 |
|
|
||||||
|
|
||||||
- **Accuracy: unchanged.** FP8 is nominally +2.5 pts, but with n=200 the
|
- **Accuracy: unchanged.** FP8 is nominally +0.5 … +2.5 pts above BF16, but at
|
||||||
standard error is ~2.1 pts, so the two are statistically indistinguishable.
|
n=100–200 the standard error is ~2–3 pts, so they are statistically
|
||||||
The takeaway is that FP8 did **not** degrade accuracy.
|
indistinguishable. The takeaway is that neither FP8 quantization nor the
|
||||||
- **Decode: FP8 ~5 % faster** (TPOT 17.45 vs 18.26 ms), reproducible across
|
strided-batched rounding degrades accuracy.
|
||||||
runs, with a tighter p90. Modest because the dense-MoE path loads *all*
|
- **Decode: FP8 1.41× faster** once batched (TPOT 13.08 vs 18.39 ms), with a
|
||||||
experts every token and FP8 only halves the *expert* bytes; the per-expert
|
tight p90. The per-expert version was only ~1.05× — the ~768 tiny M=1 launches
|
||||||
M=1 launches and M=1 tensor-core inefficiency absorb much of the bandwidth
|
per token dominated; batching them into ~48 unlocked most of the FP8
|
||||||
saving.
|
expert-weight-bandwidth saving.
|
||||||
- **Prefill (TTFT): comparable.** A multi-length sweep (113 / 561 / 1681 tokens)
|
- **Prefill (TTFT): comparable.** A multi-length sweep (113 / 561 / 1681 tokens)
|
||||||
gave FP8 480 / 362 / 2451 ms vs BF16 558 / 282 / 2287 ms — non-monotonic, i.e.
|
gave FP8 480 / 362 / 2451 ms vs BF16 558 / 282 / 2287 ms — non-monotonic, i.e.
|
||||||
dominated by fixed overhead (cuBLAS lazy init + FP8's one-time per-shape
|
dominated by fixed overhead (cuBLAS lazy init + FP8's one-time per-shape
|
||||||
@@ -75,9 +92,8 @@ that otherwise needs two GPUs onto one — is the largest practical win.
|
|||||||
|
|
||||||
## Follow-ups (not done)
|
## Follow-ups (not done)
|
||||||
|
|
||||||
- Strided-batched FP8 (one call instead of ~768 per-expert launches per token) —
|
|
||||||
requires folding the per-expert weight scale into the post-scale kernel, at a
|
|
||||||
BF16-intermediate precision cost.
|
|
||||||
- Per-channel (per-output-row) weight scales for better accuracy headroom than
|
- Per-channel (per-output-row) weight scales for better accuracy headroom than
|
||||||
per-tensor.
|
per-tensor.
|
||||||
- Warm common prefill shapes at load to hide the first-request heuristic stall.
|
- Warm common prefill shapes at load to hide the first-request heuristic stall.
|
||||||
|
- Sparse (top-k only) MoE compute instead of dense — currently every token runs
|
||||||
|
all experts, so only ~top_k/num_experts of the FP8 GEMM work is used.
|
||||||
|
|||||||
Reference in New Issue
Block a user