Strict code review identified 30+ issues across correctness, performance, and architecture. This commit addresses 14 of them with verified fixes, restructures Phase 12 for honest continuous batching, and updates Phase 14 to target FA2 (RTX 5090 SM120 lacks TMEM required by FA4). Bug fixes: - FIX-01: Global cuBLAS handle (thread-local singleton, was per-call) - FIX-02: Remove 19 unnecessary cudaDeviceSynchronize calls from kernels - FIX-03: Qwen3 ChatML template (was plain text concatenation) - FIX-04: EOS token from tokenizer (was hardcoded 151645) - FIX-05: Storage tracks actual GPU device ordinal (was always Cuda(0)) - FIX-06: unsqueeze stride preserves contiguous layout - FIX-08: CudaDeviceProp replaced with heap buffer (was UB-prone padding) - FIX-09: Tokenizer byte_fallback to <0xNN> tokens (was panic) Feature additions: - FIX-10: SSE streaming (/v1/chat/completions, OpenAI-compatible) - FIX-11: Correct usage statistics (prompt/completion/total tokens) - FIX-13: Temperature / top-k / top-p sampling with SamplingParams Performance improvements: - FIX-07: Caching allocator wired up (thread-local pool, pooled flag) - FIX-12: KV cache staging buffers (zero-alloc get_kv_len via borrow_raw) - FIX-14: GPU strided copy kernel (eliminates contiguous() CPU round-trip) Architecture: - Phase 12 engine restructured: prefill/decode separation, honest TODO for batched GPU forward (requires Flash Attention) - Phase 14 updated: FA2 for SM120 (FA4 requires TMEM, absent on 5090) - Qwen3-7B → Qwen3-8B typo fixed across all docs (36 layers, hidden 4096) Validated on dash5 (8x RTX 5090): - 52/52 API prompts pass (EN/CN/code), SSE streaming verified - Logits match HF transformers 9/10 top-1, 4.0/5 avg top-5 overlap - 8 concurrent requests: 5.99x scheduling speedup (batch_size=4) - Throughput: 10.3 tok/s (serial), 30% of HF baseline Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
235 lines
8.5 KiB
Plaintext
235 lines
8.5 KiB
Plaintext
#include <cuda_bf16.h>
|
|
|
|
// Transpose between [S, H, D] and [H, S, D] layouts (used for RoPE and attention).
|
|
// Also handles [S, H*D] → [H, S, D] (reshape_heads) and reverse (merge_heads).
|
|
|
|
// reshape_heads: [S, H*D] → [1, H, S, D]
|
|
// Input layout: element at [s, h*D + d] = flat[s * H*D + h*D + d]
|
|
// Output layout: element at [0, h, s, d] = flat[h * S*D + s*D + d]
|
|
__global__ void reshape_heads_bf16(
|
|
const __nv_bfloat16* __restrict__ in,
|
|
__nv_bfloat16* __restrict__ out,
|
|
int seq_len, int num_heads, int head_dim
|
|
) {
|
|
int hidden = num_heads * head_dim;
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
int total = seq_len * hidden;
|
|
if (idx >= total) return;
|
|
|
|
int s = idx / hidden;
|
|
int rem = idx % hidden;
|
|
int h = rem / head_dim;
|
|
int d = rem % head_dim;
|
|
|
|
int out_idx = h * seq_len * head_dim + s * head_dim + d;
|
|
out[out_idx] = in[idx];
|
|
}
|
|
|
|
// merge_heads: [1, H, S, D] → [S, H*D]
|
|
// Input layout: element at [0, h, s, d] = flat[h * S*D + s*D + d]
|
|
// Output layout: element at [s, h*D + d] = flat[s * H*D + h*D + d]
|
|
__global__ void merge_heads_bf16(
|
|
const __nv_bfloat16* __restrict__ in,
|
|
__nv_bfloat16* __restrict__ out,
|
|
int seq_len, int num_heads, int head_dim
|
|
) {
|
|
int hidden = num_heads * head_dim;
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
int total = seq_len * hidden;
|
|
if (idx >= total) return;
|
|
|
|
// idx is output index: [s, h*D + d]
|
|
int s = idx / hidden;
|
|
int rem = idx % hidden;
|
|
int h = rem / head_dim;
|
|
int d = rem % head_dim;
|
|
|
|
int in_idx = h * seq_len * head_dim + s * head_dim + d;
|
|
out[idx] = in[in_idx];
|
|
}
|
|
|
|
// transpose_for_rope: [1, H, S, D] → [S, H, D]
|
|
// Input: [h, s, d] at h*S*D + s*D + d
|
|
// Output: [s, h, d] at s*H*D + h*D + d
|
|
__global__ void transpose_hsd_to_shd_bf16(
|
|
const __nv_bfloat16* __restrict__ in,
|
|
__nv_bfloat16* __restrict__ out,
|
|
int seq_len, int num_heads, int head_dim
|
|
) {
|
|
int total = seq_len * num_heads * head_dim;
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (idx >= total) return;
|
|
|
|
// idx = output flat index: s*H*D + h*D + d
|
|
int s = idx / (num_heads * head_dim);
|
|
int rem = idx % (num_heads * head_dim);
|
|
int h = rem / head_dim;
|
|
int d = rem % head_dim;
|
|
|
|
int in_idx = h * seq_len * head_dim + s * head_dim + d;
|
|
out[idx] = in[in_idx];
|
|
}
|
|
|
|
// transpose_from_rope: [S, H, D] → [1, H, S, D]
|
|
// Input: [s, h, d] at s*H*D + h*D + d
|
|
// Output: [h, s, d] at h*S*D + s*D + d
|
|
__global__ void transpose_shd_to_hsd_bf16(
|
|
const __nv_bfloat16* __restrict__ in,
|
|
__nv_bfloat16* __restrict__ out,
|
|
int seq_len, int num_heads, int head_dim
|
|
) {
|
|
int total = seq_len * num_heads * head_dim;
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (idx >= total) return;
|
|
|
|
// idx = output flat index: h*S*D + s*D + d
|
|
int h = idx / (seq_len * head_dim);
|
|
int rem = idx % (seq_len * head_dim);
|
|
int s = rem / head_dim;
|
|
int d = rem % head_dim;
|
|
|
|
int in_idx = s * num_heads * head_dim + h * head_dim + d;
|
|
out[idx] = in[in_idx];
|
|
}
|
|
|
|
// repeat_kv: [1, KV_H, S, D] → [1, KV_H * n_rep, S, D]
|
|
__global__ void repeat_kv_bf16(
|
|
const __nv_bfloat16* __restrict__ in,
|
|
__nv_bfloat16* __restrict__ out,
|
|
int kv_heads, int n_rep, int seq_len, int head_dim
|
|
) {
|
|
int total_heads = kv_heads * n_rep;
|
|
int total = total_heads * seq_len * head_dim;
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (idx >= total) return;
|
|
|
|
int out_h = idx / (seq_len * head_dim);
|
|
int rem = idx % (seq_len * head_dim);
|
|
int kv_h = out_h / n_rep;
|
|
|
|
int in_idx = kv_h * seq_len * head_dim + rem;
|
|
out[idx] = in[in_idx];
|
|
}
|
|
|
|
// ---- Generic strided copy (up to 4D) ----
|
|
// Each thread copies one element. Maps flat contiguous output index to strided input index.
|
|
// Unused dimensions are padded with shape=1, stride=0.
|
|
|
|
__global__ void strided_copy_bf16(
|
|
const __nv_bfloat16* __restrict__ in,
|
|
__nv_bfloat16* __restrict__ out,
|
|
int numel,
|
|
int ndim,
|
|
int shape0, int shape1, int shape2, int shape3,
|
|
int in_stride0, int in_stride1, int in_stride2, int in_stride3,
|
|
int in_offset
|
|
) {
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (idx >= numel) return;
|
|
|
|
// Decompose flat output index into multi-dim indices (rightmost = fastest)
|
|
int remaining = idx;
|
|
int i3 = remaining % shape3; remaining /= shape3;
|
|
int i2 = remaining % shape2; remaining /= shape2;
|
|
int i1 = remaining % shape1; remaining /= shape1;
|
|
int i0 = remaining;
|
|
|
|
int in_idx = in_offset + i0 * in_stride0 + i1 * in_stride1 + i2 * in_stride2 + i3 * in_stride3;
|
|
out[idx] = in[in_idx];
|
|
}
|
|
|
|
__global__ void strided_copy_f32(
|
|
const float* __restrict__ in,
|
|
float* __restrict__ out,
|
|
int numel,
|
|
int ndim,
|
|
int shape0, int shape1, int shape2, int shape3,
|
|
int in_stride0, int in_stride1, int in_stride2, int in_stride3,
|
|
int in_offset
|
|
) {
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (idx >= numel) return;
|
|
|
|
int remaining = idx;
|
|
int i3 = remaining % shape3; remaining /= shape3;
|
|
int i2 = remaining % shape2; remaining /= shape2;
|
|
int i1 = remaining % shape1; remaining /= shape1;
|
|
int i0 = remaining;
|
|
|
|
int in_idx = in_offset + i0 * in_stride0 + i1 * in_stride1 + i2 * in_stride2 + i3 * in_stride3;
|
|
out[idx] = in[in_idx];
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void launch_reshape_heads_bf16(const void* in, void* out,
|
|
int seq_len, int num_heads, int head_dim, void* stream) {
|
|
int total = seq_len * num_heads * head_dim;
|
|
int block = 256;
|
|
int grid = (total + block - 1) / block;
|
|
reshape_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
|
}
|
|
|
|
void launch_merge_heads_bf16(const void* in, void* out,
|
|
int seq_len, int num_heads, int head_dim, void* stream) {
|
|
int total = seq_len * num_heads * head_dim;
|
|
int block = 256;
|
|
int grid = (total + block - 1) / block;
|
|
merge_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
|
}
|
|
|
|
void launch_transpose_hsd_to_shd_bf16(const void* in, void* out,
|
|
int seq_len, int num_heads, int head_dim, void* stream) {
|
|
int total = seq_len * num_heads * head_dim;
|
|
int block = 256;
|
|
int grid = (total + block - 1) / block;
|
|
transpose_hsd_to_shd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
|
}
|
|
|
|
void launch_transpose_shd_to_hsd_bf16(const void* in, void* out,
|
|
int seq_len, int num_heads, int head_dim, void* stream) {
|
|
int total = seq_len * num_heads * head_dim;
|
|
int block = 256;
|
|
int grid = (total + block - 1) / block;
|
|
transpose_shd_to_hsd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
|
}
|
|
|
|
void launch_repeat_kv_bf16(const void* in, void* out,
|
|
int kv_heads, int n_rep, int seq_len, int head_dim, void* stream) {
|
|
int total = kv_heads * n_rep * seq_len * head_dim;
|
|
int block = 256;
|
|
int grid = (total + block - 1) / block;
|
|
repeat_kv_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, kv_heads, n_rep, seq_len, head_dim);
|
|
}
|
|
|
|
void launch_strided_copy_bf16(const void* in, void* out, int numel, int ndim,
|
|
int shape0, int shape1, int shape2, int shape3,
|
|
int in_stride0, int in_stride1, int in_stride2, int in_stride3,
|
|
int in_offset, void* stream) {
|
|
int block = 256;
|
|
int grid = (numel + block - 1) / block;
|
|
strided_copy_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, numel, ndim,
|
|
shape0, shape1, shape2, shape3,
|
|
in_stride0, in_stride1, in_stride2, in_stride3, in_offset);
|
|
}
|
|
|
|
void launch_strided_copy_f32(const void* in, void* out, int numel, int ndim,
|
|
int shape0, int shape1, int shape2, int shape3,
|
|
int in_stride0, int in_stride1, int in_stride2, int in_stride3,
|
|
int in_offset, void* stream) {
|
|
int block = 256;
|
|
int grid = (numel + block - 1) / block;
|
|
strided_copy_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const float*)in, (float*)out, numel, ndim,
|
|
shape0, shape1, shape2, shape3,
|
|
in_stride0, in_stride1, in_stride2, in_stride3, in_offset);
|
|
}
|
|
|
|
}
|