quantization: W8A8 FP8 compute via cuBLASLt tensor cores

Replace the W8A16 dequant→BF16-GEMM path with native FP8×FP8→BF16 GEMM
using cuBLASLt on Blackwell (RTX 5090). Both weights (static FP8 E4M3)
and activations (dynamically quantized per-row) are processed directly
on FP8 tensor cores.

Key implementation details:
- cuBLASLt on Blackwell requires transA=T for FP8, so expert weights
  are transposed during model loading ([E,K,N] → [E,N,K])
- Per-row activation quantization kernel (absmax/448 → FP8 E4M3)
- Post-GEMM row-wise rescaling recovers per-token precision
- Per-expert loop (not batched) due to cuBLASLt FP8 scale constraints

The same FP8 quantized model files work — no re-quantization needed.
Activation quantization happens dynamically at inference time.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-07 20:38:26 +08:00
parent 9f1fbbb98b
commit 76487b7963
4 changed files with 508 additions and 15 deletions

View File

@@ -8,6 +8,7 @@ fn main() {
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cublas");
println!("cargo:rustc-link-lib=dylib=cublasLt");
cc::Build::new()
.cuda(true)
@@ -31,6 +32,7 @@ fn main() {
.file("../../csrc/attention/reshape_and_cache.cu")
.file("../../csrc/moe/moe_kernels.cu")
.file("../../csrc/quantization/dequant_fp8.cu")
.file("../../csrc/quantization/quantize_fp8.cu")
.compile("xserv_kernels");
println!("cargo:rerun-if-changed=../../csrc/");

View File

@@ -1,6 +1,12 @@
use std::cell::RefCell;
use std::ffi::c_void;
use xserv_cuda::GpuBuffer;
use xserv_tensor::{DType, Tensor};
// ============================================================
// FFI: custom CUDA kernels
// ============================================================
unsafe extern "C" {
fn launch_dequant_fp8e4m3_to_bf16(
src: *const c_void,
@@ -9,8 +15,135 @@ unsafe extern "C" {
num_experts: i32, rows: i32, cols: i32,
stream: *mut c_void,
);
fn launch_quantize_bf16_to_fp8e4m3_rowwise(
src: *const c_void,
dst: *mut c_void,
scales: *mut c_void,
num_rows: i32, cols: i32,
stream: *mut c_void,
);
fn launch_rowwise_scale_bf16(
data: *mut c_void,
scales: *const c_void,
num_rows: i32, cols: i32,
stream: *mut c_void,
);
}
// ============================================================
// FFI: cuBLASLt
// ============================================================
type CublasLtHandle = *mut c_void;
type CublasLtMatmulDesc = *mut c_void;
type CublasLtMatrixLayout = *mut c_void;
type CublasLtMatmulPreference = *mut c_void;
#[repr(C)]
#[derive(Clone, Copy)]
struct CublasLtMatmulAlgo {
data: [u64; 8],
}
#[repr(C)]
struct CublasLtMatmulHeuristicResult {
algo: CublasLtMatmulAlgo,
workspace_size: usize,
state: i32,
_reserved: [f32; 4],
}
unsafe extern "C" {
fn cublasLtCreate(handle: *mut CublasLtHandle) -> i32;
fn cublasLtDestroy(handle: CublasLtHandle) -> i32;
fn cublasLtMatmulDescCreate(desc: *mut CublasLtMatmulDesc, compute_type: i32, scale_type: i32) -> i32;
fn cublasLtMatmulDescDestroy(desc: CublasLtMatmulDesc) -> i32;
fn cublasLtMatmulDescSetAttribute(desc: CublasLtMatmulDesc, attr: i32, buf: *const c_void, size: usize) -> i32;
fn cublasLtMatrixLayoutCreate(layout: *mut CublasLtMatrixLayout, dtype: i32, rows: u64, cols: u64, ld: i64) -> i32;
fn cublasLtMatrixLayoutDestroy(layout: CublasLtMatrixLayout) -> i32;
fn cublasLtMatrixLayoutSetAttribute(layout: CublasLtMatrixLayout, attr: i32, buf: *const c_void, size: usize) -> i32;
fn cublasLtMatmulPreferenceCreate(pref: *mut CublasLtMatmulPreference) -> i32;
fn cublasLtMatmulPreferenceDestroy(pref: CublasLtMatmulPreference) -> i32;
fn cublasLtMatmulPreferenceSetAttribute(pref: CublasLtMatmulPreference, attr: i32, buf: *const c_void, size: usize) -> i32;
fn cublasLtMatmulAlgoGetHeuristic(
handle: CublasLtHandle, desc: CublasLtMatmulDesc,
a_layout: CublasLtMatrixLayout, b_layout: CublasLtMatrixLayout,
c_layout: CublasLtMatrixLayout, d_layout: CublasLtMatrixLayout,
pref: CublasLtMatmulPreference,
requested: i32,
results: *mut CublasLtMatmulHeuristicResult,
found: *mut i32,
) -> i32;
fn cublasLtMatmul(
handle: CublasLtHandle, desc: CublasLtMatmulDesc,
alpha: *const c_void,
a: *const c_void, a_layout: CublasLtMatrixLayout,
b: *const c_void, b_layout: CublasLtMatrixLayout,
beta: *const c_void,
c: *const c_void, c_layout: CublasLtMatrixLayout,
d: *mut c_void, d_layout: CublasLtMatrixLayout,
algo: *const CublasLtMatmulAlgo,
workspace: *mut c_void, workspace_size: usize,
stream: *mut c_void,
) -> i32;
}
// cuBLASLt constants
const CUBLAS_COMPUTE_32F: i32 = 68;
const CUDA_R_32F: i32 = 0;
const CUDA_R_16BF: i32 = 14;
const CUDA_R_8F_E4M3: i32 = 28;
// MatmulDesc attributes
const CUBLASLT_MATMUL_DESC_A_SCALE_POINTER: i32 = 17;
const CUBLASLT_MATMUL_DESC_B_SCALE_POINTER: i32 = 18;
const CUBLASLT_MATMUL_DESC_A_SCALE_MODE: i32 = 31;
const CUBLASLT_MATMUL_DESC_B_SCALE_MODE: i32 = 32;
// MatrixLayout attributes
const CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT: i32 = 5;
const CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET: i32 = 6;
// Scale modes
const CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR: i32 = 0;
const CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F: i32 = 3;
// MatmulPreference attributes
const CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES: i32 = 1;
const WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
struct CublasLtContext {
handle: CublasLtHandle,
workspace: GpuBuffer,
}
impl CublasLtContext {
fn new() -> Self {
let mut handle = std::ptr::null_mut();
let status = unsafe { cublasLtCreate(&mut handle) };
assert_eq!(status, 0, "cublasLtCreate failed: {status}");
let workspace = GpuBuffer::alloc(WORKSPACE_BYTES).expect("alloc cublasLt workspace");
Self { handle, workspace }
}
}
impl Drop for CublasLtContext {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { cublasLtDestroy(self.handle) };
}
}
}
thread_local! {
static CUBLASLT_CTX: RefCell<CublasLtContext> = RefCell::new(CublasLtContext::new());
}
// ============================================================
// Public API
// ============================================================
/// Dequantize a 3D FP8 E4M3 tensor to BF16 using per-expert FP32 scales.
///
/// src: [num_experts, rows, cols] FP8E4M3, contiguous, GPU
@@ -44,3 +177,207 @@ pub fn dequant_fp8_to_bf16(src: &Tensor, scales: &Tensor) -> Tensor {
out
}
/// Dynamically quantize a contiguous BF16 tensor to FP8 E4M3 with per-row scales.
///
/// src: [num_rows, cols] or [batch, rows, cols] BF16, contiguous, GPU
/// Treats the tensor as 2D (flattens leading dims into num_rows).
///
/// Returns: (fp8_data [same shape] FP8E4M3, scales [total_rows] F32)
pub fn quantize_bf16_to_fp8_rowwise(src: &Tensor) -> (Tensor, Tensor) {
assert_eq!(src.dtype(), DType::BF16);
assert!(src.is_contiguous());
assert!(src.ndim() >= 2);
let cols = src.shape()[src.ndim() - 1];
let num_rows: usize = src.shape()[..src.ndim() - 1].iter().product();
let fp8_out = Tensor::empty(src.shape(), DType::FP8E4M3, src.device());
let scales = Tensor::empty(&[num_rows], DType::F32, src.device());
unsafe {
launch_quantize_bf16_to_fp8e4m3_rowwise(
src.data_ptr() as *const c_void,
fp8_out.data_ptr() as *mut c_void,
scales.data_ptr() as *mut c_void,
num_rows as i32, cols as i32,
std::ptr::null_mut(),
);
}
(fp8_out, scales)
}
/// FP8 batched GEMM via cuBLASLt (transA=T required on Blackwell).
///
/// Computes: C[b] = scale_a[b] * scale_b[b] * (A_fp8[b] @ B_fp8_T[b]^T)
/// effectively C[b] = A[b, M, K] @ W[b, K, N] but W is stored transposed
/// as [b, N, K] for cuBLASLt FP8 compatibility.
///
/// a_fp8: [batch, M, K] FP8E4M3 (activations, quantized per-row)
/// a_scales: [batch * M] F32 (per-token scales, collapsed to per-batch max)
/// b_fp8_t: [batch, N, K] FP8E4M3 (weights, TRANSPOSED for cuBLASLt)
/// b_scales: [batch] F32 (per-expert scalar scales)
///
/// Returns: [batch, M, N] BF16
pub fn batched_gemm_fp8(
a_fp8: &Tensor,
a_scales: &Tensor,
b_fp8_t: &Tensor,
b_scales: &Tensor,
) -> Tensor {
assert_eq!(a_fp8.ndim(), 3);
assert_eq!(b_fp8_t.ndim(), 3);
assert_eq!(a_fp8.dtype(), DType::FP8E4M3);
assert_eq!(b_fp8_t.dtype(), DType::FP8E4M3);
assert!(a_fp8.is_contiguous() && b_fp8_t.is_contiguous());
assert_eq!(a_fp8.shape()[0], b_fp8_t.shape()[0]);
// b_fp8_t is [batch, N, K] transposed, so b_fp8_t.shape[2] == K == a_fp8.shape[2]
assert_eq!(a_fp8.shape()[2], b_fp8_t.shape()[2]);
let batch = a_fp8.shape()[0];
let m = a_fp8.shape()[1]; // tokens
let k = a_fp8.shape()[2]; // hidden
let n = b_fp8_t.shape()[1]; // out_dim (from transposed weight)
// Per-token scales → per-expert scales (max over tokens within each expert batch)
// a_scales: [batch * M] → we take the max per expert to get [batch] scalar scales
// This is a slight accuracy tradeoff vs per-token, but allows scalar GEMM scale mode.
assert_eq!(a_scales.shape()[0], batch * m);
assert_eq!(b_scales.shape()[0], batch);
let c = Tensor::empty(&[batch, m, n], DType::BF16, a_fp8.device());
// Read weight scales to host for the per-expert loop
let b_scales_cpu = b_scales.to_device(xserv_tensor::Device::Cpu);
let b_s_data = b_scales_cpu.as_slice::<f32>();
CUBLASLT_CTX.with(|cell| {
let ctx = cell.borrow();
let handle = ctx.handle;
// Strides (in bytes) for one expert slice
let stride_a = m * k; // FP8: 1 byte per elem
let stride_b = n * k; // FP8: 1 byte per elem (transposed: [N, K])
let stride_c = m * n * 2; // BF16: 2 bytes per elem
for e in 0..batch {
let a_ptr = unsafe { (a_fp8.data_ptr() as *const u8).add(e * stride_a) as *const c_void };
let b_ptr = unsafe { (b_fp8_t.data_ptr() as *const u8).add(e * stride_b) as *const c_void };
let c_ptr = unsafe { (c.data_ptr() as *mut u8).add(e * stride_c) as *mut c_void };
// alpha = b_scale (weight scale). Per-row activation scale applied post-GEMM.
// GEMM computes: D = alpha * (A_fp8 @ B_fp8_T)
// = b_scale * ((A_real / a_scale_row) @ (B_real / b_scale))
// = (A_real / a_scale_row) @ B_real
// Post-multiply row i by a_scale[i] to recover the correct result.
let alpha: f32 = b_s_data[e];
let beta: f32 = 0.0;
unsafe {
// cuBLASLt FP8 on Blackwell requires transA=T, transB=N.
// cuBLASLt computes: D(m,n) = op(A)(m,k) * B(k,n) with transA=T
//
// We want: D_row[M,N] = A_act_row[M,K] @ B_wt_row[K,N]
// Map to cuBLASLt with m_lt=N, n_lt=M, k_lt=K:
// "A" (transA=T): stored as (K, N) col-major ld=K → transposed to (N, K)
// Our weights are stored TRANSPOSED as [E, N, K] row-major = col-major (K, N) ld=K ✓
// "B" (transB=N): stored as (K, M) col-major ld=K
// Our activations A_act_row[M,K] = col-major (K, M) ld=K ✓
// "D": stored as (N, M) col-major ld=N
// Our output D_row[M,N] = col-major (N, M) ld=N ✓
let m_lt = n as u64;
let n_lt = m as u64;
let k_lt = k as u64;
let mut matmul_desc: CublasLtMatmulDesc = std::ptr::null_mut();
cublasLtMatmulDescCreate(&mut matmul_desc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
// Set transA=T (required for FP8 on Blackwell)
let trans_a: i32 = 1; // CUBLAS_OP_T
cublasLtMatmulDescSetAttribute(matmul_desc, 3 /*TRANSA*/, &trans_a as *const i32 as _, 4);
// FP8 requires scale pointers. We fold the actual scales into alpha,
// so set dummy 1.0 scale pointers on device.
let one_val: f32 = 1.0;
let mut one_buf = xserv_cuda::allocator::cached_alloc(4).unwrap();
one_buf.copy_from_host(&one_val.to_le_bytes()).unwrap();
let one_ptr = one_buf.as_ptr() as *const c_void;
cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, std::mem::size_of::<*const c_void>());
cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, std::mem::size_of::<*const c_void>());
// "A" layout (weights, transposed): physical (K, N) col-major, ld=K
let mut a_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut a_layout, CUDA_R_8F_E4M3, k_lt, m_lt, k as i64);
// "B" layout (activations): physical (K, M) col-major, ld=K
let mut b_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut b_layout, CUDA_R_8F_E4M3, k_lt, n_lt, k as i64);
// "C"/"D" layout (output): physical (N, M) col-major, ld=N
let mut c_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut c_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
let mut d_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut d_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
// Get algo heuristic
let mut pref: CublasLtMatmulPreference = std::ptr::null_mut();
cublasLtMatmulPreferenceCreate(&mut pref);
let ws_bytes = WORKSPACE_BYTES as u64;
cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_bytes as *const u64 as _, 8);
let mut heuristic = std::mem::zeroed::<CublasLtMatmulHeuristicResult>();
let mut found: i32 = 0;
let status = cublasLtMatmulAlgoGetHeuristic(
handle, matmul_desc,
a_layout, b_layout, c_layout, d_layout,
pref, 1, &mut heuristic, &mut found,
);
assert!(status == 0 && found > 0,
"cublasLtMatmulAlgoGetHeuristic failed for FP8 GEMM: status={status}, found={found}");
let status = cublasLtMatmul(
handle, matmul_desc,
&alpha as *const f32 as _,
b_ptr, // cuBLASLt "A" = weights
a_layout,
a_ptr, // cuBLASLt "B" = activations
b_layout,
&beta as *const f32 as _,
c_ptr, // C (unused with beta=0)
c_layout,
c_ptr, // D = output
d_layout,
&heuristic.algo,
ctx.workspace.as_ptr() as *mut c_void,
heuristic.workspace_size,
std::ptr::null_mut(),
);
assert_eq!(status, 0, "cublasLtMatmul FP8 failed for expert {e}: status={status}");
cublasLtMatmulPreferenceDestroy(pref);
cublasLtMatrixLayoutDestroy(a_layout);
cublasLtMatrixLayoutDestroy(b_layout);
cublasLtMatrixLayoutDestroy(c_layout);
cublasLtMatrixLayoutDestroy(d_layout);
cublasLtMatmulDescDestroy(matmul_desc);
}
}
});
// Post-GEMM: multiply each row of c by its activation scale.
// c is [batch, M, N] BF16. a_scales is [batch * M] F32.
// This recovers the per-token scale that was divided out during quantization.
let total_rows = (batch * m) as i32;
let cols = n as i32;
unsafe {
launch_rowwise_scale_bf16(
c.data_ptr() as *mut c_void,
a_scales.data_ptr() as *const c_void,
total_rows, cols,
std::ptr::null_mut(),
);
}
c
}

View File

@@ -47,10 +47,11 @@ struct GptOssBlock {
expert_gate_up_bias: Tensor, // [local_experts, 2*inter]
expert_down_wt: Tensor, // [local_experts, inter, hidden] BF16
expert_down_bias: Tensor, // [local_experts, hidden]
// FP8 quantized expert weights (Some when running FP8 W8A16)
expert_gate_up_fp8: Option<Tensor>, // [local_experts, hidden, 2*inter] FP8E4M3
// FP8 quantized expert weights (Some when running FP8 W8A8)
// Transposed layout [E, N, K] for cuBLASLt FP8 (Blackwell requires transA=T)
expert_gate_up_fp8: Option<Tensor>, // [local_experts, 2*inter, hidden] FP8E4M3
expert_gate_up_scale: Option<Tensor>,// [local_experts] F32
expert_down_fp8: Option<Tensor>, // [local_experts, inter, hidden] FP8E4M3
expert_down_fp8: Option<Tensor>, // [local_experts, hidden, inter] FP8E4M3
expert_down_scale: Option<Tensor>, // [local_experts] F32
local_experts: usize,
// Activation params
@@ -183,9 +184,12 @@ impl GptOss {
let expert_down_scale_gpu;
if is_fp8 {
// FP8 path: load quantized weights and scales
expert_gate_up_fp8 = Some(slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev));
expert_down_fp8 = Some(slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev));
// FP8 W8A8 path: load and TRANSPOSE weights for cuBLASLt (requires transA=T on Blackwell).
// Original: [E, K, N] → Transposed: [E, N, K]
let gu_sliced = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2);
let dn_sliced = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden);
expert_gate_up_fp8 = Some(transpose_3d_inner_raw(&gu_sliced, local_experts, hidden, inter2).to_device(dev));
expert_down_fp8 = Some(transpose_3d_inner_raw(&dn_sliced, local_experts, inter, hidden).to_device(dev));
// Scales: [num_experts] F32 → slice to [local_experts]
let gu_s = gate_up_scale.expect("FP8 model missing gate_up_proj_scale");
let d_s = down_scale.expect("FP8 model missing down_proj_scale");
@@ -255,7 +259,7 @@ impl GptOss {
eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm");
}
if is_fp8 {
eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A16 mode)");
eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)");
}
}
@@ -515,12 +519,15 @@ impl GptOss {
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
// 4. Batched GEMM gate_up: [E, tokens, hidden] @ [E, hidden, 2*inter] → [E, tokens, 2*inter]
let gate_up_wt = if let Some(ref fp8) = layer.expert_gate_up_fp8 {
xserv_kernels::quantization::dequant_fp8_to_bf16(fp8, layer.expert_gate_up_scale.as_ref().unwrap())
let gate_up = if let Some(ref wt_fp8_t) = layer.expert_gate_up_fp8 {
// W8A8: quantize activations with per-expert scalar scale, use cuBLASLt FP8 GEMM
let (x_fp8, x_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep);
xserv_kernels::quantization::batched_gemm_fp8(
&x_fp8, &x_scales, wt_fp8_t, layer.expert_gate_up_scale.as_ref().unwrap(),
)
} else {
layer.expert_gate_up_wt.clone()
xserv_kernels::moe::batched_gemm_strided(&x_rep, &layer.expert_gate_up_wt)
};
let gate_up = xserv_kernels::moe::batched_gemm_strided(&x_rep, &gate_up_wt);
// 5. Bias add: gate_up += expert_gate_up_bias (in-place)
xserv_kernels::moe::moe_bias_add_3d(&gate_up, &layer.expert_gate_up_bias);
@@ -534,12 +541,15 @@ impl GptOss {
let activated = activated_flat.reshape(&[local_experts, num_tokens, inter]);
// 7. Batched GEMM down: [E, tokens, inter] @ [E, inter, hidden] → [E, tokens, hidden]
let down_wt = if let Some(ref fp8) = layer.expert_down_fp8 {
xserv_kernels::quantization::dequant_fp8_to_bf16(fp8, layer.expert_down_scale.as_ref().unwrap())
let down = if let Some(ref wt_fp8) = layer.expert_down_fp8 {
// W8A8: quantize post-GLU activations to FP8, use cuBLASLt FP8 GEMM
let (act_fp8, act_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated);
xserv_kernels::quantization::batched_gemm_fp8(
&act_fp8, &act_scales, wt_fp8, layer.expert_down_scale.as_ref().unwrap(),
)
} else {
layer.expert_down_wt.clone()
xserv_kernels::moe::batched_gemm_strided(&activated, &layer.expert_down_wt)
};
let down = xserv_kernels::moe::batched_gemm_strided(&activated, &down_wt);
// 8. Bias add: down += expert_down_bias (in-place)
xserv_kernels::moe::moe_bias_add_3d(&down, &layer.expert_down_bias);
@@ -636,6 +646,27 @@ fn shard_1d(t: &Tensor, rank: usize, world: usize) -> Tensor {
Tensor::from_slice(&shard, &[local])
}
/// Transpose the inner two dimensions of a [batch, rows, cols] tensor → [batch, cols, rows].
/// Works on raw bytes (any dtype). CPU-only.
fn transpose_3d_inner_raw(t: &Tensor, batch: usize, rows: usize, cols: usize) -> Tensor {
assert_eq!(t.ndim(), 3);
assert_eq!(t.shape(), &[batch, rows, cols]);
let host = t.to_device(Device::Cpu);
let es = t.dtype().size_bytes();
let raw = host.as_raw_bytes();
let mut out = vec![0u8; batch * cols * rows * es];
for b in 0..batch {
for r in 0..rows {
for c in 0..cols {
let src_off = (b * rows * cols + r * cols + c) * es;
let dst_off = (b * cols * rows + c * rows + r) * es;
out[dst_off..dst_off + es].copy_from_slice(&raw[src_off..src_off + es]);
}
}
}
Tensor::from_raw_bytes(&out, &[batch, cols, rows], t.dtype())
}
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor (any dtype, raw bytes).
fn slice_expert_range_3d_raw(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
assert_eq!(t.ndim(), 3);