tensor: minimal Tensor crate over xtrain-cuda
New xtrain-tensor crate: DType (F32), shape/stride helpers, Arc-counted host/device Storage with CPU↔CUDA copy, and a contiguous Tensor with creation, host↔device transfer, and a scale() op driving the elementwise kernel. GPU integration tests (host↔device roundtrip + scale correctness) gated behind not(no_cuda); a thin build.rs emits the no_cuda cfg so the kernel call sites compile out locally. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
93
Cargo.lock
generated
93
Cargo.lock
generated
@@ -12,21 +12,114 @@ dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
|
||||
|
||||
[[package]]
|
||||
name = "crunchy"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5"
|
||||
|
||||
[[package]]
|
||||
name = "find-msvc-tools"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582"
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crunchy",
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.106"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.45"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8fadd59c855ef2080decdef8ff161eb6661b86933c9d82e5ba29dc602a55aba"
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.15.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ed6a63f02c8539c91a8685a86f4099661ba3da017932f6ebbea6de3f0fa7c90"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.117"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
|
||||
|
||||
[[package]]
|
||||
name = "xtrain-cuda"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xtrain-tensor"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"half",
|
||||
"smallvec",
|
||||
"xtrain-cuda",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.8.52"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ce1022995ff5ff5d841ad7d994facc23098cd40152f2c1d11cd607c6f530653f"
|
||||
dependencies = [
|
||||
"zerocopy-derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy-derive"
|
||||
version = "0.8.52"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ae7f38b72ec2a254e2b87ef277cf2cd4fb97cbebf944faa6f33354da0867930"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
@@ -2,9 +2,14 @@
|
||||
resolver = "2"
|
||||
members = [
|
||||
"crates/xtrain-cuda",
|
||||
"crates/xtrain-tensor",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
license = "MIT"
|
||||
|
||||
[workspace.dependencies]
|
||||
half = "2"
|
||||
smallvec = "1"
|
||||
|
||||
9
crates/xtrain-tensor/Cargo.toml
Normal file
9
crates/xtrain-tensor/Cargo.toml
Normal file
@@ -0,0 +1,9 @@
|
||||
[package]
|
||||
name = "xtrain-tensor"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
xtrain-cuda = { path = "../xtrain-cuda" }
|
||||
half.workspace = true
|
||||
smallvec.workspace = true
|
||||
26
crates/xtrain-tensor/build.rs
Normal file
26
crates/xtrain-tensor/build.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use std::env;
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
// xtrain-tensor calls GPU kernels (via xtrain-cuda's FFI), so it gates those
|
||||
// call sites behind `not(no_cuda)` — the same convention xtrain-cuda uses. This
|
||||
// build script only detects nvcc and emits that cfg; it compiles no CUDA itself
|
||||
// (the kernels are built by xtrain-cuda's build.rs).
|
||||
fn main() {
|
||||
println!("cargo:rustc-check-cfg=cfg(no_cuda)");
|
||||
|
||||
let cuda_path = env::var("CUDA_HOME")
|
||||
.or_else(|_| env::var("CUDA_PATH"))
|
||||
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
|
||||
|
||||
if !nvcc_available(&cuda_path) {
|
||||
println!("cargo:rustc-cfg=no_cuda");
|
||||
}
|
||||
}
|
||||
|
||||
fn nvcc_available(cuda_path: &str) -> bool {
|
||||
if Command::new("nvcc").arg("--version").output().is_ok() {
|
||||
return true;
|
||||
}
|
||||
Path::new(&format!("{cuda_path}/bin/nvcc")).exists()
|
||||
}
|
||||
47
crates/xtrain-tensor/src/dtype.rs
Normal file
47
crates/xtrain-tensor/src/dtype.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
//! Tensor data types.
|
||||
//!
|
||||
//! T2 only needs `F32`, but the enum + `TensorDType` trait are structured so
|
||||
//! half-precision types (F16/BF16) can be added later (T7 mixed precision)
|
||||
//! without touching call sites.
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
F32,
|
||||
}
|
||||
|
||||
impl DType {
|
||||
pub fn size_bytes(self) -> usize {
|
||||
match self {
|
||||
DType::F32 => 4,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(self) -> &'static str {
|
||||
match self {
|
||||
DType::F32 => "f32",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.name())
|
||||
}
|
||||
}
|
||||
|
||||
/// Rust types that can back a tensor. Gives `from_slice`/`as_slice` type safety.
|
||||
pub trait TensorDType: Copy + Send + Sync + 'static {
|
||||
const DTYPE: DType;
|
||||
fn to_f64(self) -> f64;
|
||||
fn from_f64(v: f64) -> Self;
|
||||
}
|
||||
|
||||
impl TensorDType for f32 {
|
||||
const DTYPE: DType = DType::F32;
|
||||
fn to_f64(self) -> f64 {
|
||||
self as f64
|
||||
}
|
||||
fn from_f64(v: f64) -> Self {
|
||||
v as f32
|
||||
}
|
||||
}
|
||||
15
crates/xtrain-tensor/src/lib.rs
Normal file
15
crates/xtrain-tensor/src/lib.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
//! Minimal tensor abstraction for xtrain (Phase T2).
|
||||
//!
|
||||
//! Provides a `DType`, shape/stride helpers, reference-counted host/device
|
||||
//! `Storage`, and a `Tensor` with creation, host↔device transfer, and one
|
||||
//! elementwise CUDA op (`scale`) wired end-to-end.
|
||||
|
||||
pub mod dtype;
|
||||
pub mod shape;
|
||||
pub mod storage;
|
||||
pub mod tensor;
|
||||
|
||||
pub use dtype::{DType, TensorDType};
|
||||
pub use shape::Dims;
|
||||
pub use storage::{Device, Storage};
|
||||
pub use tensor::Tensor;
|
||||
57
crates/xtrain-tensor/src/shape.rs
Normal file
57
crates/xtrain-tensor/src/shape.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
//! Shape / stride helpers. Strides are in **elements** (not bytes), row-major.
|
||||
|
||||
use smallvec::SmallVec;
|
||||
|
||||
/// Inline storage for the common ≤4D case; spills to the heap beyond that.
|
||||
pub type Dims = SmallVec<[usize; 4]>;
|
||||
|
||||
/// Row-major (C order) contiguous strides for a shape.
|
||||
/// Example: `[2, 3, 4]` => `[12, 4, 1]`.
|
||||
pub fn contiguous_strides(shape: &[usize]) -> Dims {
|
||||
let mut strides: Dims = SmallVec::with_capacity(shape.len());
|
||||
strides.resize(shape.len(), 0);
|
||||
if shape.is_empty() {
|
||||
return strides;
|
||||
}
|
||||
strides[shape.len() - 1] = 1;
|
||||
for i in (0..shape.len() - 1).rev() {
|
||||
strides[i] = strides[i + 1] * shape[i + 1];
|
||||
}
|
||||
strides
|
||||
}
|
||||
|
||||
/// True if `strides` describe a row-major contiguous layout for `shape`.
|
||||
/// A mismatched stride on a size-1 dimension is fine (it is never stepped).
|
||||
pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
|
||||
let ndim = shape.len();
|
||||
let mut expected = 1usize;
|
||||
for d in (0..ndim).rev() {
|
||||
if shape[d] != 1 && strides[d] != expected {
|
||||
return false;
|
||||
}
|
||||
expected *= shape[d];
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Total element count.
|
||||
pub fn num_elements(shape: &[usize]) -> usize {
|
||||
shape.iter().product()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn contiguous_strides_basic() {
|
||||
assert_eq!(contiguous_strides(&[2, 3, 4]).as_slice(), &[12, 4, 1]);
|
||||
assert_eq!(contiguous_strides(&[5]).as_slice(), &[1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_contiguous_detects_transpose() {
|
||||
assert!(is_contiguous(&[2, 3], &[3, 1]));
|
||||
assert!(!is_contiguous(&[3, 2], &[1, 3]));
|
||||
}
|
||||
}
|
||||
109
crates/xtrain-tensor/src/storage.rs
Normal file
109
crates/xtrain-tensor/src/storage.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
//! Tensor storage: host (CPU) bytes or a GPU buffer, reference-counted so
|
||||
//! views (clones with different shape/strides) can share the backing data.
|
||||
|
||||
use std::sync::Arc;
|
||||
use xtrain_cuda::{GpuBuffer, Result as CudaResult};
|
||||
|
||||
enum StorageInner {
|
||||
Cpu { data: Vec<u8> },
|
||||
Cuda { buffer: GpuBuffer, device: u32 },
|
||||
}
|
||||
|
||||
/// Reference-counted tensor storage. Cloning is cheap (bumps the `Arc`).
|
||||
#[derive(Clone)]
|
||||
pub struct Storage(Arc<StorageInner>);
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Device {
|
||||
Cpu,
|
||||
Cuda(u32),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Device {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Device::Cpu => write!(f, "cpu"),
|
||||
Device::Cuda(i) => write!(f, "cuda:{i}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
pub fn cpu(data: Vec<u8>) -> Self {
|
||||
Self(Arc::new(StorageInner::Cpu { data }))
|
||||
}
|
||||
|
||||
pub fn cuda(buffer: GpuBuffer, device: u32) -> Self {
|
||||
Self(Arc::new(StorageInner::Cuda { buffer, device }))
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device {
|
||||
match self.0.as_ref() {
|
||||
StorageInner::Cpu { .. } => Device::Cpu,
|
||||
StorageInner::Cuda { device, .. } => Device::Cuda(*device),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len_bytes(&self) -> usize {
|
||||
match self.0.as_ref() {
|
||||
StorageInner::Cpu { data } => data.len(),
|
||||
StorageInner::Cuda { buffer, .. } => buffer.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Read-only view of CPU bytes. Panics if the storage lives on the GPU.
|
||||
pub fn as_cpu_bytes(&self) -> &[u8] {
|
||||
match self.0.as_ref() {
|
||||
StorageInner::Cpu { data } => data,
|
||||
StorageInner::Cuda { .. } => panic!("cannot read GPU storage as CPU bytes"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Borrow the GPU buffer. Panics if the storage lives on the CPU.
|
||||
pub fn gpu_buffer(&self) -> &GpuBuffer {
|
||||
match self.0.as_ref() {
|
||||
StorageInner::Cuda { buffer, .. } => buffer,
|
||||
StorageInner::Cpu { .. } => panic!("cannot read CPU storage as GPU buffer"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy to another device. Returns a clone of the `Arc` when already there.
|
||||
/// T2 supports CPU↔CUDA(0); device-to-device copy across GPUs is out of scope.
|
||||
pub fn to_device(&self, target: Device) -> CudaResult<Self> {
|
||||
let current = self.device();
|
||||
if current == target {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
match (current, target) {
|
||||
(Device::Cpu, Device::Cuda(dev)) => {
|
||||
let host = self.as_cpu_bytes();
|
||||
let mut buf = GpuBuffer::alloc(host.len())?;
|
||||
buf.copy_from_host(host)?;
|
||||
Ok(Storage::cuda(buf, dev))
|
||||
}
|
||||
(Device::Cuda(_), Device::Cpu) => {
|
||||
let src = self.gpu_buffer();
|
||||
let mut host = vec![0u8; src.len()];
|
||||
src.copy_to_host(&mut host)?;
|
||||
Ok(Storage::cpu(host))
|
||||
}
|
||||
(Device::Cuda(_), Device::Cuda(_)) => {
|
||||
panic!("cross-GPU storage transfer is not supported in T2")
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Zeroed storage on the given device.
|
||||
pub fn zeros(len_bytes: usize, device: Device) -> CudaResult<Self> {
|
||||
match device {
|
||||
Device::Cpu => Ok(Storage::cpu(vec![0u8; len_bytes])),
|
||||
Device::Cuda(dev) => {
|
||||
// No device memset in T2: stage zeros from the host.
|
||||
let mut buf = GpuBuffer::alloc(len_bytes)?;
|
||||
buf.copy_from_host(&vec![0u8; len_bytes])?;
|
||||
Ok(Storage::cuda(buf, dev))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
194
crates/xtrain-tensor/src/tensor.rs
Normal file
194
crates/xtrain-tensor/src/tensor.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
//! The `Tensor` type: shape/strides/dtype over reference-counted [`Storage`],
|
||||
//! with host↔device transfer and one elementwise op (`scale`) wired end-to-end
|
||||
//! through a CUDA kernel.
|
||||
|
||||
use crate::dtype::{DType, TensorDType};
|
||||
use crate::shape::{self, Dims};
|
||||
use crate::storage::{Device, Storage};
|
||||
|
||||
/// Multi-dimensional array backed by CPU or GPU storage.
|
||||
///
|
||||
/// Strides are in elements (row-major). T2 tensors created here are always
|
||||
/// contiguous; the `strides`/`offset` fields exist so later phases can add
|
||||
/// zero-copy views without changing this type's shape.
|
||||
#[derive(Clone)]
|
||||
pub struct Tensor {
|
||||
storage: Storage,
|
||||
shape: Dims,
|
||||
strides: Dims,
|
||||
offset: usize,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
// --- Creation ---
|
||||
|
||||
/// Build a contiguous CPU tensor from a typed host slice.
|
||||
pub fn from_slice<T: TensorDType>(data: &[T], shape: &[usize]) -> Self {
|
||||
let numel = shape::num_elements(shape);
|
||||
assert_eq!(
|
||||
data.len(),
|
||||
numel,
|
||||
"data length {} != shape numel {numel}",
|
||||
data.len()
|
||||
);
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, numel * T::DTYPE.size_bytes())
|
||||
};
|
||||
Self {
|
||||
storage: Storage::cpu(bytes.to_vec()),
|
||||
shape: Dims::from_slice(shape),
|
||||
strides: shape::contiguous_strides(shape),
|
||||
offset: 0,
|
||||
dtype: T::DTYPE,
|
||||
}
|
||||
}
|
||||
|
||||
/// Zero-filled contiguous tensor on the given device.
|
||||
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
|
||||
let len_bytes = shape::num_elements(shape) * dtype.size_bytes();
|
||||
let storage = Storage::zeros(len_bytes, device).expect("zeros alloc failed");
|
||||
Self {
|
||||
storage,
|
||||
shape: Dims::from_slice(shape),
|
||||
strides: shape::contiguous_strides(shape),
|
||||
offset: 0,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
// --- Properties ---
|
||||
|
||||
pub fn shape(&self) -> &[usize] {
|
||||
&self.shape
|
||||
}
|
||||
pub fn strides(&self) -> &[usize] {
|
||||
&self.strides
|
||||
}
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
pub fn ndim(&self) -> usize {
|
||||
self.shape.len()
|
||||
}
|
||||
pub fn numel(&self) -> usize {
|
||||
shape::num_elements(&self.shape)
|
||||
}
|
||||
pub fn offset(&self) -> usize {
|
||||
self.offset
|
||||
}
|
||||
pub fn device(&self) -> Device {
|
||||
self.storage.device()
|
||||
}
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
shape::is_contiguous(&self.shape, &self.strides)
|
||||
}
|
||||
pub fn storage(&self) -> &Storage {
|
||||
&self.storage
|
||||
}
|
||||
|
||||
// --- Device transfer ---
|
||||
|
||||
/// Move (copy) the tensor to `device`. Returns a cheap clone if already there.
|
||||
pub fn to_device(&self, device: Device) -> Self {
|
||||
if self.device() == device {
|
||||
return self.clone();
|
||||
}
|
||||
let storage = self
|
||||
.storage
|
||||
.to_device(device)
|
||||
.expect("device transfer failed");
|
||||
Self {
|
||||
storage,
|
||||
shape: self.shape.clone(),
|
||||
strides: self.strides.clone(),
|
||||
offset: self.offset,
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
// --- Host data access (CPU only) ---
|
||||
|
||||
/// Typed read-only view of the data. Requires a contiguous CPU tensor.
|
||||
pub fn as_slice<T: TensorDType>(&self) -> &[T] {
|
||||
assert_eq!(T::DTYPE, self.dtype, "dtype mismatch");
|
||||
assert_eq!(self.device(), Device::Cpu, "as_slice requires CPU tensor");
|
||||
assert!(self.is_contiguous(), "as_slice requires contiguous tensor");
|
||||
let bytes = self.storage.as_cpu_bytes();
|
||||
let start = self.offset * self.dtype.size_bytes();
|
||||
unsafe { std::slice::from_raw_parts(bytes[start..].as_ptr() as *const T, self.numel()) }
|
||||
}
|
||||
|
||||
/// Raw element pointer at the tensor's offset (for kernel launches).
|
||||
pub fn data_ptr(&self) -> *const u8 {
|
||||
let byte_off = self.offset * self.dtype.size_bytes();
|
||||
match self.device() {
|
||||
Device::Cpu => unsafe { self.storage.as_cpu_bytes().as_ptr().add(byte_off) },
|
||||
Device::Cuda(_) => unsafe { self.storage.gpu_buffer().as_ptr().add(byte_off) },
|
||||
}
|
||||
}
|
||||
|
||||
// --- Elementwise op (the T2 end-to-end kernel) ---
|
||||
|
||||
/// Out-of-place elementwise scale: returns a new tensor `out[i] = self[i] * alpha`.
|
||||
///
|
||||
/// Runs the `scale_f32` CUDA kernel. Requires a contiguous F32 tensor on the
|
||||
/// GPU. Available only when CUDA was compiled in (`not(no_cuda)`).
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn scale(&self, alpha: f32) -> Self {
|
||||
assert_eq!(self.dtype, DType::F32, "scale only supports F32 in T2");
|
||||
assert!(self.is_contiguous(), "scale requires contiguous tensor");
|
||||
assert!(
|
||||
matches!(self.device(), Device::Cuda(_)),
|
||||
"scale requires a CUDA tensor"
|
||||
);
|
||||
|
||||
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_scale_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
alpha,
|
||||
self.numel() as i32,
|
||||
std::ptr::null_mut(), // default stream
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("scale kernel sync failed");
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Tensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Tensor(shape={:?}, dtype={}, device={}, contiguous={})",
|
||||
self.shape.as_slice(),
|
||||
self.dtype,
|
||||
self.device(),
|
||||
self.is_contiguous()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn from_slice_shape_and_data() {
|
||||
let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
|
||||
assert_eq!(t.shape(), &[2, 3]);
|
||||
assert_eq!(t.strides(), &[3, 1]);
|
||||
assert_eq!(t.numel(), 6);
|
||||
assert_eq!(t.device(), Device::Cpu);
|
||||
assert!(t.is_contiguous());
|
||||
assert_eq!(t.as_slice::<f32>(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zeros_cpu() {
|
||||
let t = Tensor::zeros(&[4], DType::F32, Device::Cpu);
|
||||
assert_eq!(t.as_slice::<f32>(), &[0.0, 0.0, 0.0, 0.0]);
|
||||
}
|
||||
}
|
||||
58
crates/xtrain-tensor/tests/integration.rs
Normal file
58
crates/xtrain-tensor/tests/integration.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
// GPU integration tests for the tensor abstraction. Both require nvcc + a GPU,
|
||||
// so they are gated behind `not(no_cuda)`. On a GPU-less machine build.rs sets
|
||||
// the `no_cuda` cfg and these compile out, keeping host `cargo check` green.
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_tensor::{Device, Tensor};
|
||||
|
||||
/// (a) Host → device → host roundtrip preserves the data exactly.
|
||||
#[test]
|
||||
fn host_device_roundtrip() {
|
||||
assert!(
|
||||
device::device_count().expect("device count") > 0,
|
||||
"no CUDA device"
|
||||
);
|
||||
device::set_device(0).unwrap();
|
||||
|
||||
let host: Vec<f32> = (0..1024).map(|i| i as f32 * 0.5).collect();
|
||||
let cpu = Tensor::from_slice(&host, &[1024]);
|
||||
|
||||
let gpu = cpu.to_device(Device::Cuda(0));
|
||||
assert_eq!(gpu.device(), Device::Cuda(0));
|
||||
assert_eq!(gpu.shape(), &[1024]);
|
||||
|
||||
let back = gpu.to_device(Device::Cpu);
|
||||
assert_eq!(back.device(), Device::Cpu);
|
||||
assert_eq!(back.as_slice::<f32>(), host.as_slice());
|
||||
println!("roundtrip OK: {} elems preserved", host.len());
|
||||
}
|
||||
|
||||
/// (b) The elementwise `scale` kernel produces correct results.
|
||||
#[test]
|
||||
fn elementwise_scale_kernel() {
|
||||
assert!(
|
||||
device::device_count().expect("device count") > 0,
|
||||
"no CUDA device"
|
||||
);
|
||||
device::set_device(0).unwrap();
|
||||
|
||||
let host: Vec<f32> = (0..2048).map(|i| i as f32).collect();
|
||||
let alpha = 3.0f32;
|
||||
let expected: Vec<f32> = host.iter().map(|x| x * alpha).collect();
|
||||
|
||||
let gpu = Tensor::from_slice(&host, &[2048]).to_device(Device::Cuda(0));
|
||||
let scaled = gpu.scale(alpha);
|
||||
let result = scaled.to_device(Device::Cpu);
|
||||
|
||||
assert_eq!(result.shape(), &[2048]);
|
||||
assert_eq!(result.as_slice::<f32>(), expected.as_slice());
|
||||
let r = result.as_slice::<f32>();
|
||||
println!(
|
||||
"scale OK (alpha={alpha}): first={} mid={} last={} ({} elems)",
|
||||
r[0],
|
||||
r[r.len() / 2],
|
||||
r[r.len() - 1],
|
||||
r.len()
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user