diff --git a/csrc/reduce/softmax.cu b/csrc/reduce/softmax.cu index 1be36de..c3fe584 100644 --- a/csrc/reduce/softmax.cu +++ b/csrc/reduce/softmax.cu @@ -90,7 +90,7 @@ __global__ void softmax_bf16( extern "C" { void launch_softmax_f32(const void* x, void* out, int rows, int cols, void* stream) { - int block = (cols < 1024) ? cols : 1024; + int block = (cols < 512) ? cols : 512; if (block < 32) block = 32; softmax_f32<<>>( (const float*)x, (float*)out, cols); @@ -98,7 +98,7 @@ void launch_softmax_f32(const void* x, void* out, int rows, int cols, void* stre } void launch_softmax_bf16(const void* x, void* out, int rows, int cols, void* stream) { - int block = (cols < 1024) ? cols : 1024; + int block = (cols < 512) ? cols : 512; if (block < 32) block = 32; softmax_bf16<<>>( (const __nv_bfloat16*)x, (__nv_bfloat16*)out, cols);