speculative: batched-GEMV kernel for verify path (Phase 24 step 1)
Add launch_gemv_bf16_batched: runs M m=1 GEMVs in a single 3D grid launch (z = batch row) with numerically identical output to M sequential launch_gemv_bf16 calls — same K-block partial accumulation, same fixed-order reduction. Verified on dash5 with 10 prompts × 32 tokens: matched=true, verify_decode_mismatches=0. Expose as matmul_batched_gemv(a: [M,K], b: [K,N]) → [M,N] in xserv-kernels. Replace the old matmul_rows_gemv helper in qwen3 forward_verify_paged_decode_attention; the per-row loop over matmul_2d + concat_rows is replaced by a single matmul_batched_gemv call that allocates the partials buffer in one shot and launches 2 kernels instead of 2*M. Current speedup_e2e is 0.47× (same ballpark as Phase 23 0.44×); the batched launch saves ~3 ms overhead but this is small relative to the total 28 ms spec cost. The path forward (per docs/24 §4) is higher acceptance rate or cheaper draft, not further kernel optimization.
This commit is contained in:
@@ -18,6 +18,17 @@ unsafe extern "C" {
|
|||||||
n: i32,
|
n: i32,
|
||||||
stream: *mut c_void,
|
stream: *mut c_void,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
fn launch_gemv_bf16_batched(
|
||||||
|
x: *const c_void,
|
||||||
|
w: *const c_void,
|
||||||
|
y_bf16: *mut c_void,
|
||||||
|
y_fp32_buf: *mut c_void,
|
||||||
|
m: i32,
|
||||||
|
k: i32,
|
||||||
|
n: i32,
|
||||||
|
stream: *mut c_void,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
@@ -31,6 +42,55 @@ pub fn gemv_scratch_elems(k: usize, n: usize) -> usize {
|
|||||||
n * k.div_ceil(GEMV_TILE_K)
|
n * k.div_ceil(GEMV_TILE_K)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Batched GEMV: [M, K] × [K, N] → [M, N], all BF16.
|
||||||
|
/// Bit-exact with calling matmul on each row individually (same K-block partial
|
||||||
|
/// + fixed-order reduction path), but in a single kernel launch per phase.
|
||||||
|
pub fn matmul_batched_gemv(a: &Tensor, b: &Tensor) -> Tensor {
|
||||||
|
assert_eq!(a.ndim(), 2);
|
||||||
|
assert_eq!(b.ndim(), 2);
|
||||||
|
assert!(a.is_contiguous());
|
||||||
|
assert!(b.is_contiguous());
|
||||||
|
assert_eq!(a.dtype(), DType::BF16);
|
||||||
|
assert_eq!(b.dtype(), DType::BF16);
|
||||||
|
let m = a.shape()[0];
|
||||||
|
let k = a.shape()[1];
|
||||||
|
let n = b.shape()[1];
|
||||||
|
assert_eq!(b.shape()[0], k);
|
||||||
|
|
||||||
|
let out = Tensor::empty(&[m, n], DType::BF16, a.device());
|
||||||
|
let scratch_elems = m * gemv_scratch_elems(k, n);
|
||||||
|
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(scratch_elems * 4).unwrap();
|
||||||
|
|
||||||
|
let null_stream = xserv_cuda::current_stream_raw();
|
||||||
|
if m == 1 {
|
||||||
|
unsafe {
|
||||||
|
launch_gemv_bf16(
|
||||||
|
a.data_ptr() as *const c_void,
|
||||||
|
b.data_ptr() as *const c_void,
|
||||||
|
out.data_ptr() as *mut c_void,
|
||||||
|
fp32_buf.as_mut_ptr() as *mut c_void,
|
||||||
|
k as i32,
|
||||||
|
n as i32,
|
||||||
|
null_stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unsafe {
|
||||||
|
launch_gemv_bf16_batched(
|
||||||
|
a.data_ptr() as *const c_void,
|
||||||
|
b.data_ptr() as *const c_void,
|
||||||
|
out.data_ptr() as *mut c_void,
|
||||||
|
fp32_buf.as_mut_ptr() as *mut c_void,
|
||||||
|
m as i32,
|
||||||
|
k as i32,
|
||||||
|
n as i32,
|
||||||
|
null_stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
// --- FFI: custom CUDA kernels ---
|
// --- FFI: custom CUDA kernels ---
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
fn launch_gemm_naive_f32(
|
fn launch_gemm_naive_f32(
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ pub use attention::{
|
|||||||
paged_decode_attention_sinks, reshape_and_cache_batched_bf16, reshape_and_cache_bf16,
|
paged_decode_attention_sinks, reshape_and_cache_batched_bf16, reshape_and_cache_bf16,
|
||||||
};
|
};
|
||||||
pub use embedding::{embedding, embedding_device_ids};
|
pub use embedding::{embedding, embedding_device_ids};
|
||||||
pub use gemm::{GemmBackend, batched_matmul, matmul};
|
pub use gemm::{GemmBackend, batched_matmul, matmul, matmul_batched_gemv};
|
||||||
pub use layernorm::layernorm;
|
pub use layernorm::layernorm;
|
||||||
pub use rmsnorm::{add_rmsnorm, rmsnorm};
|
pub use rmsnorm::{add_rmsnorm, rmsnorm};
|
||||||
pub use rope::{RopeCache, rope_inplace, rope_inplace_device_pos};
|
pub use rope::{RopeCache, rope_inplace, rope_inplace_device_pos};
|
||||||
|
|||||||
@@ -923,7 +923,7 @@ impl Qwen3 {
|
|||||||
let residual = x.clone();
|
let residual = x.clone();
|
||||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||||
|
|
||||||
let qkv = matmul_rows_gemv(&normed, &layer.qkv_proj_wt);
|
let qkv = matmul_batched_gemv(&normed, &layer.qkv_proj_wt);
|
||||||
let q_dim = num_heads * head_dim;
|
let q_dim = num_heads * head_dim;
|
||||||
let kv_dim = num_kv_heads * head_dim;
|
let kv_dim = num_kv_heads * head_dim;
|
||||||
let q_all = qkv.narrow(1, 0, q_dim);
|
let q_all = qkv.narrow(1, 0, q_dim);
|
||||||
@@ -966,25 +966,25 @@ impl Qwen3 {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]);
|
let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]);
|
||||||
let attn_proj = matmul_rows_gemv(&attn_merged, &layer.o_proj_wt);
|
let attn_proj = matmul_batched_gemv(&attn_merged, &layer.o_proj_wt);
|
||||||
self.all_reduce(&attn_proj);
|
self.all_reduce(&attn_proj);
|
||||||
|
|
||||||
let (normed, x_new) =
|
let (normed, x_new) =
|
||||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||||
let residual = x_new.clone();
|
let residual = x_new.clone();
|
||||||
|
|
||||||
let gate_up = matmul_rows_gemv(&normed, &layer.gate_up_proj_wt);
|
let gate_up = matmul_batched_gemv(&normed, &layer.gate_up_proj_wt);
|
||||||
let ffn_dim = gate_up.shape()[1] / 2;
|
let ffn_dim = gate_up.shape()[1] / 2;
|
||||||
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
||||||
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
||||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||||
let down = matmul_rows_gemv(&hidden_states, &layer.down_proj_wt);
|
let down = matmul_batched_gemv(&hidden_states, &layer.down_proj_wt);
|
||||||
self.all_reduce(&down);
|
self.all_reduce(&down);
|
||||||
x = add_any(&residual, &down);
|
x = add_any(&residual, &down);
|
||||||
}
|
}
|
||||||
|
|
||||||
let x = rmsnorm(&x, &self.norm, eps);
|
let x = rmsnorm(&x, &self.norm, eps);
|
||||||
matmul_rows_gemv(&x, &self.lm_head_t)
|
matmul_batched_gemv(&x, &self.lm_head_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
|
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
|
||||||
@@ -1261,20 +1261,6 @@ fn row_view(t: &Tensor, row: usize) -> Tensor {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run a 2D matmul row by row so each row uses the same GEMV kernel as
|
|
||||||
/// single-token decode. Used by speculative verify parity, where near-tie
|
|
||||||
/// logits must follow decode's BF16 rounding path.
|
|
||||||
fn matmul_rows_gemv(a: &Tensor, b: &Tensor) -> Tensor {
|
|
||||||
assert_eq!(a.ndim(), 2);
|
|
||||||
assert!(a.is_contiguous());
|
|
||||||
let rows = a.shape()[0];
|
|
||||||
if rows == 1 {
|
|
||||||
return matmul_2d(a, b);
|
|
||||||
}
|
|
||||||
let out_rows: Vec<Tensor> = (0..rows).map(|i| matmul_2d(&row_view(a, i), b)).collect();
|
|
||||||
concat_rows(&out_rows)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy.
|
/// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy.
|
||||||
fn concat_rows(rows: &[Tensor]) -> Tensor {
|
fn concat_rows(rows: &[Tensor]) -> Tensor {
|
||||||
assert!(!rows.is_empty());
|
assert!(!rows.is_empty());
|
||||||
|
|||||||
@@ -69,6 +69,62 @@ __global__ void gemv_reduce_to_bf16_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Batched variant: M rows, same W. Grid.z = batch row index.
|
||||||
|
// Numerically identical to calling launch_gemv_bf16 M times in sequence because
|
||||||
|
// each z-slice executes the same accumulation order on the same data.
|
||||||
|
// partials buffer must be [M * num_k_blocks * N] floats.
|
||||||
|
__global__ void gemv_bf16_batched_partial_kernel(
|
||||||
|
const __nv_bfloat16* __restrict__ x, // [M, K]
|
||||||
|
const __nv_bfloat16* __restrict__ W, // [K, N]
|
||||||
|
float* __restrict__ partials, // [M, num_k_blocks, N]
|
||||||
|
int K, int N
|
||||||
|
) {
|
||||||
|
const int block_n = blockIdx.x;
|
||||||
|
const int block_k = blockIdx.y;
|
||||||
|
const int row = blockIdx.z;
|
||||||
|
const int t = threadIdx.x;
|
||||||
|
const int col = block_n * GEMV_TILE_N + t;
|
||||||
|
|
||||||
|
const int k_start = block_k * GEMV_TILE_K;
|
||||||
|
const int k_end = min(k_start + GEMV_TILE_K, K);
|
||||||
|
const int k_len = k_end - k_start;
|
||||||
|
|
||||||
|
__shared__ float x_shared[GEMV_TILE_K];
|
||||||
|
const __nv_bfloat16* x_row = x + (long long)row * K;
|
||||||
|
for (int i = t; i < k_len; i += GEMV_BLOCK) {
|
||||||
|
x_shared[i] = __bfloat162float(x_row[k_start + i]);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (col >= N) return;
|
||||||
|
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (int ki = 0; ki < k_len; ki++) {
|
||||||
|
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
||||||
|
partials[((long long)row * num_k_blocks + block_k) * N + col] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void gemv_batched_reduce_to_bf16_kernel(
|
||||||
|
const float* __restrict__ partials, // [M, num_k_blocks, N]
|
||||||
|
__nv_bfloat16* __restrict__ dst, // [M, N]
|
||||||
|
int n,
|
||||||
|
int num_k_blocks
|
||||||
|
) {
|
||||||
|
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int row = blockIdx.y;
|
||||||
|
if (col >= n) return;
|
||||||
|
|
||||||
|
float sum = 0.0f;
|
||||||
|
const float* row_partials = partials + (long long)row * num_k_blocks * n;
|
||||||
|
for (int kb = 0; kb < num_k_blocks; kb++) {
|
||||||
|
sum += row_partials[(long long)kb * n + col];
|
||||||
|
}
|
||||||
|
dst[(long long)row * n + col] = __float2bfloat16(sum);
|
||||||
|
}
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
void launch_gemv_bf16(
|
void launch_gemv_bf16(
|
||||||
@@ -104,4 +160,37 @@ void launch_gemv_bf16(
|
|||||||
CUDA_CHECK_LAST_ERROR();
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void launch_gemv_bf16_batched(
|
||||||
|
const void* x, // [M, K] BF16
|
||||||
|
const void* W, // [K, N] BF16
|
||||||
|
void* y_bf16, // [M, N] BF16
|
||||||
|
void* y_fp32_buf, // [M * num_k_blocks * N] FP32
|
||||||
|
int M, int K, int N,
|
||||||
|
void* stream
|
||||||
|
) {
|
||||||
|
cudaStream_t s = (cudaStream_t)stream;
|
||||||
|
|
||||||
|
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
||||||
|
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks, M);
|
||||||
|
|
||||||
|
gemv_bf16_batched_partial_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
|
||||||
|
(const __nv_bfloat16*)x,
|
||||||
|
(const __nv_bfloat16*)W,
|
||||||
|
(float*)y_fp32_buf,
|
||||||
|
K, N
|
||||||
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
|
|
||||||
|
int conv_block = 256;
|
||||||
|
int conv_grid_x = (N + conv_block - 1) / conv_block;
|
||||||
|
dim3 reduce_grid(conv_grid_x, M);
|
||||||
|
gemv_batched_reduce_to_bf16_kernel<<<reduce_grid, conv_block, 0, s>>>(
|
||||||
|
(const float*)y_fp32_buf,
|
||||||
|
(__nv_bfloat16*)y_bf16,
|
||||||
|
N,
|
||||||
|
num_k_blocks
|
||||||
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|||||||
Reference in New Issue
Block a user