perf: GPU transpose/reshape/repeat_kv kernels (eliminate CPU round-trips)

New CUDA kernels (csrc/embedding/transpose.cu):
- reshape_heads_bf16: [S, H*D] → [1, H, S, D]
- merge_heads_bf16: [1, H, S, D] → [S, H*D]
- transpose_hsd_to_shd_bf16: [1, H, S, D] → [S, H, D] (for RoPE)
- transpose_shd_to_hsd_bf16: [S, H, D] → [1, H, S, D] (from RoPE)
- repeat_kv_bf16: [1, KV_H, S, D] → [1, KV_H*n_rep, S, D]

Rust wrappers (xserv-kernels/src/transpose.rs):
- reshape_heads_gpu, merge_heads_gpu, transpose_for/from_rope_gpu, repeat_kv_gpu

Qwen3 forward_gpu_cache now uses all GPU kernels — zero CPU data round-trips.

Result: 50/50 self-consistent, 3-5% faster (TBT 142→137ms)
Remaining bottleneck: ~900 device::synchronize() calls + 252 cuBLAS handle
creations per token (Phase 15 targets)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-22 12:01:07 +08:00
parent 2d48f25e66
commit 2be27d6d94
5 changed files with 273 additions and 12 deletions

161
csrc/embedding/transpose.cu Normal file
View File

@@ -0,0 +1,161 @@
#include <cuda_bf16.h>
// Transpose between [S, H, D] and [H, S, D] layouts (used for RoPE and attention).
// Also handles [S, H*D] → [H, S, D] (reshape_heads) and reverse (merge_heads).
// reshape_heads: [S, H*D] → [1, H, S, D]
// Input layout: element at [s, h*D + d] = flat[s * H*D + h*D + d]
// Output layout: element at [0, h, s, d] = flat[h * S*D + s*D + d]
__global__ void reshape_heads_bf16(
const __nv_bfloat16* __restrict__ in,
__nv_bfloat16* __restrict__ out,
int seq_len, int num_heads, int head_dim
) {
int hidden = num_heads * head_dim;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = seq_len * hidden;
if (idx >= total) return;
int s = idx / hidden;
int rem = idx % hidden;
int h = rem / head_dim;
int d = rem % head_dim;
int out_idx = h * seq_len * head_dim + s * head_dim + d;
out[out_idx] = in[idx];
}
// merge_heads: [1, H, S, D] → [S, H*D]
// Input layout: element at [0, h, s, d] = flat[h * S*D + s*D + d]
// Output layout: element at [s, h*D + d] = flat[s * H*D + h*D + d]
__global__ void merge_heads_bf16(
const __nv_bfloat16* __restrict__ in,
__nv_bfloat16* __restrict__ out,
int seq_len, int num_heads, int head_dim
) {
int hidden = num_heads * head_dim;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = seq_len * hidden;
if (idx >= total) return;
// idx is output index: [s, h*D + d]
int s = idx / hidden;
int rem = idx % hidden;
int h = rem / head_dim;
int d = rem % head_dim;
int in_idx = h * seq_len * head_dim + s * head_dim + d;
out[idx] = in[in_idx];
}
// transpose_for_rope: [1, H, S, D] → [S, H, D]
// Input: [h, s, d] at h*S*D + s*D + d
// Output: [s, h, d] at s*H*D + h*D + d
__global__ void transpose_hsd_to_shd_bf16(
const __nv_bfloat16* __restrict__ in,
__nv_bfloat16* __restrict__ out,
int seq_len, int num_heads, int head_dim
) {
int total = seq_len * num_heads * head_dim;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total) return;
// idx = output flat index: s*H*D + h*D + d
int s = idx / (num_heads * head_dim);
int rem = idx % (num_heads * head_dim);
int h = rem / head_dim;
int d = rem % head_dim;
int in_idx = h * seq_len * head_dim + s * head_dim + d;
out[idx] = in[in_idx];
}
// transpose_from_rope: [S, H, D] → [1, H, S, D]
// Input: [s, h, d] at s*H*D + h*D + d
// Output: [h, s, d] at h*S*D + s*D + d
__global__ void transpose_shd_to_hsd_bf16(
const __nv_bfloat16* __restrict__ in,
__nv_bfloat16* __restrict__ out,
int seq_len, int num_heads, int head_dim
) {
int total = seq_len * num_heads * head_dim;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total) return;
// idx = output flat index: h*S*D + s*D + d
int h = idx / (seq_len * head_dim);
int rem = idx % (seq_len * head_dim);
int s = rem / head_dim;
int d = rem % head_dim;
int in_idx = s * num_heads * head_dim + h * head_dim + d;
out[idx] = in[in_idx];
}
// repeat_kv: [1, KV_H, S, D] → [1, KV_H * n_rep, S, D]
__global__ void repeat_kv_bf16(
const __nv_bfloat16* __restrict__ in,
__nv_bfloat16* __restrict__ out,
int kv_heads, int n_rep, int seq_len, int head_dim
) {
int total_heads = kv_heads * n_rep;
int total = total_heads * seq_len * head_dim;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total) return;
int out_h = idx / (seq_len * head_dim);
int rem = idx % (seq_len * head_dim);
int kv_h = out_h / n_rep;
int in_idx = kv_h * seq_len * head_dim + rem;
out[idx] = in[in_idx];
}
extern "C" {
void launch_reshape_heads_bf16(const void* in, void* out,
int seq_len, int num_heads, int head_dim, void* stream) {
int total = seq_len * num_heads * head_dim;
int block = 256;
int grid = (total + block - 1) / block;
reshape_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
}
void launch_merge_heads_bf16(const void* in, void* out,
int seq_len, int num_heads, int head_dim, void* stream) {
int total = seq_len * num_heads * head_dim;
int block = 256;
int grid = (total + block - 1) / block;
merge_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
}
void launch_transpose_hsd_to_shd_bf16(const void* in, void* out,
int seq_len, int num_heads, int head_dim, void* stream) {
int total = seq_len * num_heads * head_dim;
int block = 256;
int grid = (total + block - 1) / block;
transpose_hsd_to_shd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
}
void launch_transpose_shd_to_hsd_bf16(const void* in, void* out,
int seq_len, int num_heads, int head_dim, void* stream) {
int total = seq_len * num_heads * head_dim;
int block = 256;
int grid = (total + block - 1) / block;
transpose_shd_to_hsd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
}
void launch_repeat_kv_bf16(const void* in, void* out,
int kv_heads, int n_rep, int seq_len, int head_dim, void* stream) {
int total = kv_heads * n_rep * seq_len * head_dim;
int block = 256;
int grid = (total + block - 1) / block;
repeat_kv_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, kv_heads, n_rep, seq_len, head_dim);
}
}