review cleanups: pp+gpt-oss guard, sparse GEMV asserts, warnings

- --pp with gpt-oss now fails with a clear message instead of a
  cryptic missing-weight panic inside the Qwen3-only PP engine.
- Sparse GEMV wrappers assert K%16==0 (FP8) / K%32==0 (MXFP4) — the
  uint4-vectorized kernels would silently drop a tail otherwise.
- Document the topk_ids buffer holding i32 under an F32 dtype label
  (DType has no I32).
- Drop unused imports/locals and the cuBLASLt scale-mode constants
  orphaned by the strided-batched FP8 rework (e631a71).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 17:02:59 +08:00
parent 1897b2e17a
commit 5343391dbd
8 changed files with 21 additions and 17 deletions

View File

@@ -1,5 +1,4 @@
use std::ffi::c_void; use std::ffi::c_void;
use xserv_cuda::GpuBuffer;
use xserv_tensor::{DType, Device, Tensor}; use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" { unsafe extern "C" {

View File

@@ -1,5 +1,5 @@
use std::ffi::c_void; use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor}; use xserv_tensor::{DType, Tensor};
use crate::gemm::{cublas_handle, CublasHandle}; use crate::gemm::{cublas_handle, CublasHandle};
@@ -88,6 +88,9 @@ pub fn moe_topk_softmax(
let num_tokens = router_logits.shape()[0]; let num_tokens = router_logits.shape()[0];
assert_eq!(router_logits.shape()[1], num_experts); assert_eq!(router_logits.shape()[1], num_experts);
// NOTE: topk_ids actually holds i32 expert indices; DType has no I32, so
// this is a raw 4-byte buffer mislabeled F32. Never read it as floats —
// all consumers (weighted-sum / sparse GEMV kernels) cast to int*.
let topk_ids = Tensor::empty(&[num_tokens, top_k], DType::F32, router_logits.device()); let topk_ids = Tensor::empty(&[num_tokens, top_k], DType::F32, router_logits.device());
let topk_weights = Tensor::empty(&[num_tokens, top_k], DType::F32, router_logits.device()); let topk_weights = Tensor::empty(&[num_tokens, top_k], DType::F32, router_logits.device());
@@ -201,8 +204,12 @@ pub fn moe_sparse_gemv_fp8(
) -> Tensor { ) -> Tensor {
assert_eq!(x.dtype(), DType::BF16); assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous()); assert!(x.is_contiguous());
assert_eq!(w_fp8_t.dtype(), DType::FP8E4M3);
let n = w_fp8_t.shape()[1]; let n = w_fp8_t.shape()[1];
let k = w_fp8_t.shape()[2]; let k = w_fp8_t.shape()[2];
// The kernel reads weights as uint4 (16 FP8 values per lane) and would
// silently skip a K%16 tail.
assert_eq!(k % 16, 0, "sparse FP8 GEMV requires K % 16 == 0, got {k}");
assert_eq!(x.shape()[x.ndim() - 1], k); assert_eq!(x.shape()[x.ndim() - 1], k);
assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens }); assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens });
@@ -233,6 +240,8 @@ pub fn moe_sparse_gemv_mxfp4(
) -> Tensor { ) -> Tensor {
assert_eq!(x.dtype(), DType::BF16); assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous()); assert!(x.is_contiguous());
// 32-element MXFP4 blocks, read as uint4 (32 nibbles) per lane.
assert_eq!(k % 32, 0, "sparse MXFP4 GEMV requires K % 32 == 0, got {k}");
assert_eq!(x.shape()[x.ndim() - 1], k); assert_eq!(x.shape()[x.ndim() - 1], k);
assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens }); assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens });

View File

@@ -107,17 +107,11 @@ const CUDA_R_8F_E4M3: i32 = 28;
// MatmulDesc attributes // MatmulDesc attributes
const CUBLASLT_MATMUL_DESC_A_SCALE_POINTER: i32 = 17; const CUBLASLT_MATMUL_DESC_A_SCALE_POINTER: i32 = 17;
const CUBLASLT_MATMUL_DESC_B_SCALE_POINTER: i32 = 18; const CUBLASLT_MATMUL_DESC_B_SCALE_POINTER: i32 = 18;
const CUBLASLT_MATMUL_DESC_A_SCALE_MODE: i32 = 31;
const CUBLASLT_MATMUL_DESC_B_SCALE_MODE: i32 = 32;
// MatrixLayout attributes // MatrixLayout attributes
const CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT: i32 = 5; const CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT: i32 = 5;
const CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET: i32 = 6; const CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET: i32 = 6;
// Scale modes
const CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR: i32 = 0;
const CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F: i32 = 3;
// MatmulPreference attributes // MatmulPreference attributes
const CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES: i32 = 1; const CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES: i32 = 1;

View File

@@ -199,7 +199,7 @@ impl GPT2 {
layer: &GPT2Block, layer: &GPT2Block,
x: &Tensor, x: &Tensor,
cache: Option<(&mut KVCache, usize)>, cache: Option<(&mut KVCache, usize)>,
pos_offset: usize, _pos_offset: usize,
new_tokens: usize, new_tokens: usize,
num_heads: usize, num_heads: usize,
head_dim: usize, head_dim: usize,

View File

@@ -1,5 +1,5 @@
use xserv_cuda::GpuBuffer; use xserv_cuda::GpuBuffer;
use xserv_tensor::{DType, Device, Tensor}; use xserv_tensor::{DType, Tensor};
use crate::config::ModelConfig; use crate::config::ModelConfig;
/// GPU-resident KV cache. Pre-allocates max_seq_len on GPU, /// GPU-resident KV cache. Pre-allocates max_seq_len on GPU,

View File

@@ -1,7 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use half::bf16; use half::bf16;
use xserv_kernels::*; use xserv_kernels::*;
use xserv_tensor::{DType, Device, Tensor}; use xserv_tensor::{Device, Tensor};
use crate::config::ModelConfig; use crate::config::ModelConfig;
use crate::gpt2::KVCache; use crate::gpt2::KVCache;
@@ -798,7 +798,7 @@ impl Qwen3 {
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor { pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
let new_tokens = token_ids.len(); let new_tokens = token_ids.len();
let pos_offset = cache.seq_len(); let pos_offset = cache.seq_len();
let hidden = self.config.hidden();
let num_heads = self.config.num_heads(); let num_heads = self.config.num_heads();
let num_kv_heads = self.config.num_kv_heads(); let num_kv_heads = self.config.num_kv_heads();
let head_dim = self.config.head_dim(); let head_dim = self.config.head_dim();

View File

@@ -65,6 +65,13 @@ async fn main() {
std::process::exit(1); std::process::exit(1);
} }
let model_config = ModelConfig::from_file(&model_dir.join("config.json")); let model_config = ModelConfig::from_file(&model_dir.join("config.json"));
// gpt-oss is only implemented in the TP engine; route it there even at
// tp=1 (single-rank world) so quantized models can serve on one GPU.
let is_gpt_oss = model_config.model_type.as_deref() == Some("gpt_oss");
if pp > 1 && is_gpt_oss {
eprintln!("gpt-oss is not supported by the pipeline-parallel engine (Qwen3 only); use --tp instead");
std::process::exit(1);
}
let model_max_seq_len = model_config.max_seq_len(); let model_max_seq_len = model_config.max_seq_len();
if model_max_seq_len == 0 { if model_max_seq_len == 0 {
eprintln!("model config has invalid max_seq_len=0"); eprintln!("model config has invalid max_seq_len=0");
@@ -87,9 +94,6 @@ async fn main() {
let (tx, rx) = mpsc::channel::<GenerateRequest>(); let (tx, rx) = mpsc::channel::<GenerateRequest>();
let model_dir_clone = model_dir.clone(); let model_dir_clone = model_dir.clone();
// gpt-oss is only implemented in the TP engine; route it there even at
// tp=1 (single-rank world) so quantized models can serve on one GPU.
let is_gpt_oss = model_config.model_type.as_deref() == Some("gpt_oss");
std::thread::spawn(move || { std::thread::spawn(move || {
if pp > 1 { if pp > 1 {
// Pipeline-parallel path: stage-0 coordinator + worker stage threads. // Pipeline-parallel path: stage-0 coordinator + worker stage threads.

View File

@@ -93,11 +93,9 @@ __global__ void moe_replicate_bf16_kernel(
int total = local_experts * num_tokens * hidden; int total = local_experts * num_tokens * hidden;
if (idx >= total) return; if (idx >= total) return;
int expert = idx / (num_tokens * hidden);
int remainder = idx % (num_tokens * hidden); int remainder = idx % (num_tokens * hidden);
// x_rep[expert, token, dim] = x[token, dim] // x_rep[expert, token, dim] = x[token, dim]
x_rep[idx] = x[remainder]; x_rep[idx] = x[remainder];
(void)expert; // suppress unused warning
} }
// ============================================================ // ============================================================