- docs/01-cuda-ffi.md: added takeaways (struct layout pitfall, Rust 2024 unsafe changes, caching allocator strategy, etc.) - docs/02-tensor.md: design doc + takeaways for tensor abstraction - docs/03-gemm.md: design doc + takeaways for GEMM kernels Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
98 lines
3.4 KiB
Markdown
98 lines
3.4 KiB
Markdown
# Phase 2: Tensor Abstraction Layer — Design Document
|
||
|
||
## Goal
|
||
|
||
实现核心 Tensor 类型,支持 CPU/GPU 存储、多种数据类型、strided view 操作,作为后续所有算子和模型的数据基础。
|
||
|
||
## Module Layout
|
||
|
||
```
|
||
crates/xserv-tensor/
|
||
├── Cargo.toml
|
||
└── src/
|
||
├── lib.rs # re-exports
|
||
├── dtype.rs # DType enum, TensorDType trait
|
||
├── shape.rs # strides 计算, broadcast 规则
|
||
├── storage.rs # Storage (Arc引用计数), Device enum
|
||
└── tensor.rs # Tensor 主体: 创建, 形状操作, 设备迁移
|
||
```
|
||
|
||
## Key Design Decisions
|
||
|
||
### DType + TensorDType Trait
|
||
|
||
```rust
|
||
pub enum DType { F32, F16, BF16 }
|
||
|
||
pub trait TensorDType: Copy + Send + Sync + 'static {
|
||
const DTYPE: DType;
|
||
fn to_f64(self) -> f64;
|
||
fn from_f64(v: f64) -> Self;
|
||
}
|
||
```
|
||
|
||
- 用 `half` crate 的 `bf16`/`f16` 表示半精度类型
|
||
- `TensorDType` trait 让 `from_slice<T>` 和 `as_slice<T>` 有类型安全
|
||
- GPU kernel 中通过 `DType` dispatch 到对应的 CUDA 类型 (`__nv_bfloat16` / `float`)
|
||
|
||
### Storage 引用计数
|
||
|
||
```rust
|
||
pub struct Storage(Arc<StorageInner>);
|
||
enum StorageInner {
|
||
Cpu { data: Vec<u8> },
|
||
Cuda { buffer: GpuBuffer },
|
||
}
|
||
```
|
||
|
||
- `Arc` 引用计数让 transpose/slice/reshape 能共享底层数据(view 语义)
|
||
- 不实现 CoW(copy-on-write),view 只能读不能写
|
||
- `to_device()` 总是创建新的 Storage
|
||
|
||
### Strided Tensor
|
||
|
||
```rust
|
||
pub struct Tensor {
|
||
storage: Storage,
|
||
shape: SmallVec<[usize; 4]>,
|
||
strides: SmallVec<[usize; 4]>,
|
||
offset: usize,
|
||
dtype: DType,
|
||
}
|
||
```
|
||
|
||
- `SmallVec<[usize; 4]>` 避免大多数 tensor (≤4D) 的堆分配
|
||
- `strides` 以元素为单位(不是字节)
|
||
- `offset` 支持 slice 操作(view 到 storage 的中间位置)
|
||
- `is_contiguous()` 检查 strides 是否与 shape 匹配
|
||
- 非 contiguous 的 tensor 调 `contiguous()` 才能送入 CUDA kernel
|
||
|
||
### Broadcast 规则
|
||
|
||
实现了 NumPy-style broadcasting:
|
||
- 维度从尾部对齐
|
||
- 大小为 1 的维度可以广播到任意大小
|
||
- `broadcast_strides()` 将 size=1 维度的 stride 置为 0(虚拟广播,不复制数据)
|
||
|
||
## Test Plan
|
||
|
||
- [x] from_slice → shape/strides 正确
|
||
- [x] reshape, transpose, squeeze, unsqueeze
|
||
- [x] transpose 后 contiguous() 重排数据
|
||
- [x] BF16 tensor 的精度验证
|
||
- [x] CPU↔GPU roundtrip
|
||
- [x] zeros on GPU → 拷回 CPU 验证全 0
|
||
- [x] broadcast_shape 单元测试
|
||
|
||
## Takeaways
|
||
|
||
1. **`SmallVec` 是正确选择**:绝大多数 tensor ≤ 4D,避免了频繁堆分配。LLM 推理中常见的维度是 `[B, S, H]` (3D) 和 `[B, H, S, D]` (4D)。
|
||
|
||
2. **View 语义的取舍**:Arc 共享 storage 实现了零拷贝 transpose/reshape,但代价是无法原地修改 view 后的 tensor。对于推理引擎这是可以接受的——推理路径上大部分操作是只读的。
|
||
|
||
3. **contiguous() 的隐性开销**:非 contiguous tensor 在送入 kernel 前需要 `contiguous()` 拷贝。这意味着 `transpose → matmul` 会产生一次额外拷贝。后续优化方向:在 kernel 中直接支持 strided input。
|
||
|
||
4. **Rust 2024 edition 变化**:`unsafe fn` 内部的 unsafe 调用也需要显式 `unsafe {}` 块,`extern "C"` 块必须加 `unsafe` 前缀。这个 edition 对安全性更严格。
|
||
|
||
5. **CPU 实现先行**:先在 CPU 上验证逻辑正确性(如 contiguous 重排),再扩展到 GPU。这个策略在后续 phase 中应该继续沿用。
|