speculative: batched-GEMV kernel for verify path (Phase 24 step 1)

Add launch_gemv_bf16_batched: runs M m=1 GEMVs in a single 3D grid
launch (z = batch row) with numerically identical output to M sequential
launch_gemv_bf16 calls — same K-block partial accumulation, same
fixed-order reduction. Verified on dash5 with 10 prompts × 32 tokens:
matched=true, verify_decode_mismatches=0.

Expose as matmul_batched_gemv(a: [M,K], b: [K,N]) → [M,N] in
xserv-kernels. Replace the old matmul_rows_gemv helper in qwen3
forward_verify_paged_decode_attention; the per-row loop over matmul_2d +
concat_rows is replaced by a single matmul_batched_gemv call that
allocates the partials buffer in one shot and launches 2 kernels instead
of 2*M.

Current speedup_e2e is 0.47× (same ballpark as Phase 23 0.44×);
the batched launch saves ~3 ms overhead but this is small relative to
the total 28 ms spec cost. The path forward (per docs/24 §4) is
higher acceptance rate or cheaper draft, not further kernel optimization.
This commit is contained in:
2026-07-01 16:13:37 +08:00
parent 42e13f33dd
commit e5734b41fa
4 changed files with 155 additions and 20 deletions

View File

@@ -18,6 +18,17 @@ unsafe extern "C" {
n: i32, n: i32,
stream: *mut c_void, stream: *mut c_void,
); );
fn launch_gemv_bf16_batched(
x: *const c_void,
w: *const c_void,
y_bf16: *mut c_void,
y_fp32_buf: *mut c_void,
m: i32,
k: i32,
n: i32,
stream: *mut c_void,
);
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@@ -31,6 +42,55 @@ pub fn gemv_scratch_elems(k: usize, n: usize) -> usize {
n * k.div_ceil(GEMV_TILE_K) n * k.div_ceil(GEMV_TILE_K)
} }
/// Batched GEMV: [M, K] × [K, N] → [M, N], all BF16.
/// Bit-exact with calling matmul on each row individually (same K-block partial
/// + fixed-order reduction path), but in a single kernel launch per phase.
pub fn matmul_batched_gemv(a: &Tensor, b: &Tensor) -> Tensor {
assert_eq!(a.ndim(), 2);
assert_eq!(b.ndim(), 2);
assert!(a.is_contiguous());
assert!(b.is_contiguous());
assert_eq!(a.dtype(), DType::BF16);
assert_eq!(b.dtype(), DType::BF16);
let m = a.shape()[0];
let k = a.shape()[1];
let n = b.shape()[1];
assert_eq!(b.shape()[0], k);
let out = Tensor::empty(&[m, n], DType::BF16, a.device());
let scratch_elems = m * gemv_scratch_elems(k, n);
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(scratch_elems * 4).unwrap();
let null_stream = xserv_cuda::current_stream_raw();
if m == 1 {
unsafe {
launch_gemv_bf16(
a.data_ptr() as *const c_void,
b.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
fp32_buf.as_mut_ptr() as *mut c_void,
k as i32,
n as i32,
null_stream,
);
}
} else {
unsafe {
launch_gemv_bf16_batched(
a.data_ptr() as *const c_void,
b.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
fp32_buf.as_mut_ptr() as *mut c_void,
m as i32,
k as i32,
n as i32,
null_stream,
);
}
}
out
}
// --- FFI: custom CUDA kernels --- // --- FFI: custom CUDA kernels ---
unsafe extern "C" { unsafe extern "C" {
fn launch_gemm_naive_f32( fn launch_gemm_naive_f32(

View File

@@ -19,7 +19,7 @@ pub use attention::{
paged_decode_attention_sinks, reshape_and_cache_batched_bf16, reshape_and_cache_bf16, paged_decode_attention_sinks, reshape_and_cache_batched_bf16, reshape_and_cache_bf16,
}; };
pub use embedding::{embedding, embedding_device_ids}; pub use embedding::{embedding, embedding_device_ids};
pub use gemm::{GemmBackend, batched_matmul, matmul}; pub use gemm::{GemmBackend, batched_matmul, matmul, matmul_batched_gemv};
pub use layernorm::layernorm; pub use layernorm::layernorm;
pub use rmsnorm::{add_rmsnorm, rmsnorm}; pub use rmsnorm::{add_rmsnorm, rmsnorm};
pub use rope::{RopeCache, rope_inplace, rope_inplace_device_pos}; pub use rope::{RopeCache, rope_inplace, rope_inplace_device_pos};

View File

@@ -923,7 +923,7 @@ impl Qwen3 {
let residual = x.clone(); let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps); let normed = rmsnorm(&x, &layer.input_norm, eps);
let qkv = matmul_rows_gemv(&normed, &layer.qkv_proj_wt); let qkv = matmul_batched_gemv(&normed, &layer.qkv_proj_wt);
let q_dim = num_heads * head_dim; let q_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim; let kv_dim = num_kv_heads * head_dim;
let q_all = qkv.narrow(1, 0, q_dim); let q_all = qkv.narrow(1, 0, q_dim);
@@ -966,25 +966,25 @@ impl Qwen3 {
); );
let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]); let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]);
let attn_proj = matmul_rows_gemv(&attn_merged, &layer.o_proj_wt); let attn_proj = matmul_batched_gemv(&attn_merged, &layer.o_proj_wt);
self.all_reduce(&attn_proj); self.all_reduce(&attn_proj);
let (normed, x_new) = let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone(); let residual = x_new.clone();
let gate_up = matmul_rows_gemv(&normed, &layer.gate_up_proj_wt); let gate_up = matmul_batched_gemv(&normed, &layer.gate_up_proj_wt);
let ffn_dim = gate_up.shape()[1] / 2; let ffn_dim = gate_up.shape()[1] / 2;
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous(); let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous(); let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
let hidden_states = xserv_kernels::silu_mul(&gate, &up); let hidden_states = xserv_kernels::silu_mul(&gate, &up);
let down = matmul_rows_gemv(&hidden_states, &layer.down_proj_wt); let down = matmul_batched_gemv(&hidden_states, &layer.down_proj_wt);
self.all_reduce(&down); self.all_reduce(&down);
x = add_any(&residual, &down); x = add_any(&residual, &down);
} }
let x = rmsnorm(&x, &self.norm, eps); let x = rmsnorm(&x, &self.norm, eps);
matmul_rows_gemv(&x, &self.lm_head_t) matmul_batched_gemv(&x, &self.lm_head_t)
} }
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels. /// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
@@ -1261,20 +1261,6 @@ fn row_view(t: &Tensor, row: usize) -> Tensor {
) )
} }
/// Run a 2D matmul row by row so each row uses the same GEMV kernel as
/// single-token decode. Used by speculative verify parity, where near-tie
/// logits must follow decode's BF16 rounding path.
fn matmul_rows_gemv(a: &Tensor, b: &Tensor) -> Tensor {
assert_eq!(a.ndim(), 2);
assert!(a.is_contiguous());
let rows = a.shape()[0];
if rows == 1 {
return matmul_2d(a, b);
}
let out_rows: Vec<Tensor> = (0..rows).map(|i| matmul_2d(&row_view(a, i), b)).collect();
concat_rows(&out_rows)
}
/// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy. /// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy.
fn concat_rows(rows: &[Tensor]) -> Tensor { fn concat_rows(rows: &[Tensor]) -> Tensor {
assert!(!rows.is_empty()); assert!(!rows.is_empty());

View File

@@ -69,6 +69,62 @@ __global__ void gemv_reduce_to_bf16_kernel(
} }
} }
// Batched variant: M rows, same W. Grid.z = batch row index.
// Numerically identical to calling launch_gemv_bf16 M times in sequence because
// each z-slice executes the same accumulation order on the same data.
// partials buffer must be [M * num_k_blocks * N] floats.
__global__ void gemv_bf16_batched_partial_kernel(
const __nv_bfloat16* __restrict__ x, // [M, K]
const __nv_bfloat16* __restrict__ W, // [K, N]
float* __restrict__ partials, // [M, num_k_blocks, N]
int K, int N
) {
const int block_n = blockIdx.x;
const int block_k = blockIdx.y;
const int row = blockIdx.z;
const int t = threadIdx.x;
const int col = block_n * GEMV_TILE_N + t;
const int k_start = block_k * GEMV_TILE_K;
const int k_end = min(k_start + GEMV_TILE_K, K);
const int k_len = k_end - k_start;
__shared__ float x_shared[GEMV_TILE_K];
const __nv_bfloat16* x_row = x + (long long)row * K;
for (int i = t; i < k_len; i += GEMV_BLOCK) {
x_shared[i] = __bfloat162float(x_row[k_start + i]);
}
__syncthreads();
if (col >= N) return;
float sum = 0.0f;
for (int ki = 0; ki < k_len; ki++) {
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
}
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
partials[((long long)row * num_k_blocks + block_k) * N + col] = sum;
}
__global__ void gemv_batched_reduce_to_bf16_kernel(
const float* __restrict__ partials, // [M, num_k_blocks, N]
__nv_bfloat16* __restrict__ dst, // [M, N]
int n,
int num_k_blocks
) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
int row = blockIdx.y;
if (col >= n) return;
float sum = 0.0f;
const float* row_partials = partials + (long long)row * num_k_blocks * n;
for (int kb = 0; kb < num_k_blocks; kb++) {
sum += row_partials[(long long)kb * n + col];
}
dst[(long long)row * n + col] = __float2bfloat16(sum);
}
extern "C" { extern "C" {
void launch_gemv_bf16( void launch_gemv_bf16(
@@ -104,4 +160,37 @@ void launch_gemv_bf16(
CUDA_CHECK_LAST_ERROR(); CUDA_CHECK_LAST_ERROR();
} }
void launch_gemv_bf16_batched(
const void* x, // [M, K] BF16
const void* W, // [K, N] BF16
void* y_bf16, // [M, N] BF16
void* y_fp32_buf, // [M * num_k_blocks * N] FP32
int M, int K, int N,
void* stream
) {
cudaStream_t s = (cudaStream_t)stream;
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks, M);
gemv_bf16_batched_partial_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
(const __nv_bfloat16*)x,
(const __nv_bfloat16*)W,
(float*)y_fp32_buf,
K, N
);
CUDA_CHECK_LAST_ERROR();
int conv_block = 256;
int conv_grid_x = (N + conv_block - 1) / conv_block;
dim3 reduce_grid(conv_grid_x, M);
gemv_batched_reduce_to_bf16_kernel<<<reduce_grid, conv_block, 0, s>>>(
(const float*)y_fp32_buf,
(__nv_bfloat16*)y_bf16,
N,
num_k_blocks
);
CUDA_CHECK_LAST_ERROR();
}
} // extern "C" } // extern "C"