use std::sync::OnceLock; use crate::dtype::{DType, TensorDType}; use crate::shape::{self, Dims}; use crate::storage::{Device, Storage}; /// Global hook for GPU strided-to-contiguous copy. /// Set by `xserv-kernels` (or any crate that provides a GPU kernel) via /// `register_gpu_contiguous`. When set, `contiguous()` on a non-contiguous /// GPU tensor calls this instead of doing a CPU round-trip. static GPU_CONTIGUOUS_FN: OnceLock Tensor> = OnceLock::new(); /// Register a function that makes a non-contiguous GPU tensor contiguous. /// Intended to be called once by the kernel crate at startup. pub fn register_gpu_contiguous(f: fn(&Tensor) -> Tensor) { let _ = GPU_CONTIGUOUS_FN.set(f); } /// Multi-dimensional array with CPU or GPU storage. /// /// Tensors support view semantics: transpose, slice, etc. share /// the underlying storage and only change shape/strides/offset. #[derive(Clone)] pub struct Tensor { storage: Storage, shape: Dims, strides: Dims, offset: usize, dtype: DType, } impl Tensor { // --- Creation --- /// Create a tensor from raw components (for advanced use like GPU KV cache). pub fn from_storage(storage: Storage, shape: Dims, strides: Dims, offset: usize, dtype: DType) -> Self { Self { storage, shape, strides, offset, dtype } } pub fn from_slice(data: &[T], shape: &[usize]) -> Self { let numel: usize = shape.iter().product(); assert_eq!(data.len(), numel, "data length mismatch with shape"); let bytes = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, numel * T::DTYPE.size_bytes()) }; Self { storage: Storage::cpu(bytes.to_vec()), shape: Dims::from_slice(shape), strides: shape::contiguous_strides(shape), offset: 0, dtype: T::DTYPE, } } pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self { let numel = shape::num_elements(shape); let len_bytes = numel * dtype.size_bytes(); let storage = Storage::zeros(len_bytes, device).expect("alloc failed"); Self { storage, shape: Dims::from_slice(shape), strides: shape::contiguous_strides(shape), offset: 0, dtype, } } pub fn ones(shape: &[usize], dtype: DType) -> Self { let numel = shape::num_elements(shape); match dtype { DType::F32 => Self::from_slice(&vec![1.0f32; numel], shape), DType::F16 => Self::from_slice(&vec![half::f16::from_f32(1.0); numel], shape), DType::BF16 => Self::from_slice(&vec![half::bf16::from_f32(1.0); numel], shape), } } // --- Properties --- pub fn shape(&self) -> &[usize] { &self.shape } pub fn strides(&self) -> &[usize] { &self.strides } pub fn dtype(&self) -> DType { self.dtype } pub fn ndim(&self) -> usize { self.shape.len() } pub fn numel(&self) -> usize { shape::num_elements(&self.shape) } pub fn offset(&self) -> usize { self.offset } pub fn device(&self) -> Device { self.storage.device() } pub fn is_contiguous(&self) -> bool { shape::is_contiguous(&self.shape, &self.strides) } // --- Shape operations (view, no copy) --- pub fn reshape(&self, new_shape: &[usize]) -> Self { assert!(self.is_contiguous(), "reshape requires contiguous tensor"); let new_numel: usize = new_shape.iter().product(); assert_eq!(new_numel, self.numel(), "reshape numel mismatch"); Self { storage: self.storage.clone(), shape: Dims::from_slice(new_shape), strides: shape::contiguous_strides(new_shape), offset: self.offset, dtype: self.dtype, } } pub fn transpose(&self, dim0: usize, dim1: usize) -> Self { assert!(dim0 < self.ndim() && dim1 < self.ndim()); let mut new_shape = self.shape.clone(); let mut new_strides = self.strides.clone(); new_shape.swap(dim0, dim1); new_strides.swap(dim0, dim1); Self { storage: self.storage.clone(), shape: new_shape, strides: new_strides, offset: self.offset, dtype: self.dtype, } } pub fn squeeze(&self, dim: usize) -> Self { assert!(dim < self.ndim() && self.shape[dim] == 1); let mut new_shape = self.shape.clone(); let mut new_strides = self.strides.clone(); new_shape.remove(dim); new_strides.remove(dim); Self { storage: self.storage.clone(), shape: new_shape, strides: new_strides, offset: self.offset, dtype: self.dtype, } } pub fn unsqueeze(&self, dim: usize) -> Self { assert!(dim <= self.ndim()); let mut new_shape = self.shape.clone(); new_shape.insert(dim, 1); let new_strides = if self.is_contiguous() { shape::contiguous_strides(&new_shape) } else { let mut s = self.strides.clone(); let stride_val = if dim < self.strides.len() { self.strides[dim] } else { 1 }; s.insert(dim, stride_val); s }; Self { storage: self.storage.clone(), shape: new_shape, strides: new_strides, offset: self.offset, dtype: self.dtype, } } /// Make contiguous: if already contiguous, return clone (shared storage). /// Otherwise, copy data into a new contiguous buffer. pub fn contiguous(&self) -> Self { if self.is_contiguous() { return self.clone(); } // For GPU tensors: use the registered GPU kernel if available, // otherwise fall back to CPU round-trip. if matches!(self.device(), Device::Cuda(_)) { if let Some(gpu_fn) = GPU_CONTIGUOUS_FN.get() { return gpu_fn(self); } let cpu = self.to_device(Device::Cpu); let contig = cpu.contiguous(); return contig.to_device(self.device()); } let numel = self.numel(); let elem_size = self.dtype.size_bytes(); let src_bytes = self.storage.as_cpu_bytes(); let mut dst = vec![0u8; numel * elem_size]; // Iterate all elements using strides let ndim = self.ndim(); let mut idx = vec![0usize; ndim]; for flat in 0..numel { let src_offset = self.offset + idx.iter().zip(self.strides.iter()).map(|(i, s)| i * s).sum::(); let src_byte_offset = src_offset * elem_size; let dst_byte_offset = flat * elem_size; dst[dst_byte_offset..dst_byte_offset + elem_size] .copy_from_slice(&src_bytes[src_byte_offset..src_byte_offset + elem_size]); // Increment index (rightmost first) for d in (0..ndim).rev() { idx[d] += 1; if idx[d] < self.shape[d] { break; } idx[d] = 0; } } Self { storage: Storage::cpu(dst), shape: self.shape.clone(), strides: shape::contiguous_strides(&self.shape), offset: 0, dtype: self.dtype, } } // --- Device transfer --- pub fn to_device(&self, device: Device) -> Self { if self.device() == device { return self.clone(); } // Transfer the raw storage (preserving strides/offset). // Non-contiguous layout is preserved — the user can call contiguous() after. let new_storage = self.storage.to_device(device).expect("device transfer failed"); Self { storage: new_storage, shape: self.shape.clone(), strides: self.strides.clone(), offset: self.offset, dtype: self.dtype, } } // --- Data access (CPU only) --- /// Read tensor data as a typed slice. Requires contiguous CPU tensor. pub fn as_slice(&self) -> &[T] { assert_eq!(T::DTYPE, self.dtype, "dtype mismatch"); assert!(self.is_contiguous(), "as_slice requires contiguous"); assert_eq!(self.device(), Device::Cpu, "as_slice requires CPU"); let bytes = self.storage.as_cpu_bytes(); let elem_size = self.dtype.size_bytes(); let start = self.offset * elem_size; let len = self.numel(); unsafe { std::slice::from_raw_parts(bytes[start..].as_ptr() as *const T, len) } } /// Raw pointer to storage start (for GPU kernel launch). pub fn data_ptr(&self) -> *const u8 { match self.device() { Device::Cpu => { let bytes = self.storage.as_cpu_bytes(); unsafe { bytes.as_ptr().add(self.offset * self.dtype.size_bytes()) } } Device::Cuda(_) => { let buf = self.storage.gpu_buffer(); unsafe { buf.as_ptr().add(self.offset * self.dtype.size_bytes()) } } } } pub fn storage(&self) -> &Storage { &self.storage } } impl std::fmt::Debug for Tensor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "Tensor(shape={:?}, dtype={}, device={}, contiguous={})", self.shape.as_slice(), self.dtype, self.device(), self.is_contiguous() ) } } #[cfg(test)] mod tests { use super::*; fn contiguous_2d() -> Tensor { Tensor::from_slice(&[1.0f32; 12], &[3, 4]) } #[test] fn unsqueeze_dim0_contiguous() { let t = contiguous_2d(); let u = t.unsqueeze(0); assert_eq!(u.shape(), &[1, 3, 4]); assert!(u.is_contiguous()); assert_eq!(u.strides(), &[12, 4, 1]); } #[test] fn unsqueeze_dim1_contiguous() { let t = contiguous_2d(); let u = t.unsqueeze(1); assert_eq!(u.shape(), &[3, 1, 4]); assert!(u.is_contiguous()); assert_eq!(u.strides(), &[4, 4, 1]); } #[test] fn unsqueeze_dim2_contiguous() { let t = contiguous_2d(); let u = t.unsqueeze(2); assert_eq!(u.shape(), &[3, 4, 1]); assert!(u.is_contiguous()); assert_eq!(u.strides(), &[4, 1, 1]); } #[test] fn unsqueeze_noncontiguous() { // Transpose makes [3,4] into [4,3] with strides [1,4] (non-contiguous) let t = contiguous_2d().transpose(0, 1); assert!(!t.is_contiguous()); let u = t.unsqueeze(0); assert_eq!(u.shape(), &[1, 4, 3]); // Non-contiguous path: stride_val copied from strides[0]=1 assert_eq!(u.strides(), &[1, 1, 4]); } #[test] fn unsqueeze_squeeze_roundtrip() { let t = contiguous_2d(); let u = t.unsqueeze(1).squeeze(1); assert_eq!(u.shape(), t.shape()); assert!(u.is_contiguous()); } }