Files
xserv/crates/xserv-tensor/tests/integration.rs
Gahow Wang a83971fa25 phase 2: tensor abstraction layer
- DType enum (F32, F16, BF16) with TensorDType trait
- Shape utilities: contiguous_strides, broadcast_shape, broadcast_strides
- Storage with Arc reference counting (CPU Vec<u8> or GPU GpuBuffer)
- Device enum (Cpu, Cuda(id)) with to_device transfer
- Tensor type with strided layout: reshape, transpose, squeeze, unsqueeze
- contiguous() copies non-contiguous views to contiguous layout
- from_slice, zeros, ones constructors
- as_slice<T> for typed CPU read access, data_ptr for GPU kernel launch
- CPU↔GPU roundtrip verified
- All 27 tests pass (12 cuda + 4 shape + 11 tensor)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-21 19:45:22 +08:00

128 lines
3.6 KiB
Rust

use half::bf16;
use xserv_tensor::*;
#[test]
fn test_from_slice_and_shape() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = Tensor::from_slice(&data, &[2, 3]);
assert_eq!(t.shape(), &[2, 3]);
assert_eq!(t.strides(), &[3, 1]);
assert_eq!(t.numel(), 6);
assert_eq!(t.ndim(), 2);
assert!(t.is_contiguous());
assert_eq!(t.dtype(), DType::F32);
assert_eq!(t.device(), Device::Cpu);
}
#[test]
fn test_as_slice() {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let t = Tensor::from_slice(&data, &[4]);
assert_eq!(t.as_slice::<f32>(), &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_zeros_and_ones() {
let z = Tensor::zeros(&[2, 3], DType::F32, Device::Cpu);
assert_eq!(z.as_slice::<f32>(), &[0.0; 6]);
let o = Tensor::ones(&[3], DType::F32);
assert_eq!(o.as_slice::<f32>(), &[1.0, 1.0, 1.0]);
}
#[test]
fn test_bf16_tensor() {
let data: Vec<bf16> = vec![bf16::from_f32(1.0), bf16::from_f32(2.5), bf16::from_f32(-3.0)];
let t = Tensor::from_slice(&data, &[3]);
assert_eq!(t.dtype(), DType::BF16);
let out = t.as_slice::<bf16>();
assert_eq!(out[0].to_f32(), 1.0);
assert!((out[1].to_f32() - 2.5).abs() < 0.01);
assert_eq!(out[2].to_f32(), -3.0);
}
#[test]
fn test_reshape() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = Tensor::from_slice(&data, &[2, 3]);
let t2 = t.reshape(&[3, 2]);
assert_eq!(t2.shape(), &[3, 2]);
assert_eq!(t2.as_slice::<f32>(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let t3 = t.reshape(&[6]);
assert_eq!(t3.shape(), &[6]);
}
#[test]
fn test_transpose() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = Tensor::from_slice(&data, &[2, 3]);
let tt = t.transpose(0, 1);
assert_eq!(tt.shape(), &[3, 2]);
assert_eq!(tt.strides(), &[1, 3]);
assert!(!tt.is_contiguous());
}
#[test]
fn test_contiguous_from_transpose() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
// Original [2,3]: [[1,2,3],[4,5,6]]
let t = Tensor::from_slice(&data, &[2, 3]);
// Transpose to [3,2]: [[1,4],[2,5],[3,6]]
let tt = t.transpose(0, 1);
let tc = tt.contiguous();
assert!(tc.is_contiguous());
assert_eq!(tc.shape(), &[3, 2]);
assert_eq!(tc.as_slice::<f32>(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_squeeze_unsqueeze() {
let data = vec![1.0f32, 2.0, 3.0];
let t = Tensor::from_slice(&data, &[1, 3]);
let squeezed = t.squeeze(0);
assert_eq!(squeezed.shape(), &[3]);
let unsqueezed = squeezed.unsqueeze(0);
assert_eq!(unsqueezed.shape(), &[1, 3]);
let unsqueezed2 = squeezed.unsqueeze(1);
assert_eq!(unsqueezed2.shape(), &[3, 1]);
}
#[test]
fn test_cpu_to_gpu_roundtrip() {
xserv_cuda::device::set_device(0).unwrap();
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let cpu_t = Tensor::from_slice(&data, &[2, 2]);
let gpu_t = cpu_t.to_device(Device::Cuda(0));
assert_eq!(gpu_t.device(), Device::Cuda(0));
assert_eq!(gpu_t.shape(), &[2, 2]);
let back = gpu_t.to_device(Device::Cpu);
assert_eq!(back.device(), Device::Cpu);
assert_eq!(back.as_slice::<f32>(), &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_zeros_gpu() {
xserv_cuda::device::set_device(0).unwrap();
let t = Tensor::zeros(&[4, 4], DType::F32, Device::Cuda(0));
assert_eq!(t.device(), Device::Cuda(0));
assert_eq!(t.shape(), &[4, 4]);
let cpu = t.to_device(Device::Cpu);
assert_eq!(cpu.as_slice::<f32>(), &[0.0f32; 16]);
}
#[test]
fn test_debug_format() {
let t = Tensor::from_slice(&[1.0f32], &[1]);
let dbg = format!("{:?}", t);
assert!(dbg.contains("shape=[1]"));
assert!(dbg.contains("f32"));
assert!(dbg.contains("cpu"));
}