speculative: batched-GEMV kernel for verify path (Phase 24 step 1)
Add launch_gemv_bf16_batched: runs M m=1 GEMVs in a single 3D grid launch (z = batch row) with numerically identical output to M sequential launch_gemv_bf16 calls — same K-block partial accumulation, same fixed-order reduction. Verified on dash5 with 10 prompts × 32 tokens: matched=true, verify_decode_mismatches=0. Expose as matmul_batched_gemv(a: [M,K], b: [K,N]) → [M,N] in xserv-kernels. Replace the old matmul_rows_gemv helper in qwen3 forward_verify_paged_decode_attention; the per-row loop over matmul_2d + concat_rows is replaced by a single matmul_batched_gemv call that allocates the partials buffer in one shot and launches 2 kernels instead of 2*M. Current speedup_e2e is 0.47× (same ballpark as Phase 23 0.44×); the batched launch saves ~3 ms overhead but this is small relative to the total 28 ms spec cost. The path forward (per docs/24 §4) is higher acceptance rate or cheaper draft, not further kernel optimization.
This commit is contained in:
@@ -18,6 +18,17 @@ unsafe extern "C" {
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
fn launch_gemv_bf16_batched(
|
||||
x: *const c_void,
|
||||
w: *const c_void,
|
||||
y_bf16: *mut c_void,
|
||||
y_fp32_buf: *mut c_void,
|
||||
m: i32,
|
||||
k: i32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
@@ -31,6 +42,55 @@ pub fn gemv_scratch_elems(k: usize, n: usize) -> usize {
|
||||
n * k.div_ceil(GEMV_TILE_K)
|
||||
}
|
||||
|
||||
/// Batched GEMV: [M, K] × [K, N] → [M, N], all BF16.
|
||||
/// Bit-exact with calling matmul on each row individually (same K-block partial
|
||||
/// + fixed-order reduction path), but in a single kernel launch per phase.
|
||||
pub fn matmul_batched_gemv(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert_eq!(b.ndim(), 2);
|
||||
assert!(a.is_contiguous());
|
||||
assert!(b.is_contiguous());
|
||||
assert_eq!(a.dtype(), DType::BF16);
|
||||
assert_eq!(b.dtype(), DType::BF16);
|
||||
let m = a.shape()[0];
|
||||
let k = a.shape()[1];
|
||||
let n = b.shape()[1];
|
||||
assert_eq!(b.shape()[0], k);
|
||||
|
||||
let out = Tensor::empty(&[m, n], DType::BF16, a.device());
|
||||
let scratch_elems = m * gemv_scratch_elems(k, n);
|
||||
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(scratch_elems * 4).unwrap();
|
||||
|
||||
let null_stream = xserv_cuda::current_stream_raw();
|
||||
if m == 1 {
|
||||
unsafe {
|
||||
launch_gemv_bf16(
|
||||
a.data_ptr() as *const c_void,
|
||||
b.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
fp32_buf.as_mut_ptr() as *mut c_void,
|
||||
k as i32,
|
||||
n as i32,
|
||||
null_stream,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
unsafe {
|
||||
launch_gemv_bf16_batched(
|
||||
a.data_ptr() as *const c_void,
|
||||
b.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
fp32_buf.as_mut_ptr() as *mut c_void,
|
||||
m as i32,
|
||||
k as i32,
|
||||
n as i32,
|
||||
null_stream,
|
||||
);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
// --- FFI: custom CUDA kernels ---
|
||||
unsafe extern "C" {
|
||||
fn launch_gemm_naive_f32(
|
||||
|
||||
@@ -19,7 +19,7 @@ pub use attention::{
|
||||
paged_decode_attention_sinks, reshape_and_cache_batched_bf16, reshape_and_cache_bf16,
|
||||
};
|
||||
pub use embedding::{embedding, embedding_device_ids};
|
||||
pub use gemm::{GemmBackend, batched_matmul, matmul};
|
||||
pub use gemm::{GemmBackend, batched_matmul, matmul, matmul_batched_gemv};
|
||||
pub use layernorm::layernorm;
|
||||
pub use rmsnorm::{add_rmsnorm, rmsnorm};
|
||||
pub use rope::{RopeCache, rope_inplace, rope_inplace_device_pos};
|
||||
|
||||
@@ -923,7 +923,7 @@ impl Qwen3 {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
let qkv = matmul_rows_gemv(&normed, &layer.qkv_proj_wt);
|
||||
let qkv = matmul_batched_gemv(&normed, &layer.qkv_proj_wt);
|
||||
let q_dim = num_heads * head_dim;
|
||||
let kv_dim = num_kv_heads * head_dim;
|
||||
let q_all = qkv.narrow(1, 0, q_dim);
|
||||
@@ -966,25 +966,25 @@ impl Qwen3 {
|
||||
);
|
||||
|
||||
let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]);
|
||||
let attn_proj = matmul_rows_gemv(&attn_merged, &layer.o_proj_wt);
|
||||
let attn_proj = matmul_batched_gemv(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj);
|
||||
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_rows_gemv(&normed, &layer.gate_up_proj_wt);
|
||||
let gate_up = matmul_batched_gemv(&normed, &layer.gate_up_proj_wt);
|
||||
let ffn_dim = gate_up.shape()[1] / 2;
|
||||
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
||||
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_rows_gemv(&hidden_states, &layer.down_proj_wt);
|
||||
let down = matmul_batched_gemv(&hidden_states, &layer.down_proj_wt);
|
||||
self.all_reduce(&down);
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
|
||||
let x = rmsnorm(&x, &self.norm, eps);
|
||||
matmul_rows_gemv(&x, &self.lm_head_t)
|
||||
matmul_batched_gemv(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
|
||||
@@ -1261,20 +1261,6 @@ fn row_view(t: &Tensor, row: usize) -> Tensor {
|
||||
)
|
||||
}
|
||||
|
||||
/// Run a 2D matmul row by row so each row uses the same GEMV kernel as
|
||||
/// single-token decode. Used by speculative verify parity, where near-tie
|
||||
/// logits must follow decode's BF16 rounding path.
|
||||
fn matmul_rows_gemv(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert!(a.is_contiguous());
|
||||
let rows = a.shape()[0];
|
||||
if rows == 1 {
|
||||
return matmul_2d(a, b);
|
||||
}
|
||||
let out_rows: Vec<Tensor> = (0..rows).map(|i| matmul_2d(&row_view(a, i), b)).collect();
|
||||
concat_rows(&out_rows)
|
||||
}
|
||||
|
||||
/// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy.
|
||||
fn concat_rows(rows: &[Tensor]) -> Tensor {
|
||||
assert!(!rows.is_empty());
|
||||
|
||||
@@ -69,6 +69,62 @@ __global__ void gemv_reduce_to_bf16_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// Batched variant: M rows, same W. Grid.z = batch row index.
|
||||
// Numerically identical to calling launch_gemv_bf16 M times in sequence because
|
||||
// each z-slice executes the same accumulation order on the same data.
|
||||
// partials buffer must be [M * num_k_blocks * N] floats.
|
||||
__global__ void gemv_bf16_batched_partial_kernel(
|
||||
const __nv_bfloat16* __restrict__ x, // [M, K]
|
||||
const __nv_bfloat16* __restrict__ W, // [K, N]
|
||||
float* __restrict__ partials, // [M, num_k_blocks, N]
|
||||
int K, int N
|
||||
) {
|
||||
const int block_n = blockIdx.x;
|
||||
const int block_k = blockIdx.y;
|
||||
const int row = blockIdx.z;
|
||||
const int t = threadIdx.x;
|
||||
const int col = block_n * GEMV_TILE_N + t;
|
||||
|
||||
const int k_start = block_k * GEMV_TILE_K;
|
||||
const int k_end = min(k_start + GEMV_TILE_K, K);
|
||||
const int k_len = k_end - k_start;
|
||||
|
||||
__shared__ float x_shared[GEMV_TILE_K];
|
||||
const __nv_bfloat16* x_row = x + (long long)row * K;
|
||||
for (int i = t; i < k_len; i += GEMV_BLOCK) {
|
||||
x_shared[i] = __bfloat162float(x_row[k_start + i]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (col >= N) return;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int ki = 0; ki < k_len; ki++) {
|
||||
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
|
||||
}
|
||||
|
||||
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
||||
partials[((long long)row * num_k_blocks + block_k) * N + col] = sum;
|
||||
}
|
||||
|
||||
__global__ void gemv_batched_reduce_to_bf16_kernel(
|
||||
const float* __restrict__ partials, // [M, num_k_blocks, N]
|
||||
__nv_bfloat16* __restrict__ dst, // [M, N]
|
||||
int n,
|
||||
int num_k_blocks
|
||||
) {
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row = blockIdx.y;
|
||||
if (col >= n) return;
|
||||
|
||||
float sum = 0.0f;
|
||||
const float* row_partials = partials + (long long)row * num_k_blocks * n;
|
||||
for (int kb = 0; kb < num_k_blocks; kb++) {
|
||||
sum += row_partials[(long long)kb * n + col];
|
||||
}
|
||||
dst[(long long)row * n + col] = __float2bfloat16(sum);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_gemv_bf16(
|
||||
@@ -104,4 +160,37 @@ void launch_gemv_bf16(
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_gemv_bf16_batched(
|
||||
const void* x, // [M, K] BF16
|
||||
const void* W, // [K, N] BF16
|
||||
void* y_bf16, // [M, N] BF16
|
||||
void* y_fp32_buf, // [M * num_k_blocks * N] FP32
|
||||
int M, int K, int N,
|
||||
void* stream
|
||||
) {
|
||||
cudaStream_t s = (cudaStream_t)stream;
|
||||
|
||||
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
||||
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks, M);
|
||||
|
||||
gemv_bf16_batched_partial_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
|
||||
(const __nv_bfloat16*)x,
|
||||
(const __nv_bfloat16*)W,
|
||||
(float*)y_fp32_buf,
|
||||
K, N
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
|
||||
int conv_block = 256;
|
||||
int conv_grid_x = (N + conv_block - 1) / conv_block;
|
||||
dim3 reduce_grid(conv_grid_x, M);
|
||||
gemv_batched_reduce_to_bf16_kernel<<<reduce_grid, conv_block, 0, s>>>(
|
||||
(const float*)y_fp32_buf,
|
||||
(__nv_bfloat16*)y_bf16,
|
||||
N,
|
||||
num_k_blocks
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
Reference in New Issue
Block a user