2 Commits

Author SHA1 Message Date
d77f921a12 phase 3: GEMM kernels (naive, tiled, cuBLAS)
- Naive GEMM kernel: one thread per output element (F32 + BF16)
- Tiled GEMM kernel: 32x32 shared memory tiles (F32 + BF16)
- cuBLAS wrapper: cublasGemmEx with row-major trick
- GemmBackend enum for runtime backend selection
- CublasContext RAII handle
- Made error::check public for cross-crate use
- 17 GEMM tests: small/medium/rect sizes, all backends, F32+BF16
- Cross-backend consistency verified (naive vs tiled vs cuBLAS)
- All 44 tests pass across all crates

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-21 19:48:05 +08:00
a83971fa25 phase 2: tensor abstraction layer
- DType enum (F32, F16, BF16) with TensorDType trait
- Shape utilities: contiguous_strides, broadcast_shape, broadcast_strides
- Storage with Arc reference counting (CPU Vec<u8> or GPU GpuBuffer)
- Device enum (Cpu, Cuda(id)) with to_device transfer
- Tensor type with strided layout: reshape, transpose, squeeze, unsqueeze
- contiguous() copies non-contiguous views to contiguous layout
- from_slice, zeros, ones constructors
- as_slice<T> for typed CPU read access, data_ptr for GPU kernel launch
- CPU↔GPU roundtrip verified
- All 27 tests pass (12 cuda + 4 shape + 11 tensor)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-21 19:45:22 +08:00
16 changed files with 1173 additions and 1 deletions

View File

@@ -2,6 +2,8 @@
resolver = "2"
members = [
"crates/xserv-cuda",
"crates/xserv-tensor",
"crates/xserv-kernels",
]
[workspace.package]

View File

@@ -23,7 +23,7 @@ impl std::error::Error for CudaError {}
pub type Result<T> = std::result::Result<T, CudaError>;
pub(crate) fn check(code: i32) -> Result<()> {
pub fn check(code: i32) -> Result<()> {
if code == ffi::CUDA_SUCCESS {
return Ok(());
}

View File

@@ -0,0 +1,12 @@
[package]
name = "xserv-kernels"
version.workspace = true
edition.workspace = true
[build-dependencies]
cc = "1"
[dependencies]
xserv-cuda = { path = "../xserv-cuda" }
xserv-tensor = { path = "../xserv-tensor" }
half.workspace = true

View File

@@ -0,0 +1,21 @@
use std::env;
fn main() {
let cuda_path = env::var("CUDA_HOME")
.or_else(|_| env::var("CUDA_PATH"))
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cublas");
cc::Build::new()
.cuda(true)
.cudart("shared")
.flag("-gencode=arch=compute_120,code=sm_120")
.file("../../csrc/gemm/naive.cu")
.file("../../csrc/gemm/tiled.cu")
.compile("xserv_gemm_kernels");
println!("cargo:rerun-if-changed=../../csrc/gemm/");
}

View File

@@ -0,0 +1,151 @@
use std::ffi::c_void;
use xserv_cuda::error::{self, Result};
use xserv_tensor::{DType, Device, Tensor};
#[derive(Debug, Clone, Copy)]
pub enum GemmBackend {
Naive,
Tiled,
CuBlas,
}
// --- FFI: custom CUDA kernels ---
unsafe extern "C" {
fn launch_gemm_naive_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
fn launch_gemm_naive_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
fn launch_gemm_tiled_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
fn launch_gemm_tiled_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
}
// --- FFI: cuBLAS ---
type CublasHandle = *mut c_void;
#[allow(non_upper_case_globals)]
const CUBLAS_OP_N: i32 = 0;
// cudaDataType
const CUDA_R_32F: i32 = 0;
const CUDA_R_16BF: i32 = 14;
// cublasComputeType
const CUBLAS_COMPUTE_32F: i32 = 68;
unsafe extern "C" {
fn cublasCreate_v2(handle: *mut CublasHandle) -> i32;
fn cublasDestroy_v2(handle: CublasHandle) -> i32;
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
fn cublasGemmEx(
handle: CublasHandle,
transa: i32, transb: i32,
m: i32, n: i32, k: i32,
alpha: *const c_void,
a: *const c_void, a_type: i32, lda: i32,
b: *const c_void, b_type: i32, ldb: i32,
beta: *const c_void,
c: *mut c_void, c_type: i32, ldc: i32,
compute_type: i32,
algo: i32,
) -> i32;
}
pub struct CublasContext {
handle: CublasHandle,
}
impl CublasContext {
pub fn new() -> Result<Self> {
let mut handle = std::ptr::null_mut();
error::check(unsafe { cublasCreate_v2(&mut handle) })?;
Ok(Self { handle })
}
}
impl Drop for CublasContext {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { cublasDestroy_v2(self.handle) };
}
}
}
/// Matrix multiplication: C = A @ B
/// A: [M, K], B: [K, N], C: [M, N]
/// All tensors must be contiguous and on the same GPU.
pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
assert_eq!(a.ndim(), 2);
assert_eq!(b.ndim(), 2);
assert_eq!(a.shape()[1], b.shape()[0], "inner dimension mismatch");
assert_eq!(a.dtype(), b.dtype(), "dtype mismatch");
assert!(a.is_contiguous() && b.is_contiguous(), "matmul requires contiguous tensors");
assert!(matches!(a.device(), Device::Cuda(_)), "matmul requires GPU tensors");
let m = a.shape()[0];
let k = a.shape()[1];
let n = b.shape()[1];
let dtype = a.dtype();
let c = Tensor::zeros(&[m, n], dtype, a.device());
let a_ptr = a.data_ptr() as *const c_void;
let b_ptr = b.data_ptr() as *const c_void;
let c_ptr = c.data_ptr() as *mut c_void;
let null_stream = std::ptr::null_mut();
match backend {
GemmBackend::Naive => {
unsafe {
match dtype {
DType::F32 => launch_gemm_naive_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
DType::BF16 => launch_gemm_naive_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
_ => panic!("unsupported dtype for naive GEMM"),
}
}
xserv_cuda::device::synchronize().unwrap();
}
GemmBackend::Tiled => {
unsafe {
match dtype {
DType::F32 => launch_gemm_tiled_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
DType::BF16 => launch_gemm_tiled_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
_ => panic!("unsupported dtype for tiled GEMM"),
}
}
xserv_cuda::device::synchronize().unwrap();
}
GemmBackend::CuBlas => {
// cuBLAS uses column-major, but we have row-major tensors.
// Trick: compute C^T = B^T @ A^T, which gives us C in row-major.
// cuBLAS sees our row-major data as column-major transposed.
let ctx = CublasContext::new().unwrap();
let alpha = 1.0f32;
let beta = 0.0f32;
let (a_type, b_type, c_type) = match dtype {
DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F),
DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF),
_ => panic!("unsupported dtype for cuBLAS GEMM"),
};
unsafe {
cublasSetStream_v2(ctx.handle, null_stream);
// Row-major trick: swap A/B and transpose flags
// C(row-major) = A @ B <=> C^T(col-major) = B^T @ A^T
error::check(cublasGemmEx(
ctx.handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n as i32, m as i32, k as i32,
&alpha as *const f32 as *const c_void,
b_ptr, b_type, n as i32, // B as col-major = B^T
a_ptr, a_type, k as i32, // A as col-major = A^T
&beta as *const f32 as *const c_void,
c_ptr, c_type, n as i32, // C as col-major = C^T
CUBLAS_COMPUTE_32F,
-1, // default algo
)).expect("cuBLAS GEMM failed");
}
xserv_cuda::device::synchronize().unwrap();
}
}
c
}

View File

@@ -0,0 +1,3 @@
pub mod gemm;
pub use gemm::{GemmBackend, matmul};

View File

@@ -0,0 +1,152 @@
use half::bf16;
use xserv_kernels::{matmul, GemmBackend};
use xserv_tensor::{Device, Tensor};
fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
let mut c = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for kk in 0..k {
sum += a[i * k + kk] * b[kk * n + j];
}
c[i * n + j] = sum;
}
}
c
}
fn check_close_f32(result: &[f32], expected: &[f32], atol: f32) {
assert_eq!(result.len(), expected.len());
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
assert!(
(r - e).abs() <= atol,
"mismatch at index {i}: got {r}, expected {e}, diff {}",
(r - e).abs()
);
}
}
fn check_close_bf16(result: &[bf16], expected: &[f32], atol: f32) {
assert_eq!(result.len(), expected.len());
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
let rv = r.to_f32();
assert!(
(rv - e).abs() <= atol,
"mismatch at index {i}: got {rv}, expected {e}, diff {}",
(rv - e).abs()
);
}
}
fn run_gemm_test_f32(backend: GemmBackend, m: usize, n: usize, k: usize) {
xserv_cuda::device::set_device(0).unwrap();
let a_data: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
let b_data: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
let expected = cpu_matmul_f32(&a_data, &b_data, m, n, k);
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
let c = matmul(&a, &b, backend);
let c_cpu = c.to_device(Device::Cpu);
check_close_f32(c_cpu.as_slice::<f32>(), &expected, 1e-4);
}
fn run_gemm_test_bf16(backend: GemmBackend, m: usize, n: usize, k: usize) {
xserv_cuda::device::set_device(0).unwrap();
let a_f32: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
let b_f32: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
let expected = cpu_matmul_f32(&a_f32, &b_f32, m, n, k);
let a_data: Vec<bf16> = a_f32.iter().map(|&v| bf16::from_f32(v)).collect();
let b_data: Vec<bf16> = b_f32.iter().map(|&v| bf16::from_f32(v)).collect();
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
let c = matmul(&a, &b, backend);
let c_cpu = c.to_device(Device::Cpu);
check_close_bf16(c_cpu.as_slice::<bf16>(), &expected, 0.1);
}
// --- F32 tests ---
#[test]
fn test_gemm_naive_f32_small() { run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4); }
#[test]
fn test_gemm_naive_f32_medium() { run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64); }
#[test]
fn test_gemm_naive_f32_rect() { run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48); }
#[test]
fn test_gemm_tiled_f32_small() { run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4); }
#[test]
fn test_gemm_tiled_f32_medium() { run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128); }
#[test]
fn test_gemm_tiled_f32_rect() { run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97); }
#[test]
fn test_gemm_cublas_f32_small() { run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4); }
#[test]
fn test_gemm_cublas_f32_medium() { run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256); }
#[test]
fn test_gemm_cublas_f32_rect() { run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97); }
// --- BF16 tests ---
#[test]
fn test_gemm_naive_bf16_small() { run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4); }
#[test]
fn test_gemm_naive_bf16_medium() { run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64); }
#[test]
fn test_gemm_tiled_bf16_small() { run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4); }
#[test]
fn test_gemm_tiled_bf16_medium() { run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128); }
#[test]
fn test_gemm_cublas_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4); }
#[test]
fn test_gemm_cublas_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256); }
// --- Larger benchmark-style tests ---
#[test]
fn test_gemm_cublas_f32_1024() { run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024); }
#[test]
fn test_gemm_consistency_all_backends() {
xserv_cuda::device::set_device(0).unwrap();
let m = 64;
let n = 64;
let k = 64;
let a_data: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
let b_data: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
let c_naive = matmul(&a, &b, GemmBackend::Naive).to_device(Device::Cpu);
let c_tiled = matmul(&a, &b, GemmBackend::Tiled).to_device(Device::Cpu);
let c_cublas = matmul(&a, &b, GemmBackend::CuBlas).to_device(Device::Cpu);
let naive = c_naive.as_slice::<f32>();
let tiled = c_tiled.as_slice::<f32>();
let cublas = c_cublas.as_slice::<f32>();
check_close_f32(naive, cublas, 1e-4);
check_close_f32(tiled, cublas, 1e-4);
}

View File

@@ -0,0 +1,9 @@
[package]
name = "xserv-tensor"
version.workspace = true
edition.workspace = true
[dependencies]
xserv-cuda = { path = "../xserv-cuda" }
half.workspace = true
smallvec.workspace = true

View File

@@ -0,0 +1,57 @@
use half::{bf16, f16};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DType {
F32,
F16,
BF16,
}
impl DType {
pub fn size_bytes(self) -> usize {
match self {
DType::F32 => 4,
DType::F16 => 2,
DType::BF16 => 2,
}
}
pub fn name(self) -> &'static str {
match self {
DType::F32 => "f32",
DType::F16 => "f16",
DType::BF16 => "bf16",
}
}
}
impl std::fmt::Display for DType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.name())
}
}
/// Trait for types that can be stored in a Tensor.
pub trait TensorDType: Copy + Send + Sync + 'static {
const DTYPE: DType;
fn to_f64(self) -> f64;
fn from_f64(v: f64) -> Self;
}
impl TensorDType for f32 {
const DTYPE: DType = DType::F32;
fn to_f64(self) -> f64 { self as f64 }
fn from_f64(v: f64) -> Self { v as f32 }
}
impl TensorDType for f16 {
const DTYPE: DType = DType::F16;
fn to_f64(self) -> f64 { self.to_f32() as f64 }
fn from_f64(v: f64) -> Self { f16::from_f32(v as f32) }
}
impl TensorDType for bf16 {
const DTYPE: DType = DType::BF16;
fn to_f64(self) -> f64 { self.to_f32() as f64 }
fn from_f64(v: f64) -> Self { bf16::from_f32(v as f32) }
}

View File

@@ -0,0 +1,8 @@
pub mod dtype;
pub mod shape;
pub mod storage;
pub mod tensor;
pub use dtype::{DType, TensorDType};
pub use storage::Device;
pub use tensor::Tensor;

View File

@@ -0,0 +1,105 @@
use smallvec::SmallVec;
pub type Dims = SmallVec<[usize; 4]>;
/// Compute contiguous strides for a given shape (row-major / C order).
/// Example: shape [2, 3, 4] => strides [12, 4, 1]
pub fn contiguous_strides(shape: &[usize]) -> Dims {
let mut strides = SmallVec::with_capacity(shape.len());
strides.resize(shape.len(), 0);
if shape.is_empty() {
return strides;
}
strides[shape.len() - 1] = 1;
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
/// Check if the given strides represent contiguous (row-major) layout for the shape.
pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
if shape.is_empty() {
return true;
}
let expected = contiguous_strides(shape);
strides == expected.as_slice()
}
/// Total number of elements given a shape.
pub fn num_elements(shape: &[usize]) -> usize {
shape.iter().product()
}
/// Compute the shape after broadcasting two shapes together (NumPy rules).
/// Returns None if shapes are not broadcastable.
pub fn broadcast_shape(a: &[usize], b: &[usize]) -> Option<Dims> {
let ndim = a.len().max(b.len());
let mut result = SmallVec::with_capacity(ndim);
for i in 0..ndim {
let da = if i < ndim - a.len() { 1 } else { a[i - (ndim - a.len())] };
let db = if i < ndim - b.len() { 1 } else { b[i - (ndim - b.len())] };
if da == db {
result.push(da);
} else if da == 1 {
result.push(db);
} else if db == 1 {
result.push(da);
} else {
return None;
}
}
Some(result)
}
/// Compute broadcast strides: for dimensions where size is 1 but output is >1, stride becomes 0.
pub fn broadcast_strides(shape: &[usize], strides: &[usize], target_shape: &[usize]) -> Dims {
let ndim = target_shape.len();
let offset = ndim - shape.len();
let mut result = SmallVec::with_capacity(ndim);
for i in 0..ndim {
if i < offset {
result.push(0);
} else {
let orig_idx = i - offset;
if shape[orig_idx] == 1 && target_shape[i] > 1 {
result.push(0);
} else {
result.push(strides[orig_idx]);
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_contiguous_strides() {
assert_eq!(contiguous_strides(&[2, 3, 4]).as_slice(), &[12, 4, 1]);
assert_eq!(contiguous_strides(&[5]).as_slice(), &[1]);
assert_eq!(contiguous_strides(&[2, 3]).as_slice(), &[3, 1]);
}
#[test]
fn test_is_contiguous() {
assert!(is_contiguous(&[2, 3], &[3, 1]));
assert!(!is_contiguous(&[3, 2], &[1, 3])); // transposed
}
#[test]
fn test_broadcast_shape() {
assert_eq!(broadcast_shape(&[3, 1], &[1, 4]).unwrap().as_slice(), &[3, 4]);
assert_eq!(broadcast_shape(&[2, 3, 4], &[4]).unwrap().as_slice(), &[2, 3, 4]);
assert_eq!(broadcast_shape(&[1], &[5, 3]).unwrap().as_slice(), &[5, 3]);
assert!(broadcast_shape(&[3], &[4]).is_none());
}
#[test]
fn test_broadcast_strides() {
// [3,1] with strides [1,1] broadcast to [3,4]
assert_eq!(broadcast_strides(&[3, 1], &[1, 1], &[3, 4]).as_slice(), &[1, 0]);
}
}

View File

@@ -0,0 +1,119 @@
use std::sync::Arc;
use xserv_cuda::{GpuBuffer, Result as CudaResult};
enum StorageInner {
Cpu { data: Vec<u8> },
Cuda { buffer: GpuBuffer },
}
/// Reference-counted storage for tensor data. Multiple tensors can share
/// the same storage (e.g., after transpose or slice — view semantics).
#[derive(Clone)]
pub struct Storage(Arc<StorageInner>);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Device {
Cpu,
Cuda(u32),
}
impl std::fmt::Display for Device {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Device::Cpu => write!(f, "cpu"),
Device::Cuda(i) => write!(f, "cuda:{i}"),
}
}
}
impl Storage {
pub fn cpu(data: Vec<u8>) -> Self {
Self(Arc::new(StorageInner::Cpu { data }))
}
pub fn cuda(buffer: GpuBuffer) -> Self {
Self(Arc::new(StorageInner::Cuda { buffer }))
}
pub fn device(&self) -> Device {
match self.0.as_ref() {
StorageInner::Cpu { .. } => Device::Cpu,
StorageInner::Cuda { .. } => Device::Cuda(0),
}
}
pub fn len_bytes(&self) -> usize {
match self.0.as_ref() {
StorageInner::Cpu { data } => data.len(),
StorageInner::Cuda { buffer } => buffer.len(),
}
}
/// Get a read-only view of CPU data. Panics if storage is on GPU.
pub fn as_cpu_bytes(&self) -> &[u8] {
match self.0.as_ref() {
StorageInner::Cpu { data } => data,
StorageInner::Cuda { .. } => panic!("cannot access GPU storage as CPU bytes"),
}
}
pub fn gpu_buffer(&self) -> &GpuBuffer {
match self.0.as_ref() {
StorageInner::Cuda { buffer } => buffer,
StorageInner::Cpu { .. } => panic!("cannot access CPU storage as GPU buffer"),
}
}
/// Copy to a different device. If already on the target device, clones the Arc (no copy).
pub fn to_device(&self, target: Device) -> CudaResult<Self> {
let current = self.device();
if current == target {
return Ok(self.clone());
}
match (current, target) {
(Device::Cpu, Device::Cuda(_dev)) => {
let cpu_data = self.as_cpu_bytes();
let mut buf = GpuBuffer::alloc(cpu_data.len())?;
buf.copy_from_host(cpu_data)?;
Ok(Storage::cuda(buf))
}
(Device::Cuda(_), Device::Cpu) => {
let gpu_buf = self.gpu_buffer();
let mut data = vec![0u8; gpu_buf.len()];
gpu_buf.copy_to_host(&mut data)?;
Ok(Storage::cpu(data))
}
(Device::Cuda(_), Device::Cuda(_)) => {
let src = self.gpu_buffer();
let mut dst = GpuBuffer::alloc(src.len())?;
dst.copy_from_device(src)?;
Ok(Storage::cuda(dst))
}
_ => unreachable!(),
}
}
/// Create a new owned copy of the storage on the same device.
pub fn deep_copy(&self) -> CudaResult<Self> {
match self.0.as_ref() {
StorageInner::Cpu { data } => Ok(Storage::cpu(data.clone())),
StorageInner::Cuda { buffer } => {
let mut dst = GpuBuffer::alloc(buffer.len())?;
dst.copy_from_device(buffer)?;
Ok(Storage::cuda(dst))
}
}
}
/// Allocate zeroed storage on the given device.
pub fn zeros(len_bytes: usize, device: Device) -> CudaResult<Self> {
match device {
Device::Cpu => Ok(Storage::cpu(vec![0u8; len_bytes])),
Device::Cuda(_) => {
let mut buf = GpuBuffer::alloc(len_bytes)?;
buf.zero()?;
Ok(Storage::cuda(buf))
}
}
}
}

View File

@@ -0,0 +1,228 @@
use crate::dtype::{DType, TensorDType};
use crate::shape::{self, Dims};
use crate::storage::{Device, Storage};
/// Multi-dimensional array with CPU or GPU storage.
///
/// Tensors support view semantics: transpose, slice, etc. share
/// the underlying storage and only change shape/strides/offset.
#[derive(Clone)]
pub struct Tensor {
storage: Storage,
shape: Dims,
strides: Dims,
offset: usize,
dtype: DType,
}
impl Tensor {
// --- Creation ---
pub fn from_slice<T: TensorDType>(data: &[T], shape: &[usize]) -> Self {
let numel: usize = shape.iter().product();
assert_eq!(data.len(), numel, "data length mismatch with shape");
let bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, numel * T::DTYPE.size_bytes())
};
Self {
storage: Storage::cpu(bytes.to_vec()),
shape: Dims::from_slice(shape),
strides: shape::contiguous_strides(shape),
offset: 0,
dtype: T::DTYPE,
}
}
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
let numel = shape::num_elements(shape);
let len_bytes = numel * dtype.size_bytes();
let storage = Storage::zeros(len_bytes, device).expect("alloc failed");
Self {
storage,
shape: Dims::from_slice(shape),
strides: shape::contiguous_strides(shape),
offset: 0,
dtype,
}
}
pub fn ones(shape: &[usize], dtype: DType) -> Self {
let numel = shape::num_elements(shape);
match dtype {
DType::F32 => Self::from_slice(&vec![1.0f32; numel], shape),
DType::F16 => Self::from_slice(&vec![half::f16::from_f32(1.0); numel], shape),
DType::BF16 => Self::from_slice(&vec![half::bf16::from_f32(1.0); numel], shape),
}
}
// --- Properties ---
pub fn shape(&self) -> &[usize] { &self.shape }
pub fn strides(&self) -> &[usize] { &self.strides }
pub fn dtype(&self) -> DType { self.dtype }
pub fn ndim(&self) -> usize { self.shape.len() }
pub fn numel(&self) -> usize { shape::num_elements(&self.shape) }
pub fn offset(&self) -> usize { self.offset }
pub fn device(&self) -> Device { self.storage.device() }
pub fn is_contiguous(&self) -> bool {
shape::is_contiguous(&self.shape, &self.strides)
}
// --- Shape operations (view, no copy) ---
pub fn reshape(&self, new_shape: &[usize]) -> Self {
assert!(self.is_contiguous(), "reshape requires contiguous tensor");
let new_numel: usize = new_shape.iter().product();
assert_eq!(new_numel, self.numel(), "reshape numel mismatch");
Self {
storage: self.storage.clone(),
shape: Dims::from_slice(new_shape),
strides: shape::contiguous_strides(new_shape),
offset: self.offset,
dtype: self.dtype,
}
}
pub fn transpose(&self, dim0: usize, dim1: usize) -> Self {
assert!(dim0 < self.ndim() && dim1 < self.ndim());
let mut new_shape = self.shape.clone();
let mut new_strides = self.strides.clone();
new_shape.swap(dim0, dim1);
new_strides.swap(dim0, dim1);
Self {
storage: self.storage.clone(),
shape: new_shape,
strides: new_strides,
offset: self.offset,
dtype: self.dtype,
}
}
pub fn squeeze(&self, dim: usize) -> Self {
assert!(dim < self.ndim() && self.shape[dim] == 1);
let mut new_shape = self.shape.clone();
let mut new_strides = self.strides.clone();
new_shape.remove(dim);
new_strides.remove(dim);
Self {
storage: self.storage.clone(),
shape: new_shape,
strides: new_strides,
offset: self.offset,
dtype: self.dtype,
}
}
pub fn unsqueeze(&self, dim: usize) -> Self {
assert!(dim <= self.ndim());
let mut new_shape = self.shape.clone();
let mut new_strides = self.strides.clone();
new_shape.insert(dim, 1);
let stride_val = if dim < self.strides.len() { self.strides[dim] } else { 1 };
new_strides.insert(dim, stride_val);
Self {
storage: self.storage.clone(),
shape: new_shape,
strides: new_strides,
offset: self.offset,
dtype: self.dtype,
}
}
/// Make contiguous: if already contiguous, return clone (shared storage).
/// Otherwise, copy data into a new contiguous buffer.
pub fn contiguous(&self) -> Self {
if self.is_contiguous() {
return self.clone();
}
// Copy to contiguous layout on CPU
assert_eq!(self.device(), Device::Cpu, "contiguous() on GPU not yet supported");
let numel = self.numel();
let elem_size = self.dtype.size_bytes();
let src_bytes = self.storage.as_cpu_bytes();
let mut dst = vec![0u8; numel * elem_size];
// Iterate all elements using strides
let ndim = self.ndim();
let mut idx = vec![0usize; ndim];
for flat in 0..numel {
let src_offset = self.offset + idx.iter().zip(self.strides.iter()).map(|(i, s)| i * s).sum::<usize>();
let src_byte_offset = src_offset * elem_size;
let dst_byte_offset = flat * elem_size;
dst[dst_byte_offset..dst_byte_offset + elem_size]
.copy_from_slice(&src_bytes[src_byte_offset..src_byte_offset + elem_size]);
// Increment index (rightmost first)
for d in (0..ndim).rev() {
idx[d] += 1;
if idx[d] < self.shape[d] {
break;
}
idx[d] = 0;
}
}
Self {
storage: Storage::cpu(dst),
shape: self.shape.clone(),
strides: shape::contiguous_strides(&self.shape),
offset: 0,
dtype: self.dtype,
}
}
// --- Device transfer ---
pub fn to_device(&self, device: Device) -> Self {
let t = if self.is_contiguous() { self.clone() } else { self.contiguous() };
if t.device() == device {
return t;
}
let new_storage = t.storage.to_device(device).expect("device transfer failed");
Self {
storage: new_storage,
shape: t.shape,
strides: t.strides,
offset: 0,
dtype: t.dtype,
}
}
// --- Data access (CPU only) ---
/// Read tensor data as a typed slice. Requires contiguous CPU tensor.
pub fn as_slice<T: TensorDType>(&self) -> &[T] {
assert_eq!(T::DTYPE, self.dtype, "dtype mismatch");
assert!(self.is_contiguous(), "as_slice requires contiguous");
assert_eq!(self.device(), Device::Cpu, "as_slice requires CPU");
let bytes = self.storage.as_cpu_bytes();
let elem_size = self.dtype.size_bytes();
let start = self.offset * elem_size;
let len = self.numel();
unsafe { std::slice::from_raw_parts(bytes[start..].as_ptr() as *const T, len) }
}
/// Raw pointer to storage start (for GPU kernel launch).
pub fn data_ptr(&self) -> *const u8 {
match self.device() {
Device::Cpu => {
let bytes = self.storage.as_cpu_bytes();
unsafe { bytes.as_ptr().add(self.offset * self.dtype.size_bytes()) }
}
Device::Cuda(_) => {
let buf = self.storage.gpu_buffer();
unsafe { buf.as_ptr().add(self.offset * self.dtype.size_bytes()) }
}
}
}
pub fn storage(&self) -> &Storage { &self.storage }
}
impl std::fmt::Debug for Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f, "Tensor(shape={:?}, dtype={}, device={}, contiguous={})",
self.shape.as_slice(), self.dtype, self.device(), self.is_contiguous()
)
}
}

View File

@@ -0,0 +1,127 @@
use half::bf16;
use xserv_tensor::*;
#[test]
fn test_from_slice_and_shape() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = Tensor::from_slice(&data, &[2, 3]);
assert_eq!(t.shape(), &[2, 3]);
assert_eq!(t.strides(), &[3, 1]);
assert_eq!(t.numel(), 6);
assert_eq!(t.ndim(), 2);
assert!(t.is_contiguous());
assert_eq!(t.dtype(), DType::F32);
assert_eq!(t.device(), Device::Cpu);
}
#[test]
fn test_as_slice() {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let t = Tensor::from_slice(&data, &[4]);
assert_eq!(t.as_slice::<f32>(), &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_zeros_and_ones() {
let z = Tensor::zeros(&[2, 3], DType::F32, Device::Cpu);
assert_eq!(z.as_slice::<f32>(), &[0.0; 6]);
let o = Tensor::ones(&[3], DType::F32);
assert_eq!(o.as_slice::<f32>(), &[1.0, 1.0, 1.0]);
}
#[test]
fn test_bf16_tensor() {
let data: Vec<bf16> = vec![bf16::from_f32(1.0), bf16::from_f32(2.5), bf16::from_f32(-3.0)];
let t = Tensor::from_slice(&data, &[3]);
assert_eq!(t.dtype(), DType::BF16);
let out = t.as_slice::<bf16>();
assert_eq!(out[0].to_f32(), 1.0);
assert!((out[1].to_f32() - 2.5).abs() < 0.01);
assert_eq!(out[2].to_f32(), -3.0);
}
#[test]
fn test_reshape() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = Tensor::from_slice(&data, &[2, 3]);
let t2 = t.reshape(&[3, 2]);
assert_eq!(t2.shape(), &[3, 2]);
assert_eq!(t2.as_slice::<f32>(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let t3 = t.reshape(&[6]);
assert_eq!(t3.shape(), &[6]);
}
#[test]
fn test_transpose() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = Tensor::from_slice(&data, &[2, 3]);
let tt = t.transpose(0, 1);
assert_eq!(tt.shape(), &[3, 2]);
assert_eq!(tt.strides(), &[1, 3]);
assert!(!tt.is_contiguous());
}
#[test]
fn test_contiguous_from_transpose() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
// Original [2,3]: [[1,2,3],[4,5,6]]
let t = Tensor::from_slice(&data, &[2, 3]);
// Transpose to [3,2]: [[1,4],[2,5],[3,6]]
let tt = t.transpose(0, 1);
let tc = tt.contiguous();
assert!(tc.is_contiguous());
assert_eq!(tc.shape(), &[3, 2]);
assert_eq!(tc.as_slice::<f32>(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_squeeze_unsqueeze() {
let data = vec![1.0f32, 2.0, 3.0];
let t = Tensor::from_slice(&data, &[1, 3]);
let squeezed = t.squeeze(0);
assert_eq!(squeezed.shape(), &[3]);
let unsqueezed = squeezed.unsqueeze(0);
assert_eq!(unsqueezed.shape(), &[1, 3]);
let unsqueezed2 = squeezed.unsqueeze(1);
assert_eq!(unsqueezed2.shape(), &[3, 1]);
}
#[test]
fn test_cpu_to_gpu_roundtrip() {
xserv_cuda::device::set_device(0).unwrap();
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let cpu_t = Tensor::from_slice(&data, &[2, 2]);
let gpu_t = cpu_t.to_device(Device::Cuda(0));
assert_eq!(gpu_t.device(), Device::Cuda(0));
assert_eq!(gpu_t.shape(), &[2, 2]);
let back = gpu_t.to_device(Device::Cpu);
assert_eq!(back.device(), Device::Cpu);
assert_eq!(back.as_slice::<f32>(), &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_zeros_gpu() {
xserv_cuda::device::set_device(0).unwrap();
let t = Tensor::zeros(&[4, 4], DType::F32, Device::Cuda(0));
assert_eq!(t.device(), Device::Cuda(0));
assert_eq!(t.shape(), &[4, 4]);
let cpu = t.to_device(Device::Cpu);
assert_eq!(cpu.as_slice::<f32>(), &[0.0f32; 16]);
}
#[test]
fn test_debug_format() {
let t = Tensor::from_slice(&[1.0f32], &[1]);
let dbg = format!("{:?}", t);
assert!(dbg.contains("shape=[1]"));
assert!(dbg.contains("f32"));
assert!(dbg.contains("cpu"));
}

62
csrc/gemm/naive.cu Normal file
View File

@@ -0,0 +1,62 @@
#include <cuda_bf16.h>
// Naive GEMM: each thread computes one element of C.
// C[i][j] = sum_k A[i][k] * B[k][j]
// All matrices are row-major.
__global__ void gemm_naive_bf16(
const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C,
int M, int N, int K
) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float sum = 0.0f;
for (int k = 0; k < K; k++) {
sum += __bfloat162float(A[row * K + k]) * __bfloat162float(B[k * N + col]);
}
C[row * N + col] = __float2bfloat16(sum);
}
}
__global__ void gemm_naive_f32(
const float* A, const float* B, float* C,
int M, int N, int K
) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float sum = 0.0f;
for (int k = 0; k < K; k++) {
sum += A[row * K + k] * B[k * N + col];
}
C[row * N + col] = sum;
}
}
extern "C" {
void launch_gemm_naive_bf16(
const void* A, const void* B, void* C,
int M, int N, int K, void* stream
) {
dim3 block(16, 16);
dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
gemm_naive_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
);
}
void launch_gemm_naive_f32(
const void* A, const void* B, void* C,
int M, int N, int K, void* stream
) {
dim3 block(16, 16);
dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
gemm_naive_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
(const float*)A, (const float*)B, (float*)C, M, N, K
);
}
} // extern "C"

116
csrc/gemm/tiled.cu Normal file
View File

@@ -0,0 +1,116 @@
#include <cuda_bf16.h>
// Tiled GEMM using shared memory.
// Each thread block loads TILE_SIZE x TILE_SIZE tiles of A and B
// into shared memory, then computes a partial dot product.
#define TILE_SIZE 32
__global__ void gemm_tiled_f32(
const float* A, const float* B, float* C,
int M, int N, int K
) {
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
float sum = 0.0f;
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
// Load tile of A
int a_col = t * TILE_SIZE + threadIdx.x;
if (row < M && a_col < K) {
As[threadIdx.y][threadIdx.x] = A[row * K + a_col];
} else {
As[threadIdx.y][threadIdx.x] = 0.0f;
}
// Load tile of B
int b_row = t * TILE_SIZE + threadIdx.y;
if (b_row < K && col < N) {
Bs[threadIdx.y][threadIdx.x] = B[b_row * N + col];
} else {
Bs[threadIdx.y][threadIdx.x] = 0.0f;
}
__syncthreads();
for (int k = 0; k < TILE_SIZE; k++) {
sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
}
__syncthreads();
}
if (row < M && col < N) {
C[row * N + col] = sum;
}
}
__global__ void gemm_tiled_bf16(
const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C,
int M, int N, int K
) {
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
float sum = 0.0f;
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
int a_col = t * TILE_SIZE + threadIdx.x;
if (row < M && a_col < K) {
As[threadIdx.y][threadIdx.x] = __bfloat162float(A[row * K + a_col]);
} else {
As[threadIdx.y][threadIdx.x] = 0.0f;
}
int b_row = t * TILE_SIZE + threadIdx.y;
if (b_row < K && col < N) {
Bs[threadIdx.y][threadIdx.x] = __bfloat162float(B[b_row * N + col]);
} else {
Bs[threadIdx.y][threadIdx.x] = 0.0f;
}
__syncthreads();
for (int k = 0; k < TILE_SIZE; k++) {
sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
}
__syncthreads();
}
if (row < M && col < N) {
C[row * N + col] = __float2bfloat16(sum);
}
}
extern "C" {
void launch_gemm_tiled_f32(
const void* A, const void* B, void* C,
int M, int N, int K, void* stream
) {
dim3 block(TILE_SIZE, TILE_SIZE);
dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
gemm_tiled_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
(const float*)A, (const float*)B, (float*)C, M, N, K
);
}
void launch_gemm_tiled_bf16(
const void* A, const void* B, void* C,
int M, int N, int K, void* stream
) {
dim3 block(TILE_SIZE, TILE_SIZE);
dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
gemm_tiled_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
);
}
} // extern "C"