Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d77f921a12 | |||
| a83971fa25 |
@@ -2,6 +2,8 @@
|
||||
resolver = "2"
|
||||
members = [
|
||||
"crates/xserv-cuda",
|
||||
"crates/xserv-tensor",
|
||||
"crates/xserv-kernels",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
|
||||
@@ -23,7 +23,7 @@ impl std::error::Error for CudaError {}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, CudaError>;
|
||||
|
||||
pub(crate) fn check(code: i32) -> Result<()> {
|
||||
pub fn check(code: i32) -> Result<()> {
|
||||
if code == ffi::CUDA_SUCCESS {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
12
crates/xserv-kernels/Cargo.toml
Normal file
12
crates/xserv-kernels/Cargo.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "xserv-kernels"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[build-dependencies]
|
||||
cc = "1"
|
||||
|
||||
[dependencies]
|
||||
xserv-cuda = { path = "../xserv-cuda" }
|
||||
xserv-tensor = { path = "../xserv-tensor" }
|
||||
half.workspace = true
|
||||
21
crates/xserv-kernels/build.rs
Normal file
21
crates/xserv-kernels/build.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
use std::env;
|
||||
|
||||
fn main() {
|
||||
let cuda_path = env::var("CUDA_HOME")
|
||||
.or_else(|_| env::var("CUDA_PATH"))
|
||||
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
|
||||
|
||||
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
|
||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||
println!("cargo:rustc-link-lib=dylib=cublas");
|
||||
|
||||
cc::Build::new()
|
||||
.cuda(true)
|
||||
.cudart("shared")
|
||||
.flag("-gencode=arch=compute_120,code=sm_120")
|
||||
.file("../../csrc/gemm/naive.cu")
|
||||
.file("../../csrc/gemm/tiled.cu")
|
||||
.compile("xserv_gemm_kernels");
|
||||
|
||||
println!("cargo:rerun-if-changed=../../csrc/gemm/");
|
||||
}
|
||||
151
crates/xserv-kernels/src/gemm.rs
Normal file
151
crates/xserv-kernels/src/gemm.rs
Normal file
@@ -0,0 +1,151 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::error::{self, Result};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum GemmBackend {
|
||||
Naive,
|
||||
Tiled,
|
||||
CuBlas,
|
||||
}
|
||||
|
||||
// --- FFI: custom CUDA kernels ---
|
||||
unsafe extern "C" {
|
||||
fn launch_gemm_naive_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_naive_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_tiled_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_tiled_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
// --- FFI: cuBLAS ---
|
||||
type CublasHandle = *mut c_void;
|
||||
|
||||
#[allow(non_upper_case_globals)]
|
||||
const CUBLAS_OP_N: i32 = 0;
|
||||
|
||||
// cudaDataType
|
||||
const CUDA_R_32F: i32 = 0;
|
||||
const CUDA_R_16BF: i32 = 14;
|
||||
|
||||
// cublasComputeType
|
||||
const CUBLAS_COMPUTE_32F: i32 = 68;
|
||||
|
||||
unsafe extern "C" {
|
||||
fn cublasCreate_v2(handle: *mut CublasHandle) -> i32;
|
||||
fn cublasDestroy_v2(handle: CublasHandle) -> i32;
|
||||
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
|
||||
fn cublasGemmEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32, transb: i32,
|
||||
m: i32, n: i32, k: i32,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void, a_type: i32, lda: i32,
|
||||
b: *const c_void, b_type: i32, ldb: i32,
|
||||
beta: *const c_void,
|
||||
c: *mut c_void, c_type: i32, ldc: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
) -> i32;
|
||||
}
|
||||
|
||||
pub struct CublasContext {
|
||||
handle: CublasHandle,
|
||||
}
|
||||
|
||||
impl CublasContext {
|
||||
pub fn new() -> Result<Self> {
|
||||
let mut handle = std::ptr::null_mut();
|
||||
error::check(unsafe { cublasCreate_v2(&mut handle) })?;
|
||||
Ok(Self { handle })
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CublasContext {
|
||||
fn drop(&mut self) {
|
||||
if !self.handle.is_null() {
|
||||
unsafe { cublasDestroy_v2(self.handle) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Matrix multiplication: C = A @ B
|
||||
/// A: [M, K], B: [K, N], C: [M, N]
|
||||
/// All tensors must be contiguous and on the same GPU.
|
||||
pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert_eq!(b.ndim(), 2);
|
||||
assert_eq!(a.shape()[1], b.shape()[0], "inner dimension mismatch");
|
||||
assert_eq!(a.dtype(), b.dtype(), "dtype mismatch");
|
||||
assert!(a.is_contiguous() && b.is_contiguous(), "matmul requires contiguous tensors");
|
||||
assert!(matches!(a.device(), Device::Cuda(_)), "matmul requires GPU tensors");
|
||||
|
||||
let m = a.shape()[0];
|
||||
let k = a.shape()[1];
|
||||
let n = b.shape()[1];
|
||||
let dtype = a.dtype();
|
||||
|
||||
let c = Tensor::zeros(&[m, n], dtype, a.device());
|
||||
|
||||
let a_ptr = a.data_ptr() as *const c_void;
|
||||
let b_ptr = b.data_ptr() as *const c_void;
|
||||
let c_ptr = c.data_ptr() as *mut c_void;
|
||||
let null_stream = std::ptr::null_mut();
|
||||
|
||||
match backend {
|
||||
GemmBackend::Naive => {
|
||||
unsafe {
|
||||
match dtype {
|
||||
DType::F32 => launch_gemm_naive_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
DType::BF16 => launch_gemm_naive_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
_ => panic!("unsupported dtype for naive GEMM"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
GemmBackend::Tiled => {
|
||||
unsafe {
|
||||
match dtype {
|
||||
DType::F32 => launch_gemm_tiled_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
DType::BF16 => launch_gemm_tiled_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
_ => panic!("unsupported dtype for tiled GEMM"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
GemmBackend::CuBlas => {
|
||||
// cuBLAS uses column-major, but we have row-major tensors.
|
||||
// Trick: compute C^T = B^T @ A^T, which gives us C in row-major.
|
||||
// cuBLAS sees our row-major data as column-major transposed.
|
||||
let ctx = CublasContext::new().unwrap();
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
|
||||
let (a_type, b_type, c_type) = match dtype {
|
||||
DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F),
|
||||
DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF),
|
||||
_ => panic!("unsupported dtype for cuBLAS GEMM"),
|
||||
};
|
||||
|
||||
unsafe {
|
||||
cublasSetStream_v2(ctx.handle, null_stream);
|
||||
// Row-major trick: swap A/B and transpose flags
|
||||
// C(row-major) = A @ B <=> C^T(col-major) = B^T @ A^T
|
||||
error::check(cublasGemmEx(
|
||||
ctx.handle,
|
||||
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
n as i32, m as i32, k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b_ptr, b_type, n as i32, // B as col-major = B^T
|
||||
a_ptr, a_type, k as i32, // A as col-major = A^T
|
||||
&beta as *const f32 as *const c_void,
|
||||
c_ptr, c_type, n as i32, // C as col-major = C^T
|
||||
CUBLAS_COMPUTE_32F,
|
||||
-1, // default algo
|
||||
)).expect("cuBLAS GEMM failed");
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
c
|
||||
}
|
||||
3
crates/xserv-kernels/src/lib.rs
Normal file
3
crates/xserv-kernels/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod gemm;
|
||||
|
||||
pub use gemm::{GemmBackend, matmul};
|
||||
152
crates/xserv-kernels/tests/gemm_test.rs
Normal file
152
crates/xserv-kernels/tests/gemm_test.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
use half::bf16;
|
||||
use xserv_kernels::{matmul, GemmBackend};
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
|
||||
let mut c = vec![0.0f32; m * n];
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
let mut sum = 0.0f32;
|
||||
for kk in 0..k {
|
||||
sum += a[i * k + kk] * b[kk * n + j];
|
||||
}
|
||||
c[i * n + j] = sum;
|
||||
}
|
||||
}
|
||||
c
|
||||
}
|
||||
|
||||
fn check_close_f32(result: &[f32], expected: &[f32], atol: f32) {
|
||||
assert_eq!(result.len(), expected.len());
|
||||
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
|
||||
assert!(
|
||||
(r - e).abs() <= atol,
|
||||
"mismatch at index {i}: got {r}, expected {e}, diff {}",
|
||||
(r - e).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn check_close_bf16(result: &[bf16], expected: &[f32], atol: f32) {
|
||||
assert_eq!(result.len(), expected.len());
|
||||
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
|
||||
let rv = r.to_f32();
|
||||
assert!(
|
||||
(rv - e).abs() <= atol,
|
||||
"mismatch at index {i}: got {rv}, expected {e}, diff {}",
|
||||
(rv - e).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_gemm_test_f32(backend: GemmBackend, m: usize, n: usize, k: usize) {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
|
||||
let a_data: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
|
||||
let b_data: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
|
||||
let expected = cpu_matmul_f32(&a_data, &b_data, m, n, k);
|
||||
|
||||
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
|
||||
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
|
||||
let c = matmul(&a, &b, backend);
|
||||
|
||||
let c_cpu = c.to_device(Device::Cpu);
|
||||
check_close_f32(c_cpu.as_slice::<f32>(), &expected, 1e-4);
|
||||
}
|
||||
|
||||
fn run_gemm_test_bf16(backend: GemmBackend, m: usize, n: usize, k: usize) {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
|
||||
let a_f32: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
|
||||
let b_f32: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
|
||||
let expected = cpu_matmul_f32(&a_f32, &b_f32, m, n, k);
|
||||
|
||||
let a_data: Vec<bf16> = a_f32.iter().map(|&v| bf16::from_f32(v)).collect();
|
||||
let b_data: Vec<bf16> = b_f32.iter().map(|&v| bf16::from_f32(v)).collect();
|
||||
|
||||
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
|
||||
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
|
||||
let c = matmul(&a, &b, backend);
|
||||
|
||||
let c_cpu = c.to_device(Device::Cpu);
|
||||
check_close_bf16(c_cpu.as_slice::<bf16>(), &expected, 0.1);
|
||||
}
|
||||
|
||||
// --- F32 tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_small() { run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_medium() { run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_rect() { run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_small() { run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_medium() { run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_rect() { run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_small() { run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_medium() { run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_rect() { run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97); }
|
||||
|
||||
// --- BF16 tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_bf16_small() { run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_bf16_medium() { run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_bf16_small() { run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_bf16_medium() { run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256); }
|
||||
|
||||
// --- Larger benchmark-style tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_1024() { run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_consistency_all_backends() {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
|
||||
let m = 64;
|
||||
let n = 64;
|
||||
let k = 64;
|
||||
let a_data: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
|
||||
let b_data: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
|
||||
|
||||
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
|
||||
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
|
||||
|
||||
let c_naive = matmul(&a, &b, GemmBackend::Naive).to_device(Device::Cpu);
|
||||
let c_tiled = matmul(&a, &b, GemmBackend::Tiled).to_device(Device::Cpu);
|
||||
let c_cublas = matmul(&a, &b, GemmBackend::CuBlas).to_device(Device::Cpu);
|
||||
|
||||
let naive = c_naive.as_slice::<f32>();
|
||||
let tiled = c_tiled.as_slice::<f32>();
|
||||
let cublas = c_cublas.as_slice::<f32>();
|
||||
|
||||
check_close_f32(naive, cublas, 1e-4);
|
||||
check_close_f32(tiled, cublas, 1e-4);
|
||||
}
|
||||
9
crates/xserv-tensor/Cargo.toml
Normal file
9
crates/xserv-tensor/Cargo.toml
Normal file
@@ -0,0 +1,9 @@
|
||||
[package]
|
||||
name = "xserv-tensor"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
xserv-cuda = { path = "../xserv-cuda" }
|
||||
half.workspace = true
|
||||
smallvec.workspace = true
|
||||
57
crates/xserv-tensor/src/dtype.rs
Normal file
57
crates/xserv-tensor/src/dtype.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use half::{bf16, f16};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
F32,
|
||||
F16,
|
||||
BF16,
|
||||
}
|
||||
|
||||
impl DType {
|
||||
pub fn size_bytes(self) -> usize {
|
||||
match self {
|
||||
DType::F32 => 4,
|
||||
DType::F16 => 2,
|
||||
DType::BF16 => 2,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(self) -> &'static str {
|
||||
match self {
|
||||
DType::F32 => "f32",
|
||||
DType::F16 => "f16",
|
||||
DType::BF16 => "bf16",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.name())
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for types that can be stored in a Tensor.
|
||||
pub trait TensorDType: Copy + Send + Sync + 'static {
|
||||
const DTYPE: DType;
|
||||
fn to_f64(self) -> f64;
|
||||
fn from_f64(v: f64) -> Self;
|
||||
}
|
||||
|
||||
impl TensorDType for f32 {
|
||||
const DTYPE: DType = DType::F32;
|
||||
fn to_f64(self) -> f64 { self as f64 }
|
||||
fn from_f64(v: f64) -> Self { v as f32 }
|
||||
}
|
||||
|
||||
impl TensorDType for f16 {
|
||||
const DTYPE: DType = DType::F16;
|
||||
fn to_f64(self) -> f64 { self.to_f32() as f64 }
|
||||
fn from_f64(v: f64) -> Self { f16::from_f32(v as f32) }
|
||||
}
|
||||
|
||||
impl TensorDType for bf16 {
|
||||
const DTYPE: DType = DType::BF16;
|
||||
fn to_f64(self) -> f64 { self.to_f32() as f64 }
|
||||
fn from_f64(v: f64) -> Self { bf16::from_f32(v as f32) }
|
||||
}
|
||||
8
crates/xserv-tensor/src/lib.rs
Normal file
8
crates/xserv-tensor/src/lib.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
pub mod dtype;
|
||||
pub mod shape;
|
||||
pub mod storage;
|
||||
pub mod tensor;
|
||||
|
||||
pub use dtype::{DType, TensorDType};
|
||||
pub use storage::Device;
|
||||
pub use tensor::Tensor;
|
||||
105
crates/xserv-tensor/src/shape.rs
Normal file
105
crates/xserv-tensor/src/shape.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
use smallvec::SmallVec;
|
||||
|
||||
pub type Dims = SmallVec<[usize; 4]>;
|
||||
|
||||
/// Compute contiguous strides for a given shape (row-major / C order).
|
||||
/// Example: shape [2, 3, 4] => strides [12, 4, 1]
|
||||
pub fn contiguous_strides(shape: &[usize]) -> Dims {
|
||||
let mut strides = SmallVec::with_capacity(shape.len());
|
||||
strides.resize(shape.len(), 0);
|
||||
if shape.is_empty() {
|
||||
return strides;
|
||||
}
|
||||
strides[shape.len() - 1] = 1;
|
||||
for i in (0..shape.len() - 1).rev() {
|
||||
strides[i] = strides[i + 1] * shape[i + 1];
|
||||
}
|
||||
strides
|
||||
}
|
||||
|
||||
/// Check if the given strides represent contiguous (row-major) layout for the shape.
|
||||
pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
|
||||
if shape.is_empty() {
|
||||
return true;
|
||||
}
|
||||
let expected = contiguous_strides(shape);
|
||||
strides == expected.as_slice()
|
||||
}
|
||||
|
||||
/// Total number of elements given a shape.
|
||||
pub fn num_elements(shape: &[usize]) -> usize {
|
||||
shape.iter().product()
|
||||
}
|
||||
|
||||
/// Compute the shape after broadcasting two shapes together (NumPy rules).
|
||||
/// Returns None if shapes are not broadcastable.
|
||||
pub fn broadcast_shape(a: &[usize], b: &[usize]) -> Option<Dims> {
|
||||
let ndim = a.len().max(b.len());
|
||||
let mut result = SmallVec::with_capacity(ndim);
|
||||
for i in 0..ndim {
|
||||
let da = if i < ndim - a.len() { 1 } else { a[i - (ndim - a.len())] };
|
||||
let db = if i < ndim - b.len() { 1 } else { b[i - (ndim - b.len())] };
|
||||
if da == db {
|
||||
result.push(da);
|
||||
} else if da == 1 {
|
||||
result.push(db);
|
||||
} else if db == 1 {
|
||||
result.push(da);
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
Some(result)
|
||||
}
|
||||
|
||||
/// Compute broadcast strides: for dimensions where size is 1 but output is >1, stride becomes 0.
|
||||
pub fn broadcast_strides(shape: &[usize], strides: &[usize], target_shape: &[usize]) -> Dims {
|
||||
let ndim = target_shape.len();
|
||||
let offset = ndim - shape.len();
|
||||
let mut result = SmallVec::with_capacity(ndim);
|
||||
for i in 0..ndim {
|
||||
if i < offset {
|
||||
result.push(0);
|
||||
} else {
|
||||
let orig_idx = i - offset;
|
||||
if shape[orig_idx] == 1 && target_shape[i] > 1 {
|
||||
result.push(0);
|
||||
} else {
|
||||
result.push(strides[orig_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_contiguous_strides() {
|
||||
assert_eq!(contiguous_strides(&[2, 3, 4]).as_slice(), &[12, 4, 1]);
|
||||
assert_eq!(contiguous_strides(&[5]).as_slice(), &[1]);
|
||||
assert_eq!(contiguous_strides(&[2, 3]).as_slice(), &[3, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_contiguous() {
|
||||
assert!(is_contiguous(&[2, 3], &[3, 1]));
|
||||
assert!(!is_contiguous(&[3, 2], &[1, 3])); // transposed
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_broadcast_shape() {
|
||||
assert_eq!(broadcast_shape(&[3, 1], &[1, 4]).unwrap().as_slice(), &[3, 4]);
|
||||
assert_eq!(broadcast_shape(&[2, 3, 4], &[4]).unwrap().as_slice(), &[2, 3, 4]);
|
||||
assert_eq!(broadcast_shape(&[1], &[5, 3]).unwrap().as_slice(), &[5, 3]);
|
||||
assert!(broadcast_shape(&[3], &[4]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_broadcast_strides() {
|
||||
// [3,1] with strides [1,1] broadcast to [3,4]
|
||||
assert_eq!(broadcast_strides(&[3, 1], &[1, 1], &[3, 4]).as_slice(), &[1, 0]);
|
||||
}
|
||||
}
|
||||
119
crates/xserv-tensor/src/storage.rs
Normal file
119
crates/xserv-tensor/src/storage.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use std::sync::Arc;
|
||||
use xserv_cuda::{GpuBuffer, Result as CudaResult};
|
||||
|
||||
enum StorageInner {
|
||||
Cpu { data: Vec<u8> },
|
||||
Cuda { buffer: GpuBuffer },
|
||||
}
|
||||
|
||||
/// Reference-counted storage for tensor data. Multiple tensors can share
|
||||
/// the same storage (e.g., after transpose or slice — view semantics).
|
||||
#[derive(Clone)]
|
||||
pub struct Storage(Arc<StorageInner>);
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Device {
|
||||
Cpu,
|
||||
Cuda(u32),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Device {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Device::Cpu => write!(f, "cpu"),
|
||||
Device::Cuda(i) => write!(f, "cuda:{i}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
pub fn cpu(data: Vec<u8>) -> Self {
|
||||
Self(Arc::new(StorageInner::Cpu { data }))
|
||||
}
|
||||
|
||||
pub fn cuda(buffer: GpuBuffer) -> Self {
|
||||
Self(Arc::new(StorageInner::Cuda { buffer }))
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device {
|
||||
match self.0.as_ref() {
|
||||
StorageInner::Cpu { .. } => Device::Cpu,
|
||||
StorageInner::Cuda { .. } => Device::Cuda(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len_bytes(&self) -> usize {
|
||||
match self.0.as_ref() {
|
||||
StorageInner::Cpu { data } => data.len(),
|
||||
StorageInner::Cuda { buffer } => buffer.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a read-only view of CPU data. Panics if storage is on GPU.
|
||||
pub fn as_cpu_bytes(&self) -> &[u8] {
|
||||
match self.0.as_ref() {
|
||||
StorageInner::Cpu { data } => data,
|
||||
StorageInner::Cuda { .. } => panic!("cannot access GPU storage as CPU bytes"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gpu_buffer(&self) -> &GpuBuffer {
|
||||
match self.0.as_ref() {
|
||||
StorageInner::Cuda { buffer } => buffer,
|
||||
StorageInner::Cpu { .. } => panic!("cannot access CPU storage as GPU buffer"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy to a different device. If already on the target device, clones the Arc (no copy).
|
||||
pub fn to_device(&self, target: Device) -> CudaResult<Self> {
|
||||
let current = self.device();
|
||||
if current == target {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
match (current, target) {
|
||||
(Device::Cpu, Device::Cuda(_dev)) => {
|
||||
let cpu_data = self.as_cpu_bytes();
|
||||
let mut buf = GpuBuffer::alloc(cpu_data.len())?;
|
||||
buf.copy_from_host(cpu_data)?;
|
||||
Ok(Storage::cuda(buf))
|
||||
}
|
||||
(Device::Cuda(_), Device::Cpu) => {
|
||||
let gpu_buf = self.gpu_buffer();
|
||||
let mut data = vec![0u8; gpu_buf.len()];
|
||||
gpu_buf.copy_to_host(&mut data)?;
|
||||
Ok(Storage::cpu(data))
|
||||
}
|
||||
(Device::Cuda(_), Device::Cuda(_)) => {
|
||||
let src = self.gpu_buffer();
|
||||
let mut dst = GpuBuffer::alloc(src.len())?;
|
||||
dst.copy_from_device(src)?;
|
||||
Ok(Storage::cuda(dst))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new owned copy of the storage on the same device.
|
||||
pub fn deep_copy(&self) -> CudaResult<Self> {
|
||||
match self.0.as_ref() {
|
||||
StorageInner::Cpu { data } => Ok(Storage::cpu(data.clone())),
|
||||
StorageInner::Cuda { buffer } => {
|
||||
let mut dst = GpuBuffer::alloc(buffer.len())?;
|
||||
dst.copy_from_device(buffer)?;
|
||||
Ok(Storage::cuda(dst))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate zeroed storage on the given device.
|
||||
pub fn zeros(len_bytes: usize, device: Device) -> CudaResult<Self> {
|
||||
match device {
|
||||
Device::Cpu => Ok(Storage::cpu(vec![0u8; len_bytes])),
|
||||
Device::Cuda(_) => {
|
||||
let mut buf = GpuBuffer::alloc(len_bytes)?;
|
||||
buf.zero()?;
|
||||
Ok(Storage::cuda(buf))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
228
crates/xserv-tensor/src/tensor.rs
Normal file
228
crates/xserv-tensor/src/tensor.rs
Normal file
@@ -0,0 +1,228 @@
|
||||
use crate::dtype::{DType, TensorDType};
|
||||
use crate::shape::{self, Dims};
|
||||
use crate::storage::{Device, Storage};
|
||||
|
||||
/// Multi-dimensional array with CPU or GPU storage.
|
||||
///
|
||||
/// Tensors support view semantics: transpose, slice, etc. share
|
||||
/// the underlying storage and only change shape/strides/offset.
|
||||
#[derive(Clone)]
|
||||
pub struct Tensor {
|
||||
storage: Storage,
|
||||
shape: Dims,
|
||||
strides: Dims,
|
||||
offset: usize,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
// --- Creation ---
|
||||
|
||||
pub fn from_slice<T: TensorDType>(data: &[T], shape: &[usize]) -> Self {
|
||||
let numel: usize = shape.iter().product();
|
||||
assert_eq!(data.len(), numel, "data length mismatch with shape");
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, numel * T::DTYPE.size_bytes())
|
||||
};
|
||||
Self {
|
||||
storage: Storage::cpu(bytes.to_vec()),
|
||||
shape: Dims::from_slice(shape),
|
||||
strides: shape::contiguous_strides(shape),
|
||||
offset: 0,
|
||||
dtype: T::DTYPE,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
|
||||
let numel = shape::num_elements(shape);
|
||||
let len_bytes = numel * dtype.size_bytes();
|
||||
let storage = Storage::zeros(len_bytes, device).expect("alloc failed");
|
||||
Self {
|
||||
storage,
|
||||
shape: Dims::from_slice(shape),
|
||||
strides: shape::contiguous_strides(shape),
|
||||
offset: 0,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ones(shape: &[usize], dtype: DType) -> Self {
|
||||
let numel = shape::num_elements(shape);
|
||||
match dtype {
|
||||
DType::F32 => Self::from_slice(&vec![1.0f32; numel], shape),
|
||||
DType::F16 => Self::from_slice(&vec![half::f16::from_f32(1.0); numel], shape),
|
||||
DType::BF16 => Self::from_slice(&vec![half::bf16::from_f32(1.0); numel], shape),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Properties ---
|
||||
|
||||
pub fn shape(&self) -> &[usize] { &self.shape }
|
||||
pub fn strides(&self) -> &[usize] { &self.strides }
|
||||
pub fn dtype(&self) -> DType { self.dtype }
|
||||
pub fn ndim(&self) -> usize { self.shape.len() }
|
||||
pub fn numel(&self) -> usize { shape::num_elements(&self.shape) }
|
||||
pub fn offset(&self) -> usize { self.offset }
|
||||
|
||||
pub fn device(&self) -> Device { self.storage.device() }
|
||||
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
shape::is_contiguous(&self.shape, &self.strides)
|
||||
}
|
||||
|
||||
// --- Shape operations (view, no copy) ---
|
||||
|
||||
pub fn reshape(&self, new_shape: &[usize]) -> Self {
|
||||
assert!(self.is_contiguous(), "reshape requires contiguous tensor");
|
||||
let new_numel: usize = new_shape.iter().product();
|
||||
assert_eq!(new_numel, self.numel(), "reshape numel mismatch");
|
||||
Self {
|
||||
storage: self.storage.clone(),
|
||||
shape: Dims::from_slice(new_shape),
|
||||
strides: shape::contiguous_strides(new_shape),
|
||||
offset: self.offset,
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn transpose(&self, dim0: usize, dim1: usize) -> Self {
|
||||
assert!(dim0 < self.ndim() && dim1 < self.ndim());
|
||||
let mut new_shape = self.shape.clone();
|
||||
let mut new_strides = self.strides.clone();
|
||||
new_shape.swap(dim0, dim1);
|
||||
new_strides.swap(dim0, dim1);
|
||||
Self {
|
||||
storage: self.storage.clone(),
|
||||
shape: new_shape,
|
||||
strides: new_strides,
|
||||
offset: self.offset,
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn squeeze(&self, dim: usize) -> Self {
|
||||
assert!(dim < self.ndim() && self.shape[dim] == 1);
|
||||
let mut new_shape = self.shape.clone();
|
||||
let mut new_strides = self.strides.clone();
|
||||
new_shape.remove(dim);
|
||||
new_strides.remove(dim);
|
||||
Self {
|
||||
storage: self.storage.clone(),
|
||||
shape: new_shape,
|
||||
strides: new_strides,
|
||||
offset: self.offset,
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unsqueeze(&self, dim: usize) -> Self {
|
||||
assert!(dim <= self.ndim());
|
||||
let mut new_shape = self.shape.clone();
|
||||
let mut new_strides = self.strides.clone();
|
||||
new_shape.insert(dim, 1);
|
||||
let stride_val = if dim < self.strides.len() { self.strides[dim] } else { 1 };
|
||||
new_strides.insert(dim, stride_val);
|
||||
Self {
|
||||
storage: self.storage.clone(),
|
||||
shape: new_shape,
|
||||
strides: new_strides,
|
||||
offset: self.offset,
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
/// Make contiguous: if already contiguous, return clone (shared storage).
|
||||
/// Otherwise, copy data into a new contiguous buffer.
|
||||
pub fn contiguous(&self) -> Self {
|
||||
if self.is_contiguous() {
|
||||
return self.clone();
|
||||
}
|
||||
// Copy to contiguous layout on CPU
|
||||
assert_eq!(self.device(), Device::Cpu, "contiguous() on GPU not yet supported");
|
||||
let numel = self.numel();
|
||||
let elem_size = self.dtype.size_bytes();
|
||||
let src_bytes = self.storage.as_cpu_bytes();
|
||||
let mut dst = vec![0u8; numel * elem_size];
|
||||
// Iterate all elements using strides
|
||||
let ndim = self.ndim();
|
||||
let mut idx = vec![0usize; ndim];
|
||||
for flat in 0..numel {
|
||||
let src_offset = self.offset + idx.iter().zip(self.strides.iter()).map(|(i, s)| i * s).sum::<usize>();
|
||||
let src_byte_offset = src_offset * elem_size;
|
||||
let dst_byte_offset = flat * elem_size;
|
||||
dst[dst_byte_offset..dst_byte_offset + elem_size]
|
||||
.copy_from_slice(&src_bytes[src_byte_offset..src_byte_offset + elem_size]);
|
||||
// Increment index (rightmost first)
|
||||
for d in (0..ndim).rev() {
|
||||
idx[d] += 1;
|
||||
if idx[d] < self.shape[d] {
|
||||
break;
|
||||
}
|
||||
idx[d] = 0;
|
||||
}
|
||||
}
|
||||
Self {
|
||||
storage: Storage::cpu(dst),
|
||||
shape: self.shape.clone(),
|
||||
strides: shape::contiguous_strides(&self.shape),
|
||||
offset: 0,
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
// --- Device transfer ---
|
||||
|
||||
pub fn to_device(&self, device: Device) -> Self {
|
||||
let t = if self.is_contiguous() { self.clone() } else { self.contiguous() };
|
||||
if t.device() == device {
|
||||
return t;
|
||||
}
|
||||
let new_storage = t.storage.to_device(device).expect("device transfer failed");
|
||||
Self {
|
||||
storage: new_storage,
|
||||
shape: t.shape,
|
||||
strides: t.strides,
|
||||
offset: 0,
|
||||
dtype: t.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
// --- Data access (CPU only) ---
|
||||
|
||||
/// Read tensor data as a typed slice. Requires contiguous CPU tensor.
|
||||
pub fn as_slice<T: TensorDType>(&self) -> &[T] {
|
||||
assert_eq!(T::DTYPE, self.dtype, "dtype mismatch");
|
||||
assert!(self.is_contiguous(), "as_slice requires contiguous");
|
||||
assert_eq!(self.device(), Device::Cpu, "as_slice requires CPU");
|
||||
let bytes = self.storage.as_cpu_bytes();
|
||||
let elem_size = self.dtype.size_bytes();
|
||||
let start = self.offset * elem_size;
|
||||
let len = self.numel();
|
||||
unsafe { std::slice::from_raw_parts(bytes[start..].as_ptr() as *const T, len) }
|
||||
}
|
||||
|
||||
/// Raw pointer to storage start (for GPU kernel launch).
|
||||
pub fn data_ptr(&self) -> *const u8 {
|
||||
match self.device() {
|
||||
Device::Cpu => {
|
||||
let bytes = self.storage.as_cpu_bytes();
|
||||
unsafe { bytes.as_ptr().add(self.offset * self.dtype.size_bytes()) }
|
||||
}
|
||||
Device::Cuda(_) => {
|
||||
let buf = self.storage.gpu_buffer();
|
||||
unsafe { buf.as_ptr().add(self.offset * self.dtype.size_bytes()) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn storage(&self) -> &Storage { &self.storage }
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Tensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f, "Tensor(shape={:?}, dtype={}, device={}, contiguous={})",
|
||||
self.shape.as_slice(), self.dtype, self.device(), self.is_contiguous()
|
||||
)
|
||||
}
|
||||
}
|
||||
127
crates/xserv-tensor/tests/integration.rs
Normal file
127
crates/xserv-tensor/tests/integration.rs
Normal file
@@ -0,0 +1,127 @@
|
||||
use half::bf16;
|
||||
use xserv_tensor::*;
|
||||
|
||||
#[test]
|
||||
fn test_from_slice_and_shape() {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let t = Tensor::from_slice(&data, &[2, 3]);
|
||||
assert_eq!(t.shape(), &[2, 3]);
|
||||
assert_eq!(t.strides(), &[3, 1]);
|
||||
assert_eq!(t.numel(), 6);
|
||||
assert_eq!(t.ndim(), 2);
|
||||
assert!(t.is_contiguous());
|
||||
assert_eq!(t.dtype(), DType::F32);
|
||||
assert_eq!(t.device(), Device::Cpu);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_as_slice() {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let t = Tensor::from_slice(&data, &[4]);
|
||||
assert_eq!(t.as_slice::<f32>(), &[1.0, 2.0, 3.0, 4.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zeros_and_ones() {
|
||||
let z = Tensor::zeros(&[2, 3], DType::F32, Device::Cpu);
|
||||
assert_eq!(z.as_slice::<f32>(), &[0.0; 6]);
|
||||
|
||||
let o = Tensor::ones(&[3], DType::F32);
|
||||
assert_eq!(o.as_slice::<f32>(), &[1.0, 1.0, 1.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bf16_tensor() {
|
||||
let data: Vec<bf16> = vec![bf16::from_f32(1.0), bf16::from_f32(2.5), bf16::from_f32(-3.0)];
|
||||
let t = Tensor::from_slice(&data, &[3]);
|
||||
assert_eq!(t.dtype(), DType::BF16);
|
||||
let out = t.as_slice::<bf16>();
|
||||
assert_eq!(out[0].to_f32(), 1.0);
|
||||
assert!((out[1].to_f32() - 2.5).abs() < 0.01);
|
||||
assert_eq!(out[2].to_f32(), -3.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reshape() {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let t = Tensor::from_slice(&data, &[2, 3]);
|
||||
let t2 = t.reshape(&[3, 2]);
|
||||
assert_eq!(t2.shape(), &[3, 2]);
|
||||
assert_eq!(t2.as_slice::<f32>(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
|
||||
let t3 = t.reshape(&[6]);
|
||||
assert_eq!(t3.shape(), &[6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transpose() {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let t = Tensor::from_slice(&data, &[2, 3]);
|
||||
let tt = t.transpose(0, 1);
|
||||
assert_eq!(tt.shape(), &[3, 2]);
|
||||
assert_eq!(tt.strides(), &[1, 3]);
|
||||
assert!(!tt.is_contiguous());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contiguous_from_transpose() {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
// Original [2,3]: [[1,2,3],[4,5,6]]
|
||||
let t = Tensor::from_slice(&data, &[2, 3]);
|
||||
// Transpose to [3,2]: [[1,4],[2,5],[3,6]]
|
||||
let tt = t.transpose(0, 1);
|
||||
let tc = tt.contiguous();
|
||||
assert!(tc.is_contiguous());
|
||||
assert_eq!(tc.shape(), &[3, 2]);
|
||||
assert_eq!(tc.as_slice::<f32>(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_squeeze_unsqueeze() {
|
||||
let data = vec![1.0f32, 2.0, 3.0];
|
||||
let t = Tensor::from_slice(&data, &[1, 3]);
|
||||
let squeezed = t.squeeze(0);
|
||||
assert_eq!(squeezed.shape(), &[3]);
|
||||
|
||||
let unsqueezed = squeezed.unsqueeze(0);
|
||||
assert_eq!(unsqueezed.shape(), &[1, 3]);
|
||||
|
||||
let unsqueezed2 = squeezed.unsqueeze(1);
|
||||
assert_eq!(unsqueezed2.shape(), &[3, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cpu_to_gpu_roundtrip() {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let cpu_t = Tensor::from_slice(&data, &[2, 2]);
|
||||
let gpu_t = cpu_t.to_device(Device::Cuda(0));
|
||||
assert_eq!(gpu_t.device(), Device::Cuda(0));
|
||||
assert_eq!(gpu_t.shape(), &[2, 2]);
|
||||
|
||||
let back = gpu_t.to_device(Device::Cpu);
|
||||
assert_eq!(back.device(), Device::Cpu);
|
||||
assert_eq!(back.as_slice::<f32>(), &[1.0, 2.0, 3.0, 4.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zeros_gpu() {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
|
||||
let t = Tensor::zeros(&[4, 4], DType::F32, Device::Cuda(0));
|
||||
assert_eq!(t.device(), Device::Cuda(0));
|
||||
assert_eq!(t.shape(), &[4, 4]);
|
||||
|
||||
let cpu = t.to_device(Device::Cpu);
|
||||
assert_eq!(cpu.as_slice::<f32>(), &[0.0f32; 16]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_debug_format() {
|
||||
let t = Tensor::from_slice(&[1.0f32], &[1]);
|
||||
let dbg = format!("{:?}", t);
|
||||
assert!(dbg.contains("shape=[1]"));
|
||||
assert!(dbg.contains("f32"));
|
||||
assert!(dbg.contains("cpu"));
|
||||
}
|
||||
62
csrc/gemm/naive.cu
Normal file
62
csrc/gemm/naive.cu
Normal file
@@ -0,0 +1,62 @@
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
// Naive GEMM: each thread computes one element of C.
|
||||
// C[i][j] = sum_k A[i][k] * B[k][j]
|
||||
// All matrices are row-major.
|
||||
__global__ void gemm_naive_bf16(
|
||||
const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C,
|
||||
int M, int N, int K
|
||||
) {
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (row < M && col < N) {
|
||||
float sum = 0.0f;
|
||||
for (int k = 0; k < K; k++) {
|
||||
sum += __bfloat162float(A[row * K + k]) * __bfloat162float(B[k * N + col]);
|
||||
}
|
||||
C[row * N + col] = __float2bfloat16(sum);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gemm_naive_f32(
|
||||
const float* A, const float* B, float* C,
|
||||
int M, int N, int K
|
||||
) {
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (row < M && col < N) {
|
||||
float sum = 0.0f;
|
||||
for (int k = 0; k < K; k++) {
|
||||
sum += A[row * K + k] * B[k * N + col];
|
||||
}
|
||||
C[row * N + col] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_gemm_naive_bf16(
|
||||
const void* A, const void* B, void* C,
|
||||
int M, int N, int K, void* stream
|
||||
) {
|
||||
dim3 block(16, 16);
|
||||
dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
|
||||
gemm_naive_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
|
||||
);
|
||||
}
|
||||
|
||||
void launch_gemm_naive_f32(
|
||||
const void* A, const void* B, void* C,
|
||||
int M, int N, int K, void* stream
|
||||
) {
|
||||
dim3 block(16, 16);
|
||||
dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
|
||||
gemm_naive_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)A, (const float*)B, (float*)C, M, N, K
|
||||
);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
116
csrc/gemm/tiled.cu
Normal file
116
csrc/gemm/tiled.cu
Normal file
@@ -0,0 +1,116 @@
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
// Tiled GEMM using shared memory.
|
||||
// Each thread block loads TILE_SIZE x TILE_SIZE tiles of A and B
|
||||
// into shared memory, then computes a partial dot product.
|
||||
#define TILE_SIZE 32
|
||||
|
||||
__global__ void gemm_tiled_f32(
|
||||
const float* A, const float* B, float* C,
|
||||
int M, int N, int K
|
||||
) {
|
||||
__shared__ float As[TILE_SIZE][TILE_SIZE];
|
||||
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
|
||||
|
||||
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
|
||||
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
|
||||
// Load tile of A
|
||||
int a_col = t * TILE_SIZE + threadIdx.x;
|
||||
if (row < M && a_col < K) {
|
||||
As[threadIdx.y][threadIdx.x] = A[row * K + a_col];
|
||||
} else {
|
||||
As[threadIdx.y][threadIdx.x] = 0.0f;
|
||||
}
|
||||
|
||||
// Load tile of B
|
||||
int b_row = t * TILE_SIZE + threadIdx.y;
|
||||
if (b_row < K && col < N) {
|
||||
Bs[threadIdx.y][threadIdx.x] = B[b_row * N + col];
|
||||
} else {
|
||||
Bs[threadIdx.y][threadIdx.x] = 0.0f;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int k = 0; k < TILE_SIZE; k++) {
|
||||
sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (row < M && col < N) {
|
||||
C[row * N + col] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gemm_tiled_bf16(
|
||||
const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C,
|
||||
int M, int N, int K
|
||||
) {
|
||||
__shared__ float As[TILE_SIZE][TILE_SIZE];
|
||||
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
|
||||
|
||||
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
|
||||
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
|
||||
int a_col = t * TILE_SIZE + threadIdx.x;
|
||||
if (row < M && a_col < K) {
|
||||
As[threadIdx.y][threadIdx.x] = __bfloat162float(A[row * K + a_col]);
|
||||
} else {
|
||||
As[threadIdx.y][threadIdx.x] = 0.0f;
|
||||
}
|
||||
|
||||
int b_row = t * TILE_SIZE + threadIdx.y;
|
||||
if (b_row < K && col < N) {
|
||||
Bs[threadIdx.y][threadIdx.x] = __bfloat162float(B[b_row * N + col]);
|
||||
} else {
|
||||
Bs[threadIdx.y][threadIdx.x] = 0.0f;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int k = 0; k < TILE_SIZE; k++) {
|
||||
sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (row < M && col < N) {
|
||||
C[row * N + col] = __float2bfloat16(sum);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_gemm_tiled_f32(
|
||||
const void* A, const void* B, void* C,
|
||||
int M, int N, int K, void* stream
|
||||
) {
|
||||
dim3 block(TILE_SIZE, TILE_SIZE);
|
||||
dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
|
||||
gemm_tiled_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)A, (const float*)B, (float*)C, M, N, K
|
||||
);
|
||||
}
|
||||
|
||||
void launch_gemm_tiled_bf16(
|
||||
const void* A, const void* B, void* C,
|
||||
int M, int N, int K, void* stream
|
||||
) {
|
||||
dim3 block(TILE_SIZE, TILE_SIZE);
|
||||
dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
|
||||
gemm_tiled_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
|
||||
);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
Reference in New Issue
Block a user