quantization: W8A8 FP8 compute via cuBLASLt tensor cores
Replace the W8A16 dequant→BF16-GEMM path with native FP8×FP8→BF16 GEMM using cuBLASLt on Blackwell (RTX 5090). Both weights (static FP8 E4M3) and activations (dynamically quantized per-row) are processed directly on FP8 tensor cores. Key implementation details: - cuBLASLt on Blackwell requires transA=T for FP8, so expert weights are transposed during model loading ([E,K,N] → [E,N,K]) - Per-row activation quantization kernel (absmax/448 → FP8 E4M3) - Post-GEMM row-wise rescaling recovers per-token precision - Per-expert loop (not batched) due to cuBLASLt FP8 scale constraints The same FP8 quantized model files work — no re-quantization needed. Activation quantization happens dynamically at inference time. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -8,6 +8,7 @@ fn main() {
|
|||||||
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
|
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
|
||||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||||
println!("cargo:rustc-link-lib=dylib=cublas");
|
println!("cargo:rustc-link-lib=dylib=cublas");
|
||||||
|
println!("cargo:rustc-link-lib=dylib=cublasLt");
|
||||||
|
|
||||||
cc::Build::new()
|
cc::Build::new()
|
||||||
.cuda(true)
|
.cuda(true)
|
||||||
@@ -31,6 +32,7 @@ fn main() {
|
|||||||
.file("../../csrc/attention/reshape_and_cache.cu")
|
.file("../../csrc/attention/reshape_and_cache.cu")
|
||||||
.file("../../csrc/moe/moe_kernels.cu")
|
.file("../../csrc/moe/moe_kernels.cu")
|
||||||
.file("../../csrc/quantization/dequant_fp8.cu")
|
.file("../../csrc/quantization/dequant_fp8.cu")
|
||||||
|
.file("../../csrc/quantization/quantize_fp8.cu")
|
||||||
.compile("xserv_kernels");
|
.compile("xserv_kernels");
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed=../../csrc/");
|
println!("cargo:rerun-if-changed=../../csrc/");
|
||||||
|
|||||||
@@ -1,6 +1,12 @@
|
|||||||
|
use std::cell::RefCell;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
|
use xserv_cuda::GpuBuffer;
|
||||||
use xserv_tensor::{DType, Tensor};
|
use xserv_tensor::{DType, Tensor};
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// FFI: custom CUDA kernels
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
fn launch_dequant_fp8e4m3_to_bf16(
|
fn launch_dequant_fp8e4m3_to_bf16(
|
||||||
src: *const c_void,
|
src: *const c_void,
|
||||||
@@ -9,8 +15,135 @@ unsafe extern "C" {
|
|||||||
num_experts: i32, rows: i32, cols: i32,
|
num_experts: i32, rows: i32, cols: i32,
|
||||||
stream: *mut c_void,
|
stream: *mut c_void,
|
||||||
);
|
);
|
||||||
|
fn launch_quantize_bf16_to_fp8e4m3_rowwise(
|
||||||
|
src: *const c_void,
|
||||||
|
dst: *mut c_void,
|
||||||
|
scales: *mut c_void,
|
||||||
|
num_rows: i32, cols: i32,
|
||||||
|
stream: *mut c_void,
|
||||||
|
);
|
||||||
|
fn launch_rowwise_scale_bf16(
|
||||||
|
data: *mut c_void,
|
||||||
|
scales: *const c_void,
|
||||||
|
num_rows: i32, cols: i32,
|
||||||
|
stream: *mut c_void,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// FFI: cuBLASLt
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
type CublasLtHandle = *mut c_void;
|
||||||
|
type CublasLtMatmulDesc = *mut c_void;
|
||||||
|
type CublasLtMatrixLayout = *mut c_void;
|
||||||
|
type CublasLtMatmulPreference = *mut c_void;
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
struct CublasLtMatmulAlgo {
|
||||||
|
data: [u64; 8],
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct CublasLtMatmulHeuristicResult {
|
||||||
|
algo: CublasLtMatmulAlgo,
|
||||||
|
workspace_size: usize,
|
||||||
|
state: i32,
|
||||||
|
_reserved: [f32; 4],
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe extern "C" {
|
||||||
|
fn cublasLtCreate(handle: *mut CublasLtHandle) -> i32;
|
||||||
|
fn cublasLtDestroy(handle: CublasLtHandle) -> i32;
|
||||||
|
fn cublasLtMatmulDescCreate(desc: *mut CublasLtMatmulDesc, compute_type: i32, scale_type: i32) -> i32;
|
||||||
|
fn cublasLtMatmulDescDestroy(desc: CublasLtMatmulDesc) -> i32;
|
||||||
|
fn cublasLtMatmulDescSetAttribute(desc: CublasLtMatmulDesc, attr: i32, buf: *const c_void, size: usize) -> i32;
|
||||||
|
fn cublasLtMatrixLayoutCreate(layout: *mut CublasLtMatrixLayout, dtype: i32, rows: u64, cols: u64, ld: i64) -> i32;
|
||||||
|
fn cublasLtMatrixLayoutDestroy(layout: CublasLtMatrixLayout) -> i32;
|
||||||
|
fn cublasLtMatrixLayoutSetAttribute(layout: CublasLtMatrixLayout, attr: i32, buf: *const c_void, size: usize) -> i32;
|
||||||
|
fn cublasLtMatmulPreferenceCreate(pref: *mut CublasLtMatmulPreference) -> i32;
|
||||||
|
fn cublasLtMatmulPreferenceDestroy(pref: CublasLtMatmulPreference) -> i32;
|
||||||
|
fn cublasLtMatmulPreferenceSetAttribute(pref: CublasLtMatmulPreference, attr: i32, buf: *const c_void, size: usize) -> i32;
|
||||||
|
fn cublasLtMatmulAlgoGetHeuristic(
|
||||||
|
handle: CublasLtHandle, desc: CublasLtMatmulDesc,
|
||||||
|
a_layout: CublasLtMatrixLayout, b_layout: CublasLtMatrixLayout,
|
||||||
|
c_layout: CublasLtMatrixLayout, d_layout: CublasLtMatrixLayout,
|
||||||
|
pref: CublasLtMatmulPreference,
|
||||||
|
requested: i32,
|
||||||
|
results: *mut CublasLtMatmulHeuristicResult,
|
||||||
|
found: *mut i32,
|
||||||
|
) -> i32;
|
||||||
|
fn cublasLtMatmul(
|
||||||
|
handle: CublasLtHandle, desc: CublasLtMatmulDesc,
|
||||||
|
alpha: *const c_void,
|
||||||
|
a: *const c_void, a_layout: CublasLtMatrixLayout,
|
||||||
|
b: *const c_void, b_layout: CublasLtMatrixLayout,
|
||||||
|
beta: *const c_void,
|
||||||
|
c: *const c_void, c_layout: CublasLtMatrixLayout,
|
||||||
|
d: *mut c_void, d_layout: CublasLtMatrixLayout,
|
||||||
|
algo: *const CublasLtMatmulAlgo,
|
||||||
|
workspace: *mut c_void, workspace_size: usize,
|
||||||
|
stream: *mut c_void,
|
||||||
|
) -> i32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// cuBLASLt constants
|
||||||
|
const CUBLAS_COMPUTE_32F: i32 = 68;
|
||||||
|
const CUDA_R_32F: i32 = 0;
|
||||||
|
const CUDA_R_16BF: i32 = 14;
|
||||||
|
const CUDA_R_8F_E4M3: i32 = 28;
|
||||||
|
|
||||||
|
// MatmulDesc attributes
|
||||||
|
const CUBLASLT_MATMUL_DESC_A_SCALE_POINTER: i32 = 17;
|
||||||
|
const CUBLASLT_MATMUL_DESC_B_SCALE_POINTER: i32 = 18;
|
||||||
|
const CUBLASLT_MATMUL_DESC_A_SCALE_MODE: i32 = 31;
|
||||||
|
const CUBLASLT_MATMUL_DESC_B_SCALE_MODE: i32 = 32;
|
||||||
|
|
||||||
|
// MatrixLayout attributes
|
||||||
|
const CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT: i32 = 5;
|
||||||
|
const CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET: i32 = 6;
|
||||||
|
|
||||||
|
// Scale modes
|
||||||
|
const CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR: i32 = 0;
|
||||||
|
const CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F: i32 = 3;
|
||||||
|
|
||||||
|
// MatmulPreference attributes
|
||||||
|
const CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES: i32 = 1;
|
||||||
|
|
||||||
|
const WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
|
||||||
|
|
||||||
|
struct CublasLtContext {
|
||||||
|
handle: CublasLtHandle,
|
||||||
|
workspace: GpuBuffer,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CublasLtContext {
|
||||||
|
fn new() -> Self {
|
||||||
|
let mut handle = std::ptr::null_mut();
|
||||||
|
let status = unsafe { cublasLtCreate(&mut handle) };
|
||||||
|
assert_eq!(status, 0, "cublasLtCreate failed: {status}");
|
||||||
|
let workspace = GpuBuffer::alloc(WORKSPACE_BYTES).expect("alloc cublasLt workspace");
|
||||||
|
Self { handle, workspace }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for CublasLtContext {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if !self.handle.is_null() {
|
||||||
|
unsafe { cublasLtDestroy(self.handle) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static CUBLASLT_CTX: RefCell<CublasLtContext> = RefCell::new(CublasLtContext::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Public API
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
/// Dequantize a 3D FP8 E4M3 tensor to BF16 using per-expert FP32 scales.
|
/// Dequantize a 3D FP8 E4M3 tensor to BF16 using per-expert FP32 scales.
|
||||||
///
|
///
|
||||||
/// src: [num_experts, rows, cols] FP8E4M3, contiguous, GPU
|
/// src: [num_experts, rows, cols] FP8E4M3, contiguous, GPU
|
||||||
@@ -44,3 +177,207 @@ pub fn dequant_fp8_to_bf16(src: &Tensor, scales: &Tensor) -> Tensor {
|
|||||||
|
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Dynamically quantize a contiguous BF16 tensor to FP8 E4M3 with per-row scales.
|
||||||
|
///
|
||||||
|
/// src: [num_rows, cols] or [batch, rows, cols] BF16, contiguous, GPU
|
||||||
|
/// Treats the tensor as 2D (flattens leading dims into num_rows).
|
||||||
|
///
|
||||||
|
/// Returns: (fp8_data [same shape] FP8E4M3, scales [total_rows] F32)
|
||||||
|
pub fn quantize_bf16_to_fp8_rowwise(src: &Tensor) -> (Tensor, Tensor) {
|
||||||
|
assert_eq!(src.dtype(), DType::BF16);
|
||||||
|
assert!(src.is_contiguous());
|
||||||
|
assert!(src.ndim() >= 2);
|
||||||
|
|
||||||
|
let cols = src.shape()[src.ndim() - 1];
|
||||||
|
let num_rows: usize = src.shape()[..src.ndim() - 1].iter().product();
|
||||||
|
|
||||||
|
let fp8_out = Tensor::empty(src.shape(), DType::FP8E4M3, src.device());
|
||||||
|
let scales = Tensor::empty(&[num_rows], DType::F32, src.device());
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
launch_quantize_bf16_to_fp8e4m3_rowwise(
|
||||||
|
src.data_ptr() as *const c_void,
|
||||||
|
fp8_out.data_ptr() as *mut c_void,
|
||||||
|
scales.data_ptr() as *mut c_void,
|
||||||
|
num_rows as i32, cols as i32,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
(fp8_out, scales)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// FP8 batched GEMM via cuBLASLt (transA=T required on Blackwell).
|
||||||
|
///
|
||||||
|
/// Computes: C[b] = scale_a[b] * scale_b[b] * (A_fp8[b] @ B_fp8_T[b]^T)
|
||||||
|
/// effectively C[b] = A[b, M, K] @ W[b, K, N] but W is stored transposed
|
||||||
|
/// as [b, N, K] for cuBLASLt FP8 compatibility.
|
||||||
|
///
|
||||||
|
/// a_fp8: [batch, M, K] FP8E4M3 (activations, quantized per-row)
|
||||||
|
/// a_scales: [batch * M] F32 (per-token scales, collapsed to per-batch max)
|
||||||
|
/// b_fp8_t: [batch, N, K] FP8E4M3 (weights, TRANSPOSED for cuBLASLt)
|
||||||
|
/// b_scales: [batch] F32 (per-expert scalar scales)
|
||||||
|
///
|
||||||
|
/// Returns: [batch, M, N] BF16
|
||||||
|
pub fn batched_gemm_fp8(
|
||||||
|
a_fp8: &Tensor,
|
||||||
|
a_scales: &Tensor,
|
||||||
|
b_fp8_t: &Tensor,
|
||||||
|
b_scales: &Tensor,
|
||||||
|
) -> Tensor {
|
||||||
|
assert_eq!(a_fp8.ndim(), 3);
|
||||||
|
assert_eq!(b_fp8_t.ndim(), 3);
|
||||||
|
assert_eq!(a_fp8.dtype(), DType::FP8E4M3);
|
||||||
|
assert_eq!(b_fp8_t.dtype(), DType::FP8E4M3);
|
||||||
|
assert!(a_fp8.is_contiguous() && b_fp8_t.is_contiguous());
|
||||||
|
assert_eq!(a_fp8.shape()[0], b_fp8_t.shape()[0]);
|
||||||
|
// b_fp8_t is [batch, N, K] transposed, so b_fp8_t.shape[2] == K == a_fp8.shape[2]
|
||||||
|
assert_eq!(a_fp8.shape()[2], b_fp8_t.shape()[2]);
|
||||||
|
|
||||||
|
let batch = a_fp8.shape()[0];
|
||||||
|
let m = a_fp8.shape()[1]; // tokens
|
||||||
|
let k = a_fp8.shape()[2]; // hidden
|
||||||
|
let n = b_fp8_t.shape()[1]; // out_dim (from transposed weight)
|
||||||
|
|
||||||
|
// Per-token scales → per-expert scales (max over tokens within each expert batch)
|
||||||
|
// a_scales: [batch * M] → we take the max per expert to get [batch] scalar scales
|
||||||
|
// This is a slight accuracy tradeoff vs per-token, but allows scalar GEMM scale mode.
|
||||||
|
assert_eq!(a_scales.shape()[0], batch * m);
|
||||||
|
assert_eq!(b_scales.shape()[0], batch);
|
||||||
|
|
||||||
|
let c = Tensor::empty(&[batch, m, n], DType::BF16, a_fp8.device());
|
||||||
|
|
||||||
|
// Read weight scales to host for the per-expert loop
|
||||||
|
let b_scales_cpu = b_scales.to_device(xserv_tensor::Device::Cpu);
|
||||||
|
let b_s_data = b_scales_cpu.as_slice::<f32>();
|
||||||
|
|
||||||
|
CUBLASLT_CTX.with(|cell| {
|
||||||
|
let ctx = cell.borrow();
|
||||||
|
let handle = ctx.handle;
|
||||||
|
|
||||||
|
// Strides (in bytes) for one expert slice
|
||||||
|
let stride_a = m * k; // FP8: 1 byte per elem
|
||||||
|
let stride_b = n * k; // FP8: 1 byte per elem (transposed: [N, K])
|
||||||
|
let stride_c = m * n * 2; // BF16: 2 bytes per elem
|
||||||
|
|
||||||
|
for e in 0..batch {
|
||||||
|
let a_ptr = unsafe { (a_fp8.data_ptr() as *const u8).add(e * stride_a) as *const c_void };
|
||||||
|
let b_ptr = unsafe { (b_fp8_t.data_ptr() as *const u8).add(e * stride_b) as *const c_void };
|
||||||
|
let c_ptr = unsafe { (c.data_ptr() as *mut u8).add(e * stride_c) as *mut c_void };
|
||||||
|
|
||||||
|
// alpha = b_scale (weight scale). Per-row activation scale applied post-GEMM.
|
||||||
|
// GEMM computes: D = alpha * (A_fp8 @ B_fp8_T)
|
||||||
|
// = b_scale * ((A_real / a_scale_row) @ (B_real / b_scale))
|
||||||
|
// = (A_real / a_scale_row) @ B_real
|
||||||
|
// Post-multiply row i by a_scale[i] to recover the correct result.
|
||||||
|
let alpha: f32 = b_s_data[e];
|
||||||
|
let beta: f32 = 0.0;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
// cuBLASLt FP8 on Blackwell requires transA=T, transB=N.
|
||||||
|
// cuBLASLt computes: D(m,n) = op(A)(m,k) * B(k,n) with transA=T
|
||||||
|
//
|
||||||
|
// We want: D_row[M,N] = A_act_row[M,K] @ B_wt_row[K,N]
|
||||||
|
// Map to cuBLASLt with m_lt=N, n_lt=M, k_lt=K:
|
||||||
|
// "A" (transA=T): stored as (K, N) col-major ld=K → transposed to (N, K)
|
||||||
|
// Our weights are stored TRANSPOSED as [E, N, K] row-major = col-major (K, N) ld=K ✓
|
||||||
|
// "B" (transB=N): stored as (K, M) col-major ld=K
|
||||||
|
// Our activations A_act_row[M,K] = col-major (K, M) ld=K ✓
|
||||||
|
// "D": stored as (N, M) col-major ld=N
|
||||||
|
// Our output D_row[M,N] = col-major (N, M) ld=N ✓
|
||||||
|
let m_lt = n as u64;
|
||||||
|
let n_lt = m as u64;
|
||||||
|
let k_lt = k as u64;
|
||||||
|
|
||||||
|
let mut matmul_desc: CublasLtMatmulDesc = std::ptr::null_mut();
|
||||||
|
cublasLtMatmulDescCreate(&mut matmul_desc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
|
||||||
|
|
||||||
|
// Set transA=T (required for FP8 on Blackwell)
|
||||||
|
let trans_a: i32 = 1; // CUBLAS_OP_T
|
||||||
|
cublasLtMatmulDescSetAttribute(matmul_desc, 3 /*TRANSA*/, &trans_a as *const i32 as _, 4);
|
||||||
|
|
||||||
|
// FP8 requires scale pointers. We fold the actual scales into alpha,
|
||||||
|
// so set dummy 1.0 scale pointers on device.
|
||||||
|
let one_val: f32 = 1.0;
|
||||||
|
let mut one_buf = xserv_cuda::allocator::cached_alloc(4).unwrap();
|
||||||
|
one_buf.copy_from_host(&one_val.to_le_bytes()).unwrap();
|
||||||
|
let one_ptr = one_buf.as_ptr() as *const c_void;
|
||||||
|
cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, std::mem::size_of::<*const c_void>());
|
||||||
|
cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, std::mem::size_of::<*const c_void>());
|
||||||
|
|
||||||
|
// "A" layout (weights, transposed): physical (K, N) col-major, ld=K
|
||||||
|
let mut a_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||||
|
cublasLtMatrixLayoutCreate(&mut a_layout, CUDA_R_8F_E4M3, k_lt, m_lt, k as i64);
|
||||||
|
|
||||||
|
// "B" layout (activations): physical (K, M) col-major, ld=K
|
||||||
|
let mut b_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||||
|
cublasLtMatrixLayoutCreate(&mut b_layout, CUDA_R_8F_E4M3, k_lt, n_lt, k as i64);
|
||||||
|
|
||||||
|
// "C"/"D" layout (output): physical (N, M) col-major, ld=N
|
||||||
|
let mut c_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||||
|
cublasLtMatrixLayoutCreate(&mut c_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
|
||||||
|
let mut d_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||||
|
cublasLtMatrixLayoutCreate(&mut d_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
|
||||||
|
|
||||||
|
// Get algo heuristic
|
||||||
|
let mut pref: CublasLtMatmulPreference = std::ptr::null_mut();
|
||||||
|
cublasLtMatmulPreferenceCreate(&mut pref);
|
||||||
|
let ws_bytes = WORKSPACE_BYTES as u64;
|
||||||
|
cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_bytes as *const u64 as _, 8);
|
||||||
|
|
||||||
|
let mut heuristic = std::mem::zeroed::<CublasLtMatmulHeuristicResult>();
|
||||||
|
let mut found: i32 = 0;
|
||||||
|
let status = cublasLtMatmulAlgoGetHeuristic(
|
||||||
|
handle, matmul_desc,
|
||||||
|
a_layout, b_layout, c_layout, d_layout,
|
||||||
|
pref, 1, &mut heuristic, &mut found,
|
||||||
|
);
|
||||||
|
assert!(status == 0 && found > 0,
|
||||||
|
"cublasLtMatmulAlgoGetHeuristic failed for FP8 GEMM: status={status}, found={found}");
|
||||||
|
|
||||||
|
let status = cublasLtMatmul(
|
||||||
|
handle, matmul_desc,
|
||||||
|
&alpha as *const f32 as _,
|
||||||
|
b_ptr, // cuBLASLt "A" = weights
|
||||||
|
a_layout,
|
||||||
|
a_ptr, // cuBLASLt "B" = activations
|
||||||
|
b_layout,
|
||||||
|
&beta as *const f32 as _,
|
||||||
|
c_ptr, // C (unused with beta=0)
|
||||||
|
c_layout,
|
||||||
|
c_ptr, // D = output
|
||||||
|
d_layout,
|
||||||
|
&heuristic.algo,
|
||||||
|
ctx.workspace.as_ptr() as *mut c_void,
|
||||||
|
heuristic.workspace_size,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
);
|
||||||
|
assert_eq!(status, 0, "cublasLtMatmul FP8 failed for expert {e}: status={status}");
|
||||||
|
|
||||||
|
cublasLtMatmulPreferenceDestroy(pref);
|
||||||
|
cublasLtMatrixLayoutDestroy(a_layout);
|
||||||
|
cublasLtMatrixLayoutDestroy(b_layout);
|
||||||
|
cublasLtMatrixLayoutDestroy(c_layout);
|
||||||
|
cublasLtMatrixLayoutDestroy(d_layout);
|
||||||
|
cublasLtMatmulDescDestroy(matmul_desc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Post-GEMM: multiply each row of c by its activation scale.
|
||||||
|
// c is [batch, M, N] BF16. a_scales is [batch * M] F32.
|
||||||
|
// This recovers the per-token scale that was divided out during quantization.
|
||||||
|
let total_rows = (batch * m) as i32;
|
||||||
|
let cols = n as i32;
|
||||||
|
unsafe {
|
||||||
|
launch_rowwise_scale_bf16(
|
||||||
|
c.data_ptr() as *mut c_void,
|
||||||
|
a_scales.data_ptr() as *const c_void,
|
||||||
|
total_rows, cols,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
c
|
||||||
|
}
|
||||||
|
|||||||
@@ -47,10 +47,11 @@ struct GptOssBlock {
|
|||||||
expert_gate_up_bias: Tensor, // [local_experts, 2*inter]
|
expert_gate_up_bias: Tensor, // [local_experts, 2*inter]
|
||||||
expert_down_wt: Tensor, // [local_experts, inter, hidden] BF16
|
expert_down_wt: Tensor, // [local_experts, inter, hidden] BF16
|
||||||
expert_down_bias: Tensor, // [local_experts, hidden]
|
expert_down_bias: Tensor, // [local_experts, hidden]
|
||||||
// FP8 quantized expert weights (Some when running FP8 W8A16)
|
// FP8 quantized expert weights (Some when running FP8 W8A8)
|
||||||
expert_gate_up_fp8: Option<Tensor>, // [local_experts, hidden, 2*inter] FP8E4M3
|
// Transposed layout [E, N, K] for cuBLASLt FP8 (Blackwell requires transA=T)
|
||||||
|
expert_gate_up_fp8: Option<Tensor>, // [local_experts, 2*inter, hidden] FP8E4M3
|
||||||
expert_gate_up_scale: Option<Tensor>,// [local_experts] F32
|
expert_gate_up_scale: Option<Tensor>,// [local_experts] F32
|
||||||
expert_down_fp8: Option<Tensor>, // [local_experts, inter, hidden] FP8E4M3
|
expert_down_fp8: Option<Tensor>, // [local_experts, hidden, inter] FP8E4M3
|
||||||
expert_down_scale: Option<Tensor>, // [local_experts] F32
|
expert_down_scale: Option<Tensor>, // [local_experts] F32
|
||||||
local_experts: usize,
|
local_experts: usize,
|
||||||
// Activation params
|
// Activation params
|
||||||
@@ -183,9 +184,12 @@ impl GptOss {
|
|||||||
let expert_down_scale_gpu;
|
let expert_down_scale_gpu;
|
||||||
|
|
||||||
if is_fp8 {
|
if is_fp8 {
|
||||||
// FP8 path: load quantized weights and scales
|
// FP8 W8A8 path: load and TRANSPOSE weights for cuBLASLt (requires transA=T on Blackwell).
|
||||||
expert_gate_up_fp8 = Some(slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev));
|
// Original: [E, K, N] → Transposed: [E, N, K]
|
||||||
expert_down_fp8 = Some(slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev));
|
let gu_sliced = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2);
|
||||||
|
let dn_sliced = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden);
|
||||||
|
expert_gate_up_fp8 = Some(transpose_3d_inner_raw(&gu_sliced, local_experts, hidden, inter2).to_device(dev));
|
||||||
|
expert_down_fp8 = Some(transpose_3d_inner_raw(&dn_sliced, local_experts, inter, hidden).to_device(dev));
|
||||||
// Scales: [num_experts] F32 → slice to [local_experts]
|
// Scales: [num_experts] F32 → slice to [local_experts]
|
||||||
let gu_s = gate_up_scale.expect("FP8 model missing gate_up_proj_scale");
|
let gu_s = gate_up_scale.expect("FP8 model missing gate_up_proj_scale");
|
||||||
let d_s = down_scale.expect("FP8 model missing down_proj_scale");
|
let d_s = down_scale.expect("FP8 model missing down_proj_scale");
|
||||||
@@ -255,7 +259,7 @@ impl GptOss {
|
|||||||
eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm");
|
eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm");
|
||||||
}
|
}
|
||||||
if is_fp8 {
|
if is_fp8 {
|
||||||
eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A16 mode)");
|
eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -515,12 +519,15 @@ impl GptOss {
|
|||||||
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
|
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
|
||||||
|
|
||||||
// 4. Batched GEMM gate_up: [E, tokens, hidden] @ [E, hidden, 2*inter] → [E, tokens, 2*inter]
|
// 4. Batched GEMM gate_up: [E, tokens, hidden] @ [E, hidden, 2*inter] → [E, tokens, 2*inter]
|
||||||
let gate_up_wt = if let Some(ref fp8) = layer.expert_gate_up_fp8 {
|
let gate_up = if let Some(ref wt_fp8_t) = layer.expert_gate_up_fp8 {
|
||||||
xserv_kernels::quantization::dequant_fp8_to_bf16(fp8, layer.expert_gate_up_scale.as_ref().unwrap())
|
// W8A8: quantize activations with per-expert scalar scale, use cuBLASLt FP8 GEMM
|
||||||
|
let (x_fp8, x_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep);
|
||||||
|
xserv_kernels::quantization::batched_gemm_fp8(
|
||||||
|
&x_fp8, &x_scales, wt_fp8_t, layer.expert_gate_up_scale.as_ref().unwrap(),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
layer.expert_gate_up_wt.clone()
|
xserv_kernels::moe::batched_gemm_strided(&x_rep, &layer.expert_gate_up_wt)
|
||||||
};
|
};
|
||||||
let gate_up = xserv_kernels::moe::batched_gemm_strided(&x_rep, &gate_up_wt);
|
|
||||||
|
|
||||||
// 5. Bias add: gate_up += expert_gate_up_bias (in-place)
|
// 5. Bias add: gate_up += expert_gate_up_bias (in-place)
|
||||||
xserv_kernels::moe::moe_bias_add_3d(&gate_up, &layer.expert_gate_up_bias);
|
xserv_kernels::moe::moe_bias_add_3d(&gate_up, &layer.expert_gate_up_bias);
|
||||||
@@ -534,12 +541,15 @@ impl GptOss {
|
|||||||
let activated = activated_flat.reshape(&[local_experts, num_tokens, inter]);
|
let activated = activated_flat.reshape(&[local_experts, num_tokens, inter]);
|
||||||
|
|
||||||
// 7. Batched GEMM down: [E, tokens, inter] @ [E, inter, hidden] → [E, tokens, hidden]
|
// 7. Batched GEMM down: [E, tokens, inter] @ [E, inter, hidden] → [E, tokens, hidden]
|
||||||
let down_wt = if let Some(ref fp8) = layer.expert_down_fp8 {
|
let down = if let Some(ref wt_fp8) = layer.expert_down_fp8 {
|
||||||
xserv_kernels::quantization::dequant_fp8_to_bf16(fp8, layer.expert_down_scale.as_ref().unwrap())
|
// W8A8: quantize post-GLU activations to FP8, use cuBLASLt FP8 GEMM
|
||||||
|
let (act_fp8, act_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated);
|
||||||
|
xserv_kernels::quantization::batched_gemm_fp8(
|
||||||
|
&act_fp8, &act_scales, wt_fp8, layer.expert_down_scale.as_ref().unwrap(),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
layer.expert_down_wt.clone()
|
xserv_kernels::moe::batched_gemm_strided(&activated, &layer.expert_down_wt)
|
||||||
};
|
};
|
||||||
let down = xserv_kernels::moe::batched_gemm_strided(&activated, &down_wt);
|
|
||||||
|
|
||||||
// 8. Bias add: down += expert_down_bias (in-place)
|
// 8. Bias add: down += expert_down_bias (in-place)
|
||||||
xserv_kernels::moe::moe_bias_add_3d(&down, &layer.expert_down_bias);
|
xserv_kernels::moe::moe_bias_add_3d(&down, &layer.expert_down_bias);
|
||||||
@@ -636,6 +646,27 @@ fn shard_1d(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
|||||||
Tensor::from_slice(&shard, &[local])
|
Tensor::from_slice(&shard, &[local])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Transpose the inner two dimensions of a [batch, rows, cols] tensor → [batch, cols, rows].
|
||||||
|
/// Works on raw bytes (any dtype). CPU-only.
|
||||||
|
fn transpose_3d_inner_raw(t: &Tensor, batch: usize, rows: usize, cols: usize) -> Tensor {
|
||||||
|
assert_eq!(t.ndim(), 3);
|
||||||
|
assert_eq!(t.shape(), &[batch, rows, cols]);
|
||||||
|
let host = t.to_device(Device::Cpu);
|
||||||
|
let es = t.dtype().size_bytes();
|
||||||
|
let raw = host.as_raw_bytes();
|
||||||
|
let mut out = vec![0u8; batch * cols * rows * es];
|
||||||
|
for b in 0..batch {
|
||||||
|
for r in 0..rows {
|
||||||
|
for c in 0..cols {
|
||||||
|
let src_off = (b * rows * cols + r * cols + c) * es;
|
||||||
|
let dst_off = (b * cols * rows + c * rows + r) * es;
|
||||||
|
out[dst_off..dst_off + es].copy_from_slice(&raw[src_off..src_off + es]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Tensor::from_raw_bytes(&out, &[batch, cols, rows], t.dtype())
|
||||||
|
}
|
||||||
|
|
||||||
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor (any dtype, raw bytes).
|
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor (any dtype, raw bytes).
|
||||||
fn slice_expert_range_3d_raw(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
|
fn slice_expert_range_3d_raw(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
|
||||||
assert_eq!(t.ndim(), 3);
|
assert_eq!(t.ndim(), 3);
|
||||||
|
|||||||
123
csrc/quantization/quantize_fp8.cu
Normal file
123
csrc/quantization/quantize_fp8.cu
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
#include <float.h>
|
||||||
|
#include "../common.cuh"
|
||||||
|
|
||||||
|
// Per-row quantize BF16 → FP8 E4M3 with per-row FP32 scale output.
|
||||||
|
//
|
||||||
|
// Input: src [num_rows, cols] BF16
|
||||||
|
// Output: dst [num_rows, cols] FP8 E4M3
|
||||||
|
// scales [num_rows] FP32
|
||||||
|
//
|
||||||
|
// Algorithm per row:
|
||||||
|
// absmax = max(|src[row, :]|)
|
||||||
|
// scale = absmax / 448.0 (FP8 E4M3 max representable)
|
||||||
|
// dst[row, i] = fp8(src[row, i] / scale)
|
||||||
|
//
|
||||||
|
// Grid: one block per row. Block: 256 threads.
|
||||||
|
// Each thread handles ceil(cols / 256) elements.
|
||||||
|
|
||||||
|
#define QUANT_BLOCK 256
|
||||||
|
#define FP8_E4M3_MAX 448.0f
|
||||||
|
|
||||||
|
__global__ void quantize_bf16_to_fp8e4m3_rowwise_kernel(
|
||||||
|
const __nv_bfloat16* __restrict__ src,
|
||||||
|
__nv_fp8_e4m3* __restrict__ dst,
|
||||||
|
float* __restrict__ scales,
|
||||||
|
int num_rows, int cols
|
||||||
|
) {
|
||||||
|
int row = blockIdx.x;
|
||||||
|
if (row >= num_rows) return;
|
||||||
|
int tid = threadIdx.x;
|
||||||
|
|
||||||
|
const __nv_bfloat16* row_src = src + (long long)row * cols;
|
||||||
|
__nv_fp8_e4m3* row_dst = dst + (long long)row * cols;
|
||||||
|
|
||||||
|
// Step 1: Compute per-row absmax via shared-memory reduction.
|
||||||
|
__shared__ float smem_max[QUANT_BLOCK];
|
||||||
|
float local_max = 0.0f;
|
||||||
|
for (int i = tid; i < cols; i += QUANT_BLOCK) {
|
||||||
|
float v = fabsf(__bfloat162float(row_src[i]));
|
||||||
|
local_max = fmaxf(local_max, v);
|
||||||
|
}
|
||||||
|
smem_max[tid] = local_max;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Block reduction
|
||||||
|
for (int s = QUANT_BLOCK / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
smem_max[tid] = fmaxf(smem_max[tid], smem_max[tid + s]);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
float absmax = smem_max[0];
|
||||||
|
float scale = absmax / FP8_E4M3_MAX;
|
||||||
|
// Clamp scale to avoid div-by-zero for all-zero rows
|
||||||
|
if (scale < 1e-12f) scale = 1e-12f;
|
||||||
|
float inv_scale = 1.0f / scale;
|
||||||
|
|
||||||
|
// Thread 0 writes the scale
|
||||||
|
if (tid == 0) {
|
||||||
|
scales[row] = scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Quantize each element
|
||||||
|
for (int i = tid; i < cols; i += QUANT_BLOCK) {
|
||||||
|
float v = __bfloat162float(row_src[i]) * inv_scale;
|
||||||
|
row_dst[i] = __nv_fp8_e4m3(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Row-wise scale: data[row, :] *= scales[row] (in-place, BF16)
|
||||||
|
__global__ void rowwise_scale_bf16_kernel(
|
||||||
|
__nv_bfloat16* __restrict__ data,
|
||||||
|
const float* __restrict__ scales,
|
||||||
|
int num_rows, int cols
|
||||||
|
) {
|
||||||
|
int row = blockIdx.x;
|
||||||
|
if (row >= num_rows) return;
|
||||||
|
int tid = threadIdx.x;
|
||||||
|
float s = scales[row];
|
||||||
|
__nv_bfloat16* row_data = data + (long long)row * cols;
|
||||||
|
for (int i = tid; i < cols; i += blockDim.x) {
|
||||||
|
float v = __bfloat162float(row_data[i]) * s;
|
||||||
|
row_data[i] = __float2bfloat16(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
void launch_rowwise_scale_bf16(
|
||||||
|
void* data, const void* scales,
|
||||||
|
int num_rows, int cols,
|
||||||
|
void* stream
|
||||||
|
) {
|
||||||
|
int block = 256;
|
||||||
|
int grid = num_rows;
|
||||||
|
rowwise_scale_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
|
(__nv_bfloat16*)data, (const float*)scales,
|
||||||
|
num_rows, cols
|
||||||
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_quantize_bf16_to_fp8e4m3_rowwise(
|
||||||
|
const void* src,
|
||||||
|
void* dst,
|
||||||
|
void* scales,
|
||||||
|
int num_rows, int cols,
|
||||||
|
void* stream
|
||||||
|
) {
|
||||||
|
int grid = num_rows;
|
||||||
|
int block = QUANT_BLOCK;
|
||||||
|
quantize_bf16_to_fp8e4m3_rowwise_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
|
(const __nv_bfloat16*)src,
|
||||||
|
(__nv_fp8_e4m3*)dst,
|
||||||
|
(float*)scales,
|
||||||
|
num_rows, cols
|
||||||
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user