# Phase 2: Tensor Abstraction Layer — Design Document ## Goal 实现核心 Tensor 类型,支持 CPU/GPU 存储、多种数据类型、strided view 操作,作为后续所有算子和模型的数据基础。 ## Module Layout ``` crates/xserv-tensor/ ├── Cargo.toml └── src/ ├── lib.rs # re-exports ├── dtype.rs # DType enum, TensorDType trait ├── shape.rs # strides 计算, broadcast 规则 ├── storage.rs # Storage (Arc引用计数), Device enum └── tensor.rs # Tensor 主体: 创建, 形状操作, 设备迁移 ``` ## Key Design Decisions ### DType + TensorDType Trait ```rust pub enum DType { F32, F16, BF16 } pub trait TensorDType: Copy + Send + Sync + 'static { const DTYPE: DType; fn to_f64(self) -> f64; fn from_f64(v: f64) -> Self; } ``` - 用 `half` crate 的 `bf16`/`f16` 表示半精度类型 - `TensorDType` trait 让 `from_slice` 和 `as_slice` 有类型安全 - GPU kernel 中通过 `DType` dispatch 到对应的 CUDA 类型 (`__nv_bfloat16` / `float`) ### Storage 引用计数 ```rust pub struct Storage(Arc); enum StorageInner { Cpu { data: Vec }, Cuda { buffer: GpuBuffer }, } ``` - `Arc` 引用计数让 transpose/slice/reshape 能共享底层数据(view 语义) - 不实现 CoW(copy-on-write),view 只能读不能写 - `to_device()` 总是创建新的 Storage ### Strided Tensor ```rust pub struct Tensor { storage: Storage, shape: SmallVec<[usize; 4]>, strides: SmallVec<[usize; 4]>, offset: usize, dtype: DType, } ``` - `SmallVec<[usize; 4]>` 避免大多数 tensor (≤4D) 的堆分配 - `strides` 以元素为单位(不是字节) - `offset` 支持 slice 操作(view 到 storage 的中间位置) - `is_contiguous()` 检查 strides 是否与 shape 匹配 - 非 contiguous 的 tensor 调 `contiguous()` 才能送入 CUDA kernel ### Broadcast 规则 实现了 NumPy-style broadcasting: - 维度从尾部对齐 - 大小为 1 的维度可以广播到任意大小 - `broadcast_strides()` 将 size=1 维度的 stride 置为 0(虚拟广播,不复制数据) ## Test Plan - [x] from_slice → shape/strides 正确 - [x] reshape, transpose, squeeze, unsqueeze - [x] transpose 后 contiguous() 重排数据 - [x] BF16 tensor 的精度验证 - [x] CPU↔GPU roundtrip - [x] zeros on GPU → 拷回 CPU 验证全 0 - [x] broadcast_shape 单元测试 ## Takeaways 1. **`SmallVec` 是正确选择**:绝大多数 tensor ≤ 4D,避免了频繁堆分配。LLM 推理中常见的维度是 `[B, S, H]` (3D) 和 `[B, H, S, D]` (4D)。 2. **View 语义的取舍**:Arc 共享 storage 实现了零拷贝 transpose/reshape,但代价是无法原地修改 view 后的 tensor。对于推理引擎这是可以接受的——推理路径上大部分操作是只读的。 3. **contiguous() 的隐性开销**:非 contiguous tensor 在送入 kernel 前需要 `contiguous()` 拷贝。这意味着 `transpose → matmul` 会产生一次额外拷贝。后续优化方向:在 kernel 中直接支持 strided input。 4. **Rust 2024 edition 变化**:`unsafe fn` 内部的 unsafe 调用也需要显式 `unsafe {}` 块,`extern "C"` 块必须加 `unsafe` 前缀。这个 edition 对安全性更严格。 5. **CPU 实现先行**:先在 CPU 上验证逻辑正确性(如 contiguous 重排),再扩展到 GPU。这个策略在后续 phase 中应该继续沿用。