gpt-oss: drop debug syncs from forward; GPU broadcast bias-add

Decode carried three leftover cudaDeviceSynchronize (prefill one) from
NaN debugging — the Qwen3 path has none and the logits D2H in sample()
already orders against the null stream.

add_bias for rows>1 round-tripped the bias through the CPU (D2H + host
tile loop + H2D) on every call — 96 times per prefill across q/k/v/o.
Replace with a bias_add_2d broadcast kernel.

dash5, FP8 TP=2, warm-server: TTFT 35/49/94 -> 29/42/79 ms
(short/medium/long), TPOT 7.19-7.32 -> 6.99-7.21 ms. Greedy tokens
unchanged; GSM8K-50 94%.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 17:02:59 +08:00
parent 63f5599717
commit 1897b2e17a
4 changed files with 50 additions and 32 deletions

View File

@@ -15,6 +15,8 @@ unsafe extern "C" {
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_gpt_oss_glu_bf16(gate_up: *const c_void, out: *mut c_void, n_elements: i32,
alpha: f32, limit: f32, stream: *mut c_void);
fn launch_bias_add_2d_bf16(x: *const c_void, bias: *const c_void, out: *mut c_void,
rows: i32, cols: i32, stream: *mut c_void);
}
fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
@@ -77,6 +79,28 @@ pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
pub fn add(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_add_f32, launch_add_bf16) }
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) }
/// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c] (BF16 only).
pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor {
assert_eq!(x.ndim(), 2);
assert_eq!(bias.ndim(), 1);
assert_eq!(x.dtype(), DType::BF16);
assert_eq!(bias.dtype(), DType::BF16);
assert!(x.is_contiguous() && bias.is_contiguous());
assert!(matches!(x.device(), Device::Cuda(_)));
let rows = x.shape()[0];
let cols = x.shape()[1];
assert_eq!(bias.shape()[0], cols, "bias size {} != cols {cols}", bias.shape()[0]);
assert!(rows * cols <= i32::MAX as usize);
let out = Tensor::empty(&[rows, cols], DType::BF16, x.device());
unsafe {
launch_bias_add_2d_bf16(
x.data_ptr() as _, bias.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, cols as i32, std::ptr::null_mut(),
);
}
out
}
/// Fused SiLU×Mul: out = silu(gate) * up (BF16 only)
/// Saves one HBM read + one HBM write compared to separate silu + mul.
pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {

View File

@@ -12,7 +12,7 @@ pub mod rope;
pub mod softmax;
pub mod transpose;
pub use activation::{add, gelu, gpt_oss_glu, mul, scale, silu, silu_mul};
pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu_mul};
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16};

View File

@@ -450,12 +450,8 @@ impl GptOss {
paged_cache.advance_seq_len(slot, 1);
}
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
let x = Self::norm(&x, &self.norm, &self.norm_bias, eps);
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
let logits = matmul_2d(&x, &self.lm_head_t);
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
logits
matmul_2d(&x, &self.lm_head_t)
}
/// Paged prefill: process full prompt tokens.
@@ -519,9 +515,7 @@ impl GptOss {
}
let x = Self::norm(&x, &self.norm, &self.norm_bias, eps);
let logits = matmul_2d(&x, &self.lm_head_t);
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
logits
matmul_2d(&x, &self.lm_head_t)
}
/// MoE forward pass — fully on GPU via batched GEMM.
@@ -691,31 +685,12 @@ fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
matmul(a, b, GemmBackend::CuBlas)
}
/// Add bias to a 2D tensor: [rows, cols] + [cols] → [rows, cols]
/// Add bias to a 2D tensor: [rows, cols] + [cols] → [rows, cols].
/// Single GPU broadcast kernel — the old rows>1 path tiled the bias on the
/// CPU (D2H + host loop + H2D) on every call, 96×/prefill in the hot path.
fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
assert_eq!(x.ndim(), 2);
assert_eq!(bias.ndim(), 1);
let rows = x.shape()[0];
let cols = x.shape()[1];
assert_eq!(bias.shape()[0], cols, "bias size {} != cols {}", bias.shape()[0], cols);
let x_c = x.contiguous();
if rows == 1 {
// Fast path: reshape bias [cols] → [1, cols] (zero-copy), add directly on GPU
let bias_2d = bias.reshape(&[1, cols]);
return xserv_kernels::add(&x_c, &bias_2d);
}
// General path: tile bias to [rows, cols] via CPU, then add on GPU
let bias_cpu = bias.to_device(Device::Cpu);
let bias_data = bias_cpu.as_slice::<bf16>();
let mut tiled = Vec::with_capacity(rows * cols);
for _ in 0..rows {
tiled.extend_from_slice(bias_data);
}
let bias_tiled = Tensor::from_slice(&tiled, &[rows, cols]).to_device(x.device());
xserv_kernels::add(&x_c, &bias_tiled)
xserv_kernels::bias_add_2d(&x_c, bias)
}
fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {

View File

@@ -87,6 +87,17 @@ __global__ void add_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b,
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(a[idx]) + __bfloat162float(b[idx]));
}
// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c]
__global__ void bias_add_2d_bf16_kernel(
const __nv_bfloat16* __restrict__ x, const __nv_bfloat16* __restrict__ bias,
__nv_bfloat16* __restrict__ out, int rows, int cols
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= rows * cols) return;
float v = __bfloat162float(x[idx]) + __bfloat162float(bias[idx % cols]);
out[idx] = __float2bfloat16(v);
}
// Element-wise mul: out = a * b
__global__ void mul_f32_kernel(const float* a, const float* b, float* out, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -159,6 +170,14 @@ void launch_add_bf16(const void* a, const void* b, void* out, int n, void* strea
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
CUDA_CHECK_LAST_ERROR();
}
void launch_bias_add_2d_bf16(const void* x, const void* bias, void* out, int rows, int cols, void* stream) {
int n = rows * cols;
int block = 256;
int grid = (n + block - 1) / block;
bias_add_2d_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const __nv_bfloat16*)bias, (__nv_bfloat16*)out, rows, cols);
CUDA_CHECK_LAST_ERROR();
}
void launch_mul_f32(const void* a, const void* b, void* out, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;