fix: GEMV NaN bug — skip custom kernel for small N (<256)
The custom launch_gemv_bf16 kernel produces NaN when output dimension N is small (e.g. N=32 for the MoE router). Fall back to cuBLAS GemmEx for N < 256. Also removes the padding workaround in gpt_oss MoE forward. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -173,7 +173,7 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
}
|
||||
}
|
||||
GemmBackend::CuBlas => {
|
||||
if m == 1 && dtype == DType::BF16 {
|
||||
if m == 1 && dtype == DType::BF16 && n >= 256 {
|
||||
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(n * 4).unwrap();
|
||||
unsafe {
|
||||
launch_gemv_bf16(
|
||||
|
||||
@@ -411,30 +411,13 @@ impl GptOss {
|
||||
|
||||
// Router: [tokens, hidden] @ [hidden, num_experts] + bias → [tokens, num_experts]
|
||||
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
|
||||
// Pad to 2 rows to avoid GEMV path (workaround for GEMV NaN bug with small N)
|
||||
let x_padded = if num_tokens == 1 {
|
||||
let x_cpu_tmp = x.to_device(Device::Cpu);
|
||||
let xd = x_cpu_tmp.as_slice::<bf16>();
|
||||
let mut padded = xd.to_vec();
|
||||
padded.extend(vec![bf16::ZERO; hidden]);
|
||||
Tensor::from_slice(&padded, &[2, hidden]).to_device(x.device())
|
||||
} else {
|
||||
x.clone()
|
||||
};
|
||||
let router_logits_full = add_bias(
|
||||
&matmul_2d(&x_padded, &layer.router_wt),
|
||||
let router_logits = add_bias(
|
||||
&matmul_2d(x, &layer.router_wt),
|
||||
&layer.router_bias,
|
||||
);
|
||||
let router_logits = if num_tokens == 1 {
|
||||
router_logits_full.narrow(0, 0, 1).contiguous()
|
||||
} else {
|
||||
router_logits_full
|
||||
};
|
||||
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
|
||||
let router_cpu = router_logits.to_device(Device::Cpu);
|
||||
let router_data = router_cpu.as_slice::<bf16>();
|
||||
|
||||
// Copy x to CPU after all GPU ops are synced
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let x_data = x_cpu.as_slice::<bf16>();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user