perf: GPU transpose/reshape/repeat_kv kernels (eliminate CPU round-trips)
New CUDA kernels (csrc/embedding/transpose.cu): - reshape_heads_bf16: [S, H*D] → [1, H, S, D] - merge_heads_bf16: [1, H, S, D] → [S, H*D] - transpose_hsd_to_shd_bf16: [1, H, S, D] → [S, H, D] (for RoPE) - transpose_shd_to_hsd_bf16: [S, H, D] → [1, H, S, D] (from RoPE) - repeat_kv_bf16: [1, KV_H, S, D] → [1, KV_H*n_rep, S, D] Rust wrappers (xserv-kernels/src/transpose.rs): - reshape_heads_gpu, merge_heads_gpu, transpose_for/from_rope_gpu, repeat_kv_gpu Qwen3 forward_gpu_cache now uses all GPU kernels — zero CPU data round-trips. Result: 50/50 self-consistent, 3-5% faster (TBT 142→137ms) Remaining bottleneck: ~900 device::synchronize() calls + 252 cuBLAS handle creations per token (Phase 15 targets) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -23,6 +23,7 @@ fn main() {
|
||||
.file("../../csrc/embedding/embedding.cu")
|
||||
.file("../../csrc/embedding/rope.cu")
|
||||
.file("../../csrc/attention/causal_mask.cu")
|
||||
.file("../../csrc/embedding/transpose.cu")
|
||||
.compile("xserv_kernels");
|
||||
|
||||
println!("cargo:rerun-if-changed=../../csrc/");
|
||||
|
||||
@@ -6,8 +6,10 @@ pub mod layernorm;
|
||||
pub mod rmsnorm;
|
||||
pub mod rope;
|
||||
pub mod softmax;
|
||||
pub mod transpose;
|
||||
|
||||
pub use activation::{add, gelu, mul, scale, silu};
|
||||
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
|
||||
pub use attention::attention;
|
||||
pub use embedding::embedding;
|
||||
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
||||
|
||||
91
crates/xserv-kernels/src/transpose.rs
Normal file
91
crates/xserv-kernels/src/transpose.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_reshape_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_merge_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_repeat_kv_bf16(inp: *const c_void, out: *mut c_void, kv_heads: i32, n_rep: i32, seq_len: i32, head_dim: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
/// [S, H*D] → [1, H, S, D] on GPU (BF16)
|
||||
pub fn reshape_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_reshape_heads_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
/// [1, H, S, D] → [S, H*D] on GPU (BF16)
|
||||
pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let hidden = num_heads * head_dim;
|
||||
let out = Tensor::zeros(&[seq_len, hidden], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_merge_heads_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
/// [1, H, S, D] → [S, H, D] for RoPE on GPU (BF16)
|
||||
pub fn transpose_for_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(&[seq_len, num_heads, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_transpose_hsd_to_shd_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
/// [S, H, D] → [1, H, S, D] after RoPE on GPU (BF16)
|
||||
pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_transpose_shd_to_hsd_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
/// [1, KV_H, S, D] → [1, KV_H*n_rep, S, D] on GPU (BF16)
|
||||
pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
|
||||
if n_rep == 1 { return x.clone(); }
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let kv_heads = x.shape()[1];
|
||||
let seq_len = x.shape()[2];
|
||||
let head_dim = x.shape()[3];
|
||||
let new_heads = kv_heads * n_rep;
|
||||
let out = Tensor::zeros(&[1, new_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_repeat_kv_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
kv_heads as i32, n_rep as i32, seq_len as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
@@ -147,7 +147,7 @@ impl Qwen3 {
|
||||
matmul_2d(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
/// Forward with GPU-resident KV cache (no CPU round-trips for KV).
|
||||
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
|
||||
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
let pos_offset = cache.seq_len();
|
||||
@@ -168,30 +168,36 @@ impl Qwen3 {
|
||||
let k = matmul_2d(&normed, &layer.k_proj_wt);
|
||||
let v = matmul_2d(&normed, &layer.v_proj_wt);
|
||||
|
||||
let q = reshape_heads(&q, new_tokens, num_heads, head_dim);
|
||||
let k = reshape_heads(&k, new_tokens, num_kv_heads, head_dim);
|
||||
let v = reshape_heads(&v, new_tokens, num_kv_heads, head_dim);
|
||||
// GPU reshape: [S, H*D] → [1, H, S, D]
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
let v = xserv_kernels::reshape_heads_gpu(&v, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
// QK norm (reshape to [H*S, D], rmsnorm, reshape back — stays on GPU)
|
||||
let q = head_rmsnorm(&q, &layer.q_norm, eps);
|
||||
let k = head_rmsnorm(&k, &layer.k_norm, eps);
|
||||
|
||||
let q = transpose_for_rope(&q, new_tokens, num_heads, head_dim);
|
||||
let k = transpose_for_rope(&k, new_tokens, num_kv_heads, head_dim);
|
||||
// GPU transpose for RoPE: [1, H, S, D] → [S, H, D]
|
||||
let q = xserv_kernels::transpose_for_rope_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_for_rope_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
rope_inplace(&q, &self.rope_cache, &positions);
|
||||
rope_inplace(&k, &self.rope_cache, &positions);
|
||||
let q = transpose_from_rope(&q, new_tokens, num_heads, head_dim);
|
||||
let k = transpose_from_rope(&k, new_tokens, num_kv_heads, head_dim);
|
||||
// GPU transpose back: [S, H, D] → [1, H, S, D]
|
||||
let q = xserv_kernels::transpose_from_rope_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_from_rope_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
// GPU KV cache: D2D append, no CPU round-trip
|
||||
// GPU KV cache
|
||||
cache.append(layer_idx, &k, &v, new_tokens, pos_offset);
|
||||
let (k_full, v_full) = cache.get_kv_len(layer_idx, pos_offset + new_tokens);
|
||||
|
||||
// GPU repeat KV for GQA
|
||||
let n_rep = num_heads / num_kv_heads;
|
||||
let k_full = repeat_kv(&k_full, n_rep);
|
||||
let v_full = repeat_kv(&v_full, n_rep);
|
||||
let k_full = xserv_kernels::repeat_kv_gpu(&k_full, n_rep);
|
||||
let v_full = xserv_kernels::repeat_kv_gpu(&v_full, n_rep);
|
||||
|
||||
let attn_out = attention(&q, &k_full, &v_full, true);
|
||||
let attn_merged = merge_heads_any(&attn_out, new_tokens, hidden);
|
||||
// GPU merge_heads: [1, H, S, D] → [S, H*D]
|
||||
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
x = add_any(&residual, &attn_proj);
|
||||
|
||||
|
||||
161
csrc/embedding/transpose.cu
Normal file
161
csrc/embedding/transpose.cu
Normal file
@@ -0,0 +1,161 @@
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
// Transpose between [S, H, D] and [H, S, D] layouts (used for RoPE and attention).
|
||||
// Also handles [S, H*D] → [H, S, D] (reshape_heads) and reverse (merge_heads).
|
||||
|
||||
// reshape_heads: [S, H*D] → [1, H, S, D]
|
||||
// Input layout: element at [s, h*D + d] = flat[s * H*D + h*D + d]
|
||||
// Output layout: element at [0, h, s, d] = flat[h * S*D + s*D + d]
|
||||
__global__ void reshape_heads_bf16(
|
||||
const __nv_bfloat16* __restrict__ in,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int seq_len, int num_heads, int head_dim
|
||||
) {
|
||||
int hidden = num_heads * head_dim;
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = seq_len * hidden;
|
||||
if (idx >= total) return;
|
||||
|
||||
int s = idx / hidden;
|
||||
int rem = idx % hidden;
|
||||
int h = rem / head_dim;
|
||||
int d = rem % head_dim;
|
||||
|
||||
int out_idx = h * seq_len * head_dim + s * head_dim + d;
|
||||
out[out_idx] = in[idx];
|
||||
}
|
||||
|
||||
// merge_heads: [1, H, S, D] → [S, H*D]
|
||||
// Input layout: element at [0, h, s, d] = flat[h * S*D + s*D + d]
|
||||
// Output layout: element at [s, h*D + d] = flat[s * H*D + h*D + d]
|
||||
__global__ void merge_heads_bf16(
|
||||
const __nv_bfloat16* __restrict__ in,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int seq_len, int num_heads, int head_dim
|
||||
) {
|
||||
int hidden = num_heads * head_dim;
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = seq_len * hidden;
|
||||
if (idx >= total) return;
|
||||
|
||||
// idx is output index: [s, h*D + d]
|
||||
int s = idx / hidden;
|
||||
int rem = idx % hidden;
|
||||
int h = rem / head_dim;
|
||||
int d = rem % head_dim;
|
||||
|
||||
int in_idx = h * seq_len * head_dim + s * head_dim + d;
|
||||
out[idx] = in[in_idx];
|
||||
}
|
||||
|
||||
// transpose_for_rope: [1, H, S, D] → [S, H, D]
|
||||
// Input: [h, s, d] at h*S*D + s*D + d
|
||||
// Output: [s, h, d] at s*H*D + h*D + d
|
||||
__global__ void transpose_hsd_to_shd_bf16(
|
||||
const __nv_bfloat16* __restrict__ in,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int seq_len, int num_heads, int head_dim
|
||||
) {
|
||||
int total = seq_len * num_heads * head_dim;
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= total) return;
|
||||
|
||||
// idx = output flat index: s*H*D + h*D + d
|
||||
int s = idx / (num_heads * head_dim);
|
||||
int rem = idx % (num_heads * head_dim);
|
||||
int h = rem / head_dim;
|
||||
int d = rem % head_dim;
|
||||
|
||||
int in_idx = h * seq_len * head_dim + s * head_dim + d;
|
||||
out[idx] = in[in_idx];
|
||||
}
|
||||
|
||||
// transpose_from_rope: [S, H, D] → [1, H, S, D]
|
||||
// Input: [s, h, d] at s*H*D + h*D + d
|
||||
// Output: [h, s, d] at h*S*D + s*D + d
|
||||
__global__ void transpose_shd_to_hsd_bf16(
|
||||
const __nv_bfloat16* __restrict__ in,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int seq_len, int num_heads, int head_dim
|
||||
) {
|
||||
int total = seq_len * num_heads * head_dim;
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= total) return;
|
||||
|
||||
// idx = output flat index: h*S*D + s*D + d
|
||||
int h = idx / (seq_len * head_dim);
|
||||
int rem = idx % (seq_len * head_dim);
|
||||
int s = rem / head_dim;
|
||||
int d = rem % head_dim;
|
||||
|
||||
int in_idx = s * num_heads * head_dim + h * head_dim + d;
|
||||
out[idx] = in[in_idx];
|
||||
}
|
||||
|
||||
// repeat_kv: [1, KV_H, S, D] → [1, KV_H * n_rep, S, D]
|
||||
__global__ void repeat_kv_bf16(
|
||||
const __nv_bfloat16* __restrict__ in,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int kv_heads, int n_rep, int seq_len, int head_dim
|
||||
) {
|
||||
int total_heads = kv_heads * n_rep;
|
||||
int total = total_heads * seq_len * head_dim;
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= total) return;
|
||||
|
||||
int out_h = idx / (seq_len * head_dim);
|
||||
int rem = idx % (seq_len * head_dim);
|
||||
int kv_h = out_h / n_rep;
|
||||
|
||||
int in_idx = kv_h * seq_len * head_dim + rem;
|
||||
out[idx] = in[in_idx];
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_reshape_heads_bf16(const void* in, void* out,
|
||||
int seq_len, int num_heads, int head_dim, void* stream) {
|
||||
int total = seq_len * num_heads * head_dim;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
reshape_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
||||
}
|
||||
|
||||
void launch_merge_heads_bf16(const void* in, void* out,
|
||||
int seq_len, int num_heads, int head_dim, void* stream) {
|
||||
int total = seq_len * num_heads * head_dim;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
merge_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
||||
}
|
||||
|
||||
void launch_transpose_hsd_to_shd_bf16(const void* in, void* out,
|
||||
int seq_len, int num_heads, int head_dim, void* stream) {
|
||||
int total = seq_len * num_heads * head_dim;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
transpose_hsd_to_shd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
||||
}
|
||||
|
||||
void launch_transpose_shd_to_hsd_bf16(const void* in, void* out,
|
||||
int seq_len, int num_heads, int head_dim, void* stream) {
|
||||
int total = seq_len * num_heads * head_dim;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
transpose_shd_to_hsd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
||||
}
|
||||
|
||||
void launch_repeat_kv_bf16(const void* in, void* out,
|
||||
int kv_heads, int n_rep, int seq_len, int head_dim, void* stream) {
|
||||
int total = kv_heads * n_rep * seq_len * head_dim;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
repeat_kv_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, kv_heads, n_rep, seq_len, head_dim);
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user