Compare commits
123 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6309dc1181 | |||
| 264c004662 | |||
| 2fe903ecea | |||
| aac9ace144 | |||
| 6da0972740 | |||
| 40d8a29e33 | |||
| fd392f7fbb | |||
| 10a98539d0 | |||
| cc3bc2188c | |||
| 06a798cab9 | |||
| 9a1af0adee | |||
| d2c55c47b2 | |||
| 14925154a3 | |||
| a24621fa6a | |||
| 68b55fa1e6 | |||
| 8f11d6e5cd | |||
| e04a8ffb18 | |||
| 6485c87c5b | |||
| a77239c0c8 | |||
| e5734b41fa | |||
| 42e13f33dd | |||
| fcf531a9b2 | |||
| d96ee0766c | |||
| ce10e4a998 | |||
| 5f060902f6 | |||
| a67753f516 | |||
| f5ec10c2c3 | |||
| ce7229f4fe | |||
| 5b350ee5f0 | |||
| 0314b4f3ac | |||
| cfbd64d206 | |||
| 531cd3fe08 | |||
| 013465fc06 | |||
| 8414f8d1e6 | |||
| 34224c7c93 | |||
| 4088f49b7d | |||
| 2a92f268a9 | |||
| 5343391dbd | |||
| 1897b2e17a | |||
| 63f5599717 | |||
| fb20178992 | |||
| cf1e9e41db | |||
| d33220498a | |||
| e631a71b68 | |||
| 24c49c31c2 | |||
| 5a16225c1f | |||
| 3a530956af | |||
| 76487b7963 | |||
| 9f1fbbb98b | |||
| e1eb77baa4 | |||
| 34e9bee375 | |||
| 3b9e32e6cd | |||
| 5157b2cd30 | |||
| ea5d8ba7ea | |||
| c0a81c84e7 | |||
| 3d6bb1918e | |||
| f2e60218b4 | |||
| 3ee8df2c0f | |||
| ae08896f46 | |||
|
|
1d0ec32e8d | ||
|
|
4368e79695 | ||
|
|
377a04b81f | ||
|
|
241009a96c | ||
| 0c6135aea3 | |||
| ffd90ce7fb | |||
| 3c9d5e260e | |||
| 99b212e6c1 | |||
| e11f15e009 | |||
| 9c98c169ff | |||
|
|
5cb3cf28f9 | ||
|
|
15c51f143e | ||
|
|
d29c39d74e | ||
|
|
9ad91a4a92 | ||
|
|
46bfb59f30 | ||
|
|
9a01c60100 | ||
|
|
c679f618fd | ||
|
|
cc4bd4cfe5 | ||
|
|
13ae3de69e | ||
|
|
6ce21345be | ||
|
|
1ab6ca9c09 | ||
| 11e0154e4d | |||
| d5dcf1a5ab | |||
| 824cc58daa | |||
| da3aaa134a | |||
| 859c0cc0b6 | |||
| c2362df1f1 | |||
| 7b8b520cda | |||
| a4a171d425 | |||
| 95eb61d639 | |||
| f17011129e | |||
| 453520d622 | |||
| 76fffb3b68 | |||
| 14a44b503e | |||
| 80157e614a | |||
| fc1900a745 | |||
| d52baa0006 | |||
| 4c3f914459 | |||
| 3f1c3d429a | |||
| 950ccf3822 | |||
| 7cb9ee3870 | |||
| 49c7653222 | |||
| 9bb5c5c328 | |||
| 986a289616 | |||
| a67e724119 | |||
| d5532ef209 | |||
| e207523e21 | |||
| 876d3f5d6a | |||
| 9783fcf410 | |||
| 6cc1c9332d | |||
| d67dda404e | |||
| ee68d3565d | |||
| d8493bd70f | |||
| 7d05ececa0 | |||
| da043554ba | |||
| 2be27d6d94 | |||
| 2d48f25e66 | |||
| be5c64ea8a | |||
| 268e40d764 | |||
| 246ae1c590 | |||
| 64084d3489 | |||
| cb12250ef0 | |||
| e1e75fc7f6 | |||
| 6035ffdc0b |
16
.gitignore
vendored
16
.gitignore
vendored
@@ -7,3 +7,19 @@
|
||||
**/*.rs.bk
|
||||
.env
|
||||
*.npy
|
||||
|
||||
# llama.cpp baseline (cloned/submoduled by tools/setup-llama-cpp.sh)
|
||||
/third_party/llama.cpp/build/
|
||||
/third_party/llama.cpp/models/
|
||||
*.gguf
|
||||
|
||||
# Claude Code runtime state
|
||||
/.claude/
|
||||
|
||||
# Benchmark output + fetched datasets (transferred to GPU host, not committed)
|
||||
/bench-out/
|
||||
/tools/bench/data/
|
||||
/tools/__pycache__/
|
||||
/tools/bench/__pycache__/
|
||||
/tools/bench/**/__pycache__/
|
||||
|
||||
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "third_party/llama.cpp"]
|
||||
path = third_party/llama.cpp
|
||||
url = https://github.com/ggerganov/llama.cpp
|
||||
1214
Cargo.lock
generated
Normal file
1214
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
15
Cargo.toml
15
Cargo.toml
@@ -4,6 +4,10 @@ members = [
|
||||
"crates/xserv-cuda",
|
||||
"crates/xserv-tensor",
|
||||
"crates/xserv-kernels",
|
||||
"crates/xserv-model",
|
||||
"crates/xserv-tokenizer",
|
||||
"crates/xserv-server",
|
||||
"crates/xserv-distributed",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -14,3 +18,14 @@ license = "MIT"
|
||||
[workspace.dependencies]
|
||||
half = "2"
|
||||
smallvec = "1"
|
||||
libc = "0.2"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
safetensors = "0.5"
|
||||
regex = "1"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
axum = "0.8"
|
||||
uuid = { version = "1", features = ["v4"] }
|
||||
tokio-stream = "0.1"
|
||||
rand = "0.8"
|
||||
minijinja = { version = "2", features = ["builtins"] }
|
||||
|
||||
208
README.md
Normal file
208
README.md
Normal file
@@ -0,0 +1,208 @@
|
||||
# xserv
|
||||
|
||||
> 从零用 **Rust + CUDA** 构建的 LLM 推理引擎,目标是吃透 LLM Serving 全栈技术。
|
||||
|
||||
xserv 不依赖 PyTorch / vLLM / TensorRT 等现成框架,自己实现了张量抽象、CUDA kernel、
|
||||
分词器、模型前向、KV cache、调度器和 OpenAI 兼容的 HTTP 服务。支持 **Qwen3-8B**(BF16)
|
||||
和 **gpt-oss-20b**(MoE,BF16/FP8/MXFP4 量化),多卡 TP/PP,并提供一套与 **llama.cpp**
|
||||
对比正确性和性能的标准 benchmark。
|
||||
|
||||
## 现状一览
|
||||
|
||||
- **模型**:GPT-2(124M)、Qwen3-8B(BF16)、gpt-oss-20b(32 专家 top-4 MoE,harmony 格式)
|
||||
- **性能**(RTX 5090,贪心,单流):
|
||||
- Qwen3-8B BF16 单卡:约 56 tok/s(HF transformers 的 1.4×)
|
||||
- gpt-oss-20b FP8 稀疏 MoE + CUDA Graph decode:**TPOT 5.8ms(~172 tok/s,
|
||||
TP=1/2 同速)**;同配置 TP=2 全面快于 llama.cpp(1.26-1.47×),llama
|
||||
单卡模式(2.8ms)仍领先,差距 2.0×
|
||||
- **精度**:GSM8K 全量与 llama.cpp 同权重持平(94.5% vs 94.4%);FP8/MXFP4 量化无回归
|
||||
- **服务**:OpenAI 兼容 `/v1/chat/completions`,SSE 流式;gpt-oss 量化后可**单卡 32GB 服务**
|
||||
- **关键能力**:自写 GEMM / Flash-Attention 2(SM120,含 attention sinks + sliding window) /
|
||||
Paged-Attention kernel、分页 KV cache(含 **CPU 换出/换入**)、连续批处理、
|
||||
CUDA Graph 解码(Qwen3 单卡 + gpt-oss 全路径整图回放)、**Tensor/Pipeline 并行**(NCCL,TP=1/2/4、PP=2/4)、
|
||||
**FP8 W8A8 / MXFP4 W4A16 量化**、**稀疏 top-k MoE decode**(只算被路由的专家)
|
||||
|
||||
> 这是一个以学习为主的项目,逐 Phase 推进,每步都做数值/端到端验证。
|
||||
|
||||
## 架构
|
||||
|
||||
```
|
||||
xserv/
|
||||
├── csrc/ # CUDA 源码 (.cu/.cuh)
|
||||
│ ├── gemm/ # GEMM (naive / tiled / gemv)
|
||||
│ ├── attention/ # Flash-Attention 2 (SM120)、Paged-Attention、causal mask
|
||||
│ ├── normalization/ # LayerNorm / RMSNorm
|
||||
│ ├── activation/ # GELU / SiLU / gpt-oss GLU
|
||||
│ ├── embedding/ # embedding lookup / RoPE / transpose
|
||||
│ ├── moe/ # MoE top-k 路由、稀疏专家 GEMV、加权求和
|
||||
│ ├── quantization/ # FP8 量化/反量化、cuBLASLt FP8 GEMM、MXFP4 GEMV
|
||||
│ └── reduce/ # softmax
|
||||
├── crates/
|
||||
│ ├── xserv-cuda/ # CUDA FFI、Stream、显存分配器、Pinned 内存、CUDA Graph
|
||||
│ ├── xserv-tensor/ # Tensor 类型(strided 布局、BF16/F16/F32、CPU↔GPU)
|
||||
│ ├── xserv-kernels/ # kernel registry(自写 kernel + cuBLAS 可切换)
|
||||
│ ├── xserv-tokenizer/ # BPE 分词器
|
||||
│ ├── xserv-distributed/ # NCCL FFI、TP 上下文(AllReduce)
|
||||
│ ├── xserv-model/ # 模型定义(GPT-2 / Qwen3 / gpt-oss MoE)、权重加载、KV cache、采样
|
||||
│ └── xserv-server/ # tokio + axum HTTP 服务、调度器、TP/PP 引擎
|
||||
├── tools/ # 辅助脚本 + benchmark 套件(见下)
|
||||
└── docs/ # 每个 Phase 的设计文档 + benchmark 报告
|
||||
```
|
||||
|
||||
## 环境要求
|
||||
|
||||
- **GPU**:NVIDIA,计算能力 SM120(RTX 5090 / Blackwell)。其它架构需调整 `CUDA_ARCH`。
|
||||
- **CUDA Toolkit**:12.9(`nvcc` 需在 `PATH`,构建 `.cu` 依赖它)
|
||||
- **Rust**:edition 2024(建议较新的 stable 工具链)
|
||||
- **模型**:HuggingFace 目录格式(含 `config.json`、`tokenizer.json`、`*.safetensors`)
|
||||
|
||||
## 构建
|
||||
|
||||
```bash
|
||||
export CUDA_HOME=/usr/local/cuda-12.9
|
||||
export PATH=$CUDA_HOME/bin:$PATH
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
如果本地没有 GPU/CUDA,可用远端构建脚本把代码同步到带卡的机器上构建/运行/测试:
|
||||
|
||||
```bash
|
||||
./tools/sync-and-build.sh build # 远端 cargo build --release
|
||||
./tools/sync-and-build.sh test # 远端 cargo test
|
||||
```
|
||||
|
||||
(远端主机、目录、模型路径在 `tools/sync-and-build.sh` 顶部配置。)
|
||||
|
||||
## 基本用法
|
||||
|
||||
### 1. 启动 HTTP 服务(OpenAI 兼容)
|
||||
|
||||
```bash
|
||||
./target/release/xserv-server /path/to/qwen3-8b \
|
||||
--port 8080 \
|
||||
--max-batch 4 \
|
||||
--max-seq-len 8192 \
|
||||
--swap-space-gb 8
|
||||
```
|
||||
|
||||
参数说明:
|
||||
|
||||
| 参数 | 含义 | 默认 |
|
||||
|------|------|------|
|
||||
| `--port` | 监听端口 | 8080 |
|
||||
| `--max-batch` | 解码批大小(并发上限) | 4 |
|
||||
| `--max-seq-len` | 单序列最大长度 | 2048 |
|
||||
| `--swap-space-gb` | KV 换出到 CPU 的 pinned 内存大小(0 关闭) | 8 |
|
||||
|
||||
请求示例(流式):
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "qwen3-8b",
|
||||
"messages": [{"role": "user", "content": "用一句话解释什么是注意力机制"}],
|
||||
"max_tokens": 256,
|
||||
"temperature": 0,
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
其它端点:`GET /health`、`GET /v1/models`。
|
||||
|
||||
### 2. 命令行推理
|
||||
|
||||
```bash
|
||||
# 单轮生成
|
||||
cargo run --release --bin xserv-cli -- /path/to/qwen3-8b --max-tokens 256
|
||||
|
||||
# 交互式多轮对话
|
||||
cargo run --release --bin xserv-chat -- /path/to/qwen3-8b
|
||||
```
|
||||
|
||||
### 3. 单机性能基准
|
||||
|
||||
```bash
|
||||
# 输出每个 prompt 的 TTFT / TBT / TPOT(JSON)
|
||||
cargo run --release --bin bench-qwen3 -- /path/to/qwen3-8b --gen-tokens 64 [--cuda-graph]
|
||||
```
|
||||
|
||||
## 与 llama.cpp 对比 benchmark
|
||||
|
||||
`tools/bench/` 提供一套一键对比套件,把 xserv 和 **llama.cpp**(同一份 BF16 权重)放在
|
||||
相同负载下,黑盒通过 OpenAI API 对比:
|
||||
|
||||
- **性能**:TTFT、TPOT、吞吐(单流 + 不同并发)
|
||||
- **精度**:AIME 2025、GSM8K(标准数据集,exact-match 评分)
|
||||
|
||||
```bash
|
||||
# 一次性准备(需联网的机器):拉取 llama.cpp 子模块 + 下载数据集
|
||||
git submodule update --init third_party/llama.cpp # 固定在 tag b9371
|
||||
HF_ENDPOINT=https://hf-mirror.com python3 -m tools.bench.fetch_datasets
|
||||
|
||||
# 一键对比(构建 llama.cpp + 转 GGUF + 构建 xserv + 跑两套 + 出报告)
|
||||
./tools/sync-and-build.sh bench -- --max-seq-len 8192 --quality-limit 50
|
||||
./tools/sync-and-build.sh fetch-bench-out
|
||||
# 报告产物:bench-out/comparison-<时间戳>.{md,json}
|
||||
```
|
||||
|
||||
设计细节见 `docs/16-llama-cpp-comparison.md`,结果报告见 `docs/benchmarks/llama-cpp-comparison.md`。
|
||||
|
||||
## 文档
|
||||
|
||||
- `docs/00-roadmap.md`:总体路线图与各 Phase 设计
|
||||
- `docs/01..15-*.md`:CUDA FFI / Tensor / GEMM / Attention / KV cache / 性能优化等每个 Phase 的设计文档
|
||||
- `docs/16-llama-cpp-comparison.md`:llama.cpp 对比基准的设计
|
||||
- `docs/17-tensor-parallelism.md`:张量并行(TP)设计
|
||||
- `docs/18-pipeline-parallelism.md`:流水线并行(PP)设计
|
||||
- `docs/benchmarks/`:各阶段的 benchmark 报告(含 `pp-sweep.md`)
|
||||
|
||||
## 多卡并行(TP / PP)
|
||||
|
||||
单机多卡,复用 NCCL(crate `xserv-distributed`)。两种切法正交、二选一:
|
||||
|
||||
- **张量并行 `--tp N`**:按 head / 中间维切每一层,层内用 AllReduce 聚合(每 token `2·层数` 次)。
|
||||
- **流水线并行 `--pp N`**:按层切成 N 段,相邻段间用 NCCL **P2P** 传 hidden state(每 token 仅 `N-1` 次),
|
||||
通信量远小于 AllReduce,对无 NVLink 的 PCIe 更友好。
|
||||
|
||||
```bash
|
||||
# 组内 GPU 0-3:4 卡张量并行 / 4 卡流水线并行
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 ./target/release/xserv-server /path/to/qwen3-8b --tp 4
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 ./target/release/xserv-server /path/to/qwen3-8b --pp 4
|
||||
```
|
||||
|
||||
**PP 实测**(dash5,Qwen3-8B BF16,单流贪心;每卡显存为权重+最小 KV 池):
|
||||
|
||||
| 配置 | TTFT | TPOT | tok/s | 每卡显存 |
|
||||
|------|------|------|-------|----------|
|
||||
| 单卡 | 33ms | 17.4ms | 57.5 | 24.0 GB |
|
||||
| PP=2 | 36ms | 18.1ms | 55.3 | 11.6 / 13.6 GB |
|
||||
| PP=4 | 36ms | 17.9ms | 55.8 | 7.3 / 5.3 / 5.3 / 9.4 GB |
|
||||
|
||||
**质量对比**(AIME 2025 30 题 + GSM8K 30 题,贪心,xserv 在 GPU 0-3、llama.cpp 在 GPU 4-7 并行):
|
||||
|
||||
| 引擎 | PP | AIME | GSM8K |
|
||||
|------|----|------|-------|
|
||||
| xserv | 1/2/4 | 8 / 7 / 7 (/30) | 29/30 (96.7%) 全部一致 |
|
||||
| llama | 1/2/4 | 7 / 7 / 7 (/30) | 29/30 (96.7%) 全部一致 |
|
||||
|
||||
正确性:hidden state 跨段是 **bit-exact BF16 P2P 拷贝**,PP=4 输出与单卡逐字节一致(用「单卡×2 vs
|
||||
PP=4×2」对照确认——单卡自身因 cuBLAS 非确定性 run-to-run 会变,而 PP=4 可复现且落在某次单卡轨迹上)。
|
||||
GSM8K 12 个格子全是 29/30,xserv 与 llama.cpp 完全一致;AIME 的 ±1 是长生成下贪心对 GEMM 抖动的敏感,
|
||||
非 PP 或引擎效应。**收益在显存**(每卡权重+KV ≈ 1/N);v1 为串行流水线,单流 TPOT 基本持平、不优于单卡,
|
||||
真正的吞吐提升需后续做 microbatch / 1F1B 重叠。完整数据见 `docs/benchmarks/pp-sweep.md`。
|
||||
|
||||
## 路线图(节选)
|
||||
|
||||
已完成 Phase 0–21:CUDA 基础设施 → Tensor → GEMM → Transformer kernels → Attention →
|
||||
模型加载 → 分词器 → GPT-2 → KV cache → Qwen3-8B → Paged Attention → 连续批处理 →
|
||||
HTTP API → Flash Attention 2 → 性能优化 → **张量并行(TP)** → **流水线并行(PP)** →
|
||||
**gpt-oss MoE + FP8/MXFP4 量化** → **稀疏 top-k MoE decode** → **decode CUDA Graph 整图回放**;
|
||||
并加入了 **llama.cpp 对比基准** 与 **KV CPU 换出** 等基础设施。
|
||||
|
||||
后续方向:非专家权重量化(lm_head/qkv/o)、稀疏 prefill(grouped GEMM)、server 侧 harmony
|
||||
channel 分离、PP microbatch/1F1B、投机解码、多模态。详见 `docs/00-roadmap.md` 的实际进展记录。
|
||||
|
||||
## 许可
|
||||
|
||||
MIT
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::error::Result;
|
||||
use crate::ffi;
|
||||
use crate::memory::GpuBuffer;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Caching allocator that reuses freed GPU buffers instead of calling
|
||||
@@ -84,6 +85,94 @@ impl Drop for CachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static ALLOCATOR: RefCell<CachingAllocator> = RefCell::new(CachingAllocator::new());
|
||||
}
|
||||
|
||||
/// Allocate a GPU buffer through the caching allocator.
|
||||
/// The returned buffer has `pooled = true` so it will be returned
|
||||
/// to the pool on drop instead of calling cudaFree.
|
||||
pub fn cached_alloc(size: usize) -> Result<GpuBuffer> {
|
||||
ALLOCATOR.with(|cell| {
|
||||
let mut buf = cell.borrow_mut().alloc(size)?;
|
||||
buf.set_pooled(true);
|
||||
Ok(buf)
|
||||
})
|
||||
}
|
||||
|
||||
/// Free all cached (unused) GPU buffers back to the driver.
|
||||
pub fn cached_trim() {
|
||||
ALLOCATOR.with(|cell| {
|
||||
cell.borrow_mut().trim();
|
||||
});
|
||||
}
|
||||
|
||||
/// Return a raw GPU pointer to the caching allocator's free list.
|
||||
/// Called from `GpuBuffer::Drop` for pooled buffers. Takes raw pointer
|
||||
/// and size to avoid re-triggering Drop.
|
||||
pub fn return_to_pool(ptr: *mut u8, len: usize) {
|
||||
// During CUDA graph capture, buffers freed by the captured code are
|
||||
// quarantined instead of pooled: the instantiated graph references their
|
||||
// addresses on every replay, so they must never be handed to another
|
||||
// consumer for as long as the graph lives.
|
||||
let quarantined = RETAINED.with(|cell| {
|
||||
let mut r = cell.borrow_mut();
|
||||
if let Some(list) = r.as_mut() {
|
||||
list.push((ptr, len));
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
if quarantined {
|
||||
return;
|
||||
}
|
||||
ALLOCATOR.with(|cell| {
|
||||
let mut alloc = cell.borrow_mut();
|
||||
let bucket = bucket_size(len);
|
||||
alloc.stats.current_allocated = alloc.stats.current_allocated.saturating_sub(len);
|
||||
alloc.free_lists.entry(bucket).or_default().push((ptr, len));
|
||||
});
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static RETAINED: RefCell<Option<Vec<(*mut u8, usize)>>> = const { RefCell::new(None) };
|
||||
}
|
||||
|
||||
/// Buffers freed while a retain window was active. Holding this keeps their
|
||||
/// memory out of the pool; dropping it returns the blocks (on the owning
|
||||
/// thread) for reuse.
|
||||
pub struct RetainedBlocks(Vec<(*mut u8, usize)>);
|
||||
|
||||
impl Drop for RetainedBlocks {
|
||||
fn drop(&mut self) {
|
||||
for (ptr, len) in self.0.drain(..) {
|
||||
return_to_pool(ptr, len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Start quarantining buffers freed on this thread (see `return_to_pool`).
|
||||
/// Must be paired with `end_retain` on the same thread; nesting unsupported.
|
||||
pub fn begin_retain() {
|
||||
RETAINED.with(|cell| {
|
||||
let mut r = cell.borrow_mut();
|
||||
assert!(r.is_none(), "begin_retain: retain window already active");
|
||||
*r = Some(Vec::new());
|
||||
});
|
||||
}
|
||||
|
||||
/// Stop quarantining and hand the quarantined blocks to the caller.
|
||||
pub fn end_retain() -> RetainedBlocks {
|
||||
RETAINED.with(|cell| {
|
||||
let list = cell
|
||||
.borrow_mut()
|
||||
.take()
|
||||
.expect("end_retain without begin_retain");
|
||||
RetainedBlocks(list)
|
||||
})
|
||||
}
|
||||
|
||||
/// Round up to next power-of-2, minimum 512 bytes.
|
||||
fn bucket_size(size: usize) -> usize {
|
||||
let min = 512;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::error::{self, Result};
|
||||
use crate::ffi;
|
||||
use std::ffi::CStr;
|
||||
use std::os::raw::c_char;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DeviceInfo {
|
||||
@@ -44,10 +45,12 @@ pub fn current_device() -> Result<u32> {
|
||||
}
|
||||
|
||||
pub fn device_info(device: u32) -> Result<DeviceInfo> {
|
||||
// Get device name from cudaGetDeviceProperties (only use the name field).
|
||||
let mut prop = unsafe { std::mem::zeroed::<ffi::CudaDeviceProp>() };
|
||||
error::check(unsafe { ffi::cudaGetDeviceProperties(&mut prop, device as i32) })?;
|
||||
let name = unsafe { CStr::from_ptr(prop.name.as_ptr()) }
|
||||
// Heap-allocate oversized buffer for cudaDeviceProp (layout varies by CUDA version).
|
||||
// CUDA 12.x struct is ~5-6 KB; use 32 KB to guard against future growth.
|
||||
let mut prop_buf = vec![0u8; 32768];
|
||||
error::check(unsafe { ffi::cudaGetDeviceProperties(prop_buf.as_mut_ptr(), device as i32) })?;
|
||||
// Name is always the first field: char[256].
|
||||
let name = unsafe { CStr::from_ptr(prop_buf.as_ptr() as *const c_char) }
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ use std::os::raw::c_char;
|
||||
|
||||
pub type CudaStream = *mut c_void;
|
||||
pub type CudaEvent = *mut c_void;
|
||||
pub type CudaGraph = *mut c_void;
|
||||
pub type CudaGraphExec = *mut c_void;
|
||||
|
||||
pub const CUDA_MEMCPY_H2D: i32 = 1;
|
||||
pub const CUDA_MEMCPY_D2H: i32 = 2;
|
||||
@@ -11,31 +13,17 @@ pub const CUDA_MEMCPY_D2D: i32 = 3;
|
||||
pub const CUDA_SUCCESS: i32 = 0;
|
||||
pub const CUDA_ERROR_OUT_OF_MEMORY: i32 = 2;
|
||||
|
||||
#[repr(C)]
|
||||
pub struct CudaDeviceProp {
|
||||
pub name: [c_char; 256],
|
||||
pub total_global_mem: usize,
|
||||
pub shared_mem_per_block: usize,
|
||||
pub regs_per_block: i32,
|
||||
pub warp_size: i32,
|
||||
pub max_threads_per_block: i32,
|
||||
pub max_threads_dim: [i32; 3],
|
||||
pub max_grid_size: [i32; 3],
|
||||
pub clock_rate: i32,
|
||||
pub total_const_mem: usize,
|
||||
pub major: i32,
|
||||
pub minor: i32,
|
||||
// There are many more fields; we only read up to what we need.
|
||||
// cudaDeviceProp is a large struct (~1KB). We pad the rest.
|
||||
_pad: [u8; 4096],
|
||||
}
|
||||
/// cudaStreamCaptureMode::cudaStreamCaptureModeGlobal
|
||||
pub const CUDA_STREAM_CAPTURE_MODE_GLOBAL: i32 = 0;
|
||||
pub const CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL: i32 = 1;
|
||||
|
||||
unsafe extern "C" {
|
||||
// --- Device ---
|
||||
pub fn cudaGetDeviceCount(count: *mut i32) -> i32;
|
||||
pub fn cudaSetDevice(device: i32) -> i32;
|
||||
pub fn cudaGetDevice(device: *mut i32) -> i32;
|
||||
pub fn cudaGetDeviceProperties(prop: *mut CudaDeviceProp, device: i32) -> i32;
|
||||
/// Takes a raw pointer; caller provides a heap buffer large enough for any CUDA version.
|
||||
pub fn cudaGetDeviceProperties(prop: *mut u8, device: i32) -> i32;
|
||||
pub fn cudaDeviceSynchronize() -> i32;
|
||||
|
||||
// --- Memory ---
|
||||
@@ -52,6 +40,7 @@ unsafe extern "C" {
|
||||
stream: CudaStream,
|
||||
) -> i32;
|
||||
pub fn cudaMemset(devptr: *mut u8, value: i32, count: usize) -> i32;
|
||||
pub fn cudaMemsetAsync(devptr: *mut u8, value: i32, count: usize, stream: CudaStream) -> i32;
|
||||
|
||||
// --- Stream ---
|
||||
pub fn cudaStreamCreate(stream: *mut CudaStream) -> i32;
|
||||
@@ -62,12 +51,18 @@ unsafe extern "C" {
|
||||
pub fn cudaGetLastError() -> i32;
|
||||
pub fn cudaGetErrorString(error: i32) -> *const c_char;
|
||||
|
||||
// --- CUDA Graphs ---
|
||||
pub fn cudaStreamBeginCapture(stream: CudaStream, mode: i32) -> i32;
|
||||
pub fn cudaStreamEndCapture(stream: CudaStream, graph: *mut CudaGraph) -> i32;
|
||||
pub fn cudaGraphInstantiate(
|
||||
graph_exec: *mut CudaGraphExec,
|
||||
graph: CudaGraph,
|
||||
flags: u64,
|
||||
) -> i32;
|
||||
pub fn cudaGraphLaunch(graph_exec: CudaGraphExec, stream: CudaStream) -> i32;
|
||||
pub fn cudaGraphDestroy(graph: CudaGraph) -> i32;
|
||||
pub fn cudaGraphExecDestroy(graph_exec: CudaGraphExec) -> i32;
|
||||
|
||||
// --- Our test kernel ---
|
||||
pub fn launch_vecadd_f32(
|
||||
a: *const f32,
|
||||
b: *const f32,
|
||||
c: *mut f32,
|
||||
n: i32,
|
||||
stream: CudaStream,
|
||||
);
|
||||
pub fn launch_vecadd_f32(a: *const f32, b: *const f32, c: *mut f32, n: i32, stream: CudaStream);
|
||||
}
|
||||
|
||||
92
crates/xserv-cuda/src/graph.rs
Normal file
92
crates/xserv-cuda/src/graph.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
//! CUDA Graphs: capture a sequence of kernel launches and replay them with
|
||||
//! near-zero host-side overhead (~3-5 us per launch eliminated).
|
||||
//!
|
||||
//! Usage:
|
||||
//! ```ignore
|
||||
//! let stream = CudaStream::new()?;
|
||||
//! let mut graph = CudaGraph::new();
|
||||
//!
|
||||
//! // First call: capture
|
||||
//! graph.begin_capture(&stream)?;
|
||||
//! // ... launch kernels on `stream` ...
|
||||
//! graph.end_capture(&stream)?;
|
||||
//!
|
||||
//! // Subsequent calls: replay
|
||||
//! graph.launch(&stream)?;
|
||||
//! ```
|
||||
//!
|
||||
//! Requirements for captured kernels:
|
||||
//! - All tensor shapes must be identical between capture and replay.
|
||||
//! - No host-side branching during the captured section.
|
||||
//! - Memory addresses used during capture must remain valid during replay.
|
||||
|
||||
use crate::error::{self, Result};
|
||||
use crate::ffi;
|
||||
use crate::stream::CudaStream;
|
||||
|
||||
/// RAII wrapper around a captured CUDA graph and its executable instance.
|
||||
pub struct CudaGraph {
|
||||
graph: ffi::CudaGraph,
|
||||
exec: ffi::CudaGraphExec,
|
||||
}
|
||||
|
||||
impl CudaGraph {
|
||||
/// Create an empty graph handle (not yet captured).
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
graph: std::ptr::null_mut(),
|
||||
exec: std::ptr::null_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if a graph has been captured and instantiated.
|
||||
pub fn is_ready(&self) -> bool {
|
||||
!self.exec.is_null()
|
||||
}
|
||||
|
||||
/// Begin capturing kernel launches on `stream`.
|
||||
/// All subsequent kernel launches on this stream are recorded into the
|
||||
/// graph instead of being executed.
|
||||
pub fn begin_capture(&mut self, stream: &CudaStream) -> Result<()> {
|
||||
// If we have an old graph, destroy it first
|
||||
self.destroy_inner();
|
||||
// THREAD_LOCAL: only "potentially unsafe" CUDA calls (cudaMalloc etc.)
|
||||
// made by THIS thread invalidate the capture. With GLOBAL mode, TP rank
|
||||
// threads capturing concurrently would poison each other's captures.
|
||||
error::check(unsafe {
|
||||
ffi::cudaStreamBeginCapture(stream.as_raw(), ffi::CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL)
|
||||
})
|
||||
}
|
||||
|
||||
/// End capture and instantiate the executable graph.
|
||||
pub fn end_capture(&mut self, stream: &CudaStream) -> Result<()> {
|
||||
error::check(unsafe { ffi::cudaStreamEndCapture(stream.as_raw(), &mut self.graph) })?;
|
||||
error::check(unsafe { ffi::cudaGraphInstantiate(&mut self.exec, self.graph, 0) })
|
||||
}
|
||||
|
||||
/// Replay the captured graph on `stream`.
|
||||
/// Panics if no graph has been captured yet.
|
||||
pub fn launch(&self, stream: &CudaStream) -> Result<()> {
|
||||
assert!(self.is_ready(), "CudaGraph::launch called before capture");
|
||||
error::check(unsafe { ffi::cudaGraphLaunch(self.exec, stream.as_raw()) })
|
||||
}
|
||||
|
||||
fn destroy_inner(&mut self) {
|
||||
if !self.exec.is_null() {
|
||||
unsafe { ffi::cudaGraphExecDestroy(self.exec) };
|
||||
self.exec = std::ptr::null_mut();
|
||||
}
|
||||
if !self.graph.is_null() {
|
||||
unsafe { ffi::cudaGraphDestroy(self.graph) };
|
||||
self.graph = std::ptr::null_mut();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CudaGraph {
|
||||
fn drop(&mut self) {
|
||||
self.destroy_inner();
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for CudaGraph {}
|
||||
@@ -2,11 +2,13 @@ pub mod allocator;
|
||||
pub mod device;
|
||||
pub mod error;
|
||||
pub mod ffi;
|
||||
pub mod graph;
|
||||
pub mod memory;
|
||||
pub mod stream;
|
||||
|
||||
pub use allocator::CachingAllocator;
|
||||
pub use device::DeviceInfo;
|
||||
pub use error::{CudaError, Result};
|
||||
pub use graph::CudaGraph;
|
||||
pub use memory::{GpuBuffer, PinnedBuffer};
|
||||
pub use stream::CudaStream;
|
||||
pub use stream::{CudaStream, StreamGuard, current_stream_raw, push_stream};
|
||||
|
||||
@@ -3,9 +3,18 @@ use crate::ffi;
|
||||
use crate::stream::CudaStream;
|
||||
|
||||
/// RAII wrapper around a GPU memory allocation.
|
||||
///
|
||||
/// When `owned` is true (the default), dropping frees the GPU memory.
|
||||
/// A borrowed buffer (`owned = false`) does NOT free on drop — the
|
||||
/// caller must ensure the backing allocation outlives all borrows.
|
||||
///
|
||||
/// When `pooled` is true, dropping returns the buffer to the caching
|
||||
/// allocator's free list instead of calling cudaFree.
|
||||
pub struct GpuBuffer {
|
||||
ptr: *mut u8,
|
||||
len: usize,
|
||||
owned: bool,
|
||||
pooled: bool,
|
||||
}
|
||||
|
||||
impl GpuBuffer {
|
||||
@@ -13,7 +22,18 @@ impl GpuBuffer {
|
||||
assert!(len > 0, "cannot allocate 0 bytes on GPU");
|
||||
let mut ptr = std::ptr::null_mut();
|
||||
error::check(unsafe { ffi::cudaMalloc(&mut ptr, len) })?;
|
||||
Ok(Self { ptr, len })
|
||||
Ok(Self {
|
||||
ptr,
|
||||
len,
|
||||
owned: true,
|
||||
pooled: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Mark this buffer as pooled (returned to caching allocator on drop)
|
||||
/// or not. Called by `cached_alloc` after obtaining a buffer.
|
||||
pub fn set_pooled(&mut self, pooled: bool) {
|
||||
self.pooled = pooled;
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
@@ -77,9 +97,7 @@ impl GpuBuffer {
|
||||
/// Copy from another GPU buffer (D2D).
|
||||
pub fn copy_from_device(&mut self, src: &GpuBuffer) -> Result<()> {
|
||||
let n = src.len.min(self.len);
|
||||
error::check(unsafe {
|
||||
ffi::cudaMemcpy(self.ptr, src.ptr, n, ffi::CUDA_MEMCPY_D2D)
|
||||
})
|
||||
error::check(unsafe { ffi::cudaMemcpy(self.ptr, src.ptr, n, ffi::CUDA_MEMCPY_D2D) })
|
||||
}
|
||||
|
||||
/// Fill buffer with zeros.
|
||||
@@ -87,6 +105,81 @@ impl GpuBuffer {
|
||||
error::check(unsafe { ffi::cudaMemset(self.ptr, 0, self.len) })
|
||||
}
|
||||
|
||||
/// Copy `count` bytes from `src` buffer at `src_offset` to this buffer at `dst_offset`.
|
||||
pub fn copy_from_device_at(
|
||||
&mut self,
|
||||
src: &GpuBuffer,
|
||||
src_offset: usize,
|
||||
dst_offset: usize,
|
||||
count: usize,
|
||||
) -> Result<()> {
|
||||
assert!(src_offset + count <= src.len);
|
||||
assert!(dst_offset + count <= self.len);
|
||||
error::check(unsafe {
|
||||
ffi::cudaMemcpy(
|
||||
self.ptr.add(dst_offset),
|
||||
src.ptr.add(src_offset),
|
||||
count,
|
||||
ffi::CUDA_MEMCPY_D2D,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Async copy `count` bytes from `src` at `src_offset` to `self` at `dst_offset` on `stream`.
|
||||
pub fn copy_from_device_at_async(
|
||||
&mut self,
|
||||
src: &GpuBuffer,
|
||||
src_offset: usize,
|
||||
dst_offset: usize,
|
||||
count: usize,
|
||||
stream: &CudaStream,
|
||||
) -> Result<()> {
|
||||
assert!(src_offset + count <= src.len);
|
||||
assert!(dst_offset + count <= self.len);
|
||||
error::check(unsafe {
|
||||
ffi::cudaMemcpyAsync(
|
||||
self.ptr.add(dst_offset),
|
||||
src.ptr.add(src_offset),
|
||||
count,
|
||||
ffi::CUDA_MEMCPY_D2D,
|
||||
stream.as_raw(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Copy `count` bytes from this GPU buffer at `src_offset` to a host slice (D2H).
|
||||
pub fn copy_to_host_at(&self, dst: &mut [u8], src_offset: usize, count: usize) -> Result<()> {
|
||||
assert!(src_offset + count <= self.len, "src range out of bounds");
|
||||
assert!(count <= dst.len(), "host dst too small");
|
||||
error::check(unsafe {
|
||||
ffi::cudaMemcpy(
|
||||
dst.as_mut_ptr(),
|
||||
self.ptr.add(src_offset),
|
||||
count,
|
||||
ffi::CUDA_MEMCPY_D2H,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Copy `count` bytes from a host slice to this GPU buffer at `dst_offset` (H2D).
|
||||
pub fn copy_from_host_at(&mut self, src: &[u8], dst_offset: usize, count: usize) -> Result<()> {
|
||||
assert!(dst_offset + count <= self.len, "dst range out of bounds");
|
||||
assert!(count <= src.len(), "host src too small");
|
||||
error::check(unsafe {
|
||||
ffi::cudaMemcpy(
|
||||
self.ptr.add(dst_offset),
|
||||
src.as_ptr(),
|
||||
count,
|
||||
ffi::CUDA_MEMCPY_H2D,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Async zero fill on stream.
|
||||
pub fn zero_async(&mut self, stream: &CudaStream) -> Result<()> {
|
||||
error::check(unsafe { ffi::cudaMemsetAsync(self.ptr, 0, self.len, stream.as_raw()) })
|
||||
}
|
||||
|
||||
/// Consume the buffer without freeing GPU memory. Returns the raw pointer and length.
|
||||
/// Caller is responsible for eventually calling cudaFree.
|
||||
pub fn into_raw(self) -> (*mut u8, usize) {
|
||||
@@ -99,14 +192,39 @@ impl GpuBuffer {
|
||||
/// Reconstruct a GpuBuffer from a raw pointer + length.
|
||||
/// Safety: ptr must have been allocated with cudaMalloc, len must be correct.
|
||||
pub unsafe fn from_raw(ptr: *mut u8, len: usize) -> Self {
|
||||
Self { ptr, len }
|
||||
Self {
|
||||
ptr,
|
||||
len,
|
||||
owned: true,
|
||||
pooled: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a non-owning view of GPU memory. Dropping this buffer does NOT
|
||||
/// call `cudaFree`. The caller must ensure the underlying allocation
|
||||
/// outlives this borrow.
|
||||
///
|
||||
/// # Safety
|
||||
/// `ptr` must point to a valid GPU allocation of at least `len` bytes that
|
||||
/// will remain live for the lifetime of the returned `GpuBuffer`.
|
||||
pub unsafe fn borrow_raw(ptr: *mut u8, len: usize) -> Self {
|
||||
Self {
|
||||
ptr,
|
||||
len,
|
||||
owned: false,
|
||||
pooled: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for GpuBuffer {
|
||||
fn drop(&mut self) {
|
||||
if !self.ptr.is_null() {
|
||||
unsafe { ffi::cudaFree(self.ptr) };
|
||||
if self.owned && !self.ptr.is_null() {
|
||||
if self.pooled {
|
||||
crate::allocator::return_to_pool(self.ptr, self.len);
|
||||
} else {
|
||||
unsafe { ffi::cudaFree(self.ptr) };
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,3 +31,39 @@ impl Drop for CudaStream {
|
||||
|
||||
// Can move across threads, but not shared without synchronization
|
||||
unsafe impl Send for CudaStream {}
|
||||
|
||||
// --- Thread-local launch stream -------------------------------------------
|
||||
//
|
||||
// Every kernel wrapper in xserv-kernels launches on `current_stream_raw()`,
|
||||
// which defaults to the legacy null stream (the historical behavior). CUDA
|
||||
// graph capture requires work to be issued on an explicit stream, so capture
|
||||
// code installs its stream here for the duration of the captured region via
|
||||
// `push_stream` / `StreamGuard`.
|
||||
|
||||
use std::cell::Cell;
|
||||
|
||||
thread_local! {
|
||||
static CURRENT_STREAM: Cell<ffi::CudaStream> = const { Cell::new(std::ptr::null_mut()) };
|
||||
}
|
||||
|
||||
/// The stream kernel launches on this thread should use (null = legacy default).
|
||||
pub fn current_stream_raw() -> ffi::CudaStream {
|
||||
CURRENT_STREAM.with(|c| c.get())
|
||||
}
|
||||
|
||||
/// RAII guard that installs a launch stream for the current thread and
|
||||
/// restores the previous one on drop.
|
||||
pub struct StreamGuard {
|
||||
prev: ffi::CudaStream,
|
||||
}
|
||||
|
||||
pub fn push_stream(stream: &CudaStream) -> StreamGuard {
|
||||
let prev = CURRENT_STREAM.with(|c| c.replace(stream.as_raw()));
|
||||
StreamGuard { prev }
|
||||
}
|
||||
|
||||
impl Drop for StreamGuard {
|
||||
fn drop(&mut self) {
|
||||
CURRENT_STREAM.with(|c| c.set(self.prev));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,10 @@ fn test_device_info() {
|
||||
info.compute_major, info.compute_minor
|
||||
);
|
||||
println!(" SM Count: {}", info.sm_count);
|
||||
println!(" Shared Mem/Block: {} KB", info.shared_mem_per_block / 1024);
|
||||
println!(
|
||||
" Shared Mem/Block: {} KB",
|
||||
info.shared_mem_per_block / 1024
|
||||
);
|
||||
println!(" Warp Size: {}", info.warp_size);
|
||||
println!(" Max Threads/Block: {}", info.max_threads_per_block);
|
||||
|
||||
@@ -145,7 +148,11 @@ fn test_caching_allocator() {
|
||||
|
||||
// Second allocation of same size: should hit cache
|
||||
let _buf2 = alloc.alloc(1024).unwrap();
|
||||
assert_eq!(alloc.stats().cuda_malloc_count, 1, "should reuse cached buffer");
|
||||
assert_eq!(
|
||||
alloc.stats().cuda_malloc_count,
|
||||
1,
|
||||
"should reuse cached buffer"
|
||||
);
|
||||
assert_eq!(alloc.stats().cache_hit_count, 1);
|
||||
}
|
||||
|
||||
@@ -198,11 +205,17 @@ fn test_async_copy() {
|
||||
}
|
||||
|
||||
let mut gpu = GpuBuffer::alloc(4096).unwrap();
|
||||
unsafe { gpu.copy_from_host_async(pinned.as_slice(), &stream).unwrap() };
|
||||
unsafe {
|
||||
gpu.copy_from_host_async(pinned.as_slice(), &stream)
|
||||
.unwrap()
|
||||
};
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
let mut out_pinned = PinnedBuffer::alloc(4096).unwrap();
|
||||
unsafe { gpu.copy_to_host_async(out_pinned.as_mut_slice(), &stream).unwrap() };
|
||||
unsafe {
|
||||
gpu.copy_to_host_async(out_pinned.as_mut_slice(), &stream)
|
||||
.unwrap()
|
||||
};
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
assert_eq!(pinned.as_slice(), out_pinned.as_slice());
|
||||
|
||||
8
crates/xserv-distributed/Cargo.toml
Normal file
8
crates/xserv-distributed/Cargo.toml
Normal file
@@ -0,0 +1,8 @@
|
||||
[package]
|
||||
name = "xserv-distributed"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
xserv-cuda = { path = "../xserv-cuda" }
|
||||
half.workspace = true
|
||||
13
crates/xserv-distributed/build.rs
Normal file
13
crates/xserv-distributed/build.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
use std::env;
|
||||
|
||||
fn main() {
|
||||
let cuda_path = env::var("CUDA_HOME")
|
||||
.or_else(|_| env::var("CUDA_PATH"))
|
||||
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
|
||||
|
||||
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
|
||||
// NCCL is typically installed as a system library.
|
||||
println!("cargo:rustc-link-search=native=/usr/lib/x86_64-linux-gnu");
|
||||
println!("cargo:rustc-link-lib=dylib=nccl");
|
||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||
}
|
||||
92
crates/xserv-distributed/src/ffi.rs
Normal file
92
crates/xserv-distributed/src/ffi.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
//! Minimal NCCL FFI bindings (hand-written, like the CUDA bindings).
|
||||
//! Only the collectives we need for tensor parallelism.
|
||||
|
||||
use std::ffi::c_void;
|
||||
use std::os::raw::c_char;
|
||||
use xserv_cuda::ffi::CudaStream;
|
||||
|
||||
/// Opaque NCCL communicator handle (`ncclComm_t`).
|
||||
pub type NcclComm = *mut c_void;
|
||||
|
||||
/// `ncclUniqueId` is a 128-byte opaque blob shared from rank 0 to all ranks.
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct NcclUniqueId {
|
||||
pub internal: [c_char; 128],
|
||||
}
|
||||
|
||||
impl Default for NcclUniqueId {
|
||||
fn default() -> Self {
|
||||
Self { internal: [0; 128] }
|
||||
}
|
||||
}
|
||||
|
||||
// ncclDataType_t (subset)
|
||||
pub const NCCL_FLOAT32: i32 = 7;
|
||||
pub const NCCL_BF16: i32 = 9;
|
||||
|
||||
// ncclRedOp_t
|
||||
pub const NCCL_SUM: i32 = 0;
|
||||
|
||||
// ncclResult_t
|
||||
pub const NCCL_SUCCESS: i32 = 0;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub fn ncclGetUniqueId(uid: *mut NcclUniqueId) -> i32;
|
||||
// ncclUniqueId is passed BY VALUE (a 128-byte struct) per the NCCL ABI.
|
||||
pub fn ncclCommInitRank(
|
||||
comm: *mut NcclComm,
|
||||
nranks: i32,
|
||||
commid: NcclUniqueId,
|
||||
rank: i32,
|
||||
) -> i32;
|
||||
pub fn ncclCommDestroy(comm: NcclComm) -> i32;
|
||||
pub fn ncclAllReduce(
|
||||
sendbuff: *const c_void,
|
||||
recvbuff: *mut c_void,
|
||||
count: usize,
|
||||
datatype: i32,
|
||||
op: i32,
|
||||
comm: NcclComm,
|
||||
stream: CudaStream,
|
||||
) -> i32;
|
||||
// Point-to-point primitives for pipeline parallelism (Phase 18).
|
||||
pub fn ncclSend(
|
||||
sendbuff: *const c_void,
|
||||
count: usize,
|
||||
datatype: i32,
|
||||
peer: i32,
|
||||
comm: NcclComm,
|
||||
stream: CudaStream,
|
||||
) -> i32;
|
||||
pub fn ncclRecv(
|
||||
recvbuff: *mut c_void,
|
||||
count: usize,
|
||||
datatype: i32,
|
||||
peer: i32,
|
||||
comm: NcclComm,
|
||||
stream: CudaStream,
|
||||
) -> i32;
|
||||
pub fn ncclGroupStart() -> i32;
|
||||
pub fn ncclGroupEnd() -> i32;
|
||||
pub fn ncclGetErrorString(result: i32) -> *const c_char;
|
||||
}
|
||||
|
||||
pub fn err_string(result: i32) -> String {
|
||||
unsafe {
|
||||
let p = ncclGetErrorString(result);
|
||||
if p.is_null() {
|
||||
return format!("nccl error {result}");
|
||||
}
|
||||
std::ffi::CStr::from_ptr(p).to_string_lossy().into_owned()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check(result: i32, what: &str) {
|
||||
assert_eq!(
|
||||
result,
|
||||
NCCL_SUCCESS,
|
||||
"{what} failed: {}",
|
||||
err_string(result)
|
||||
);
|
||||
}
|
||||
192
crates/xserv-distributed/src/lib.rs
Normal file
192
crates/xserv-distributed/src/lib.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
//! Tensor-parallel primitives for xserv.
|
||||
//!
|
||||
//! Process model: one OS thread per TP rank, each bound to one GPU. NCCL is
|
||||
//! used for the collective (AllReduce); a hand-rolled P2P AllReduce may replace
|
||||
//! it later as a learning exercise (see docs/17-tensor-parallelism.md).
|
||||
|
||||
pub mod ffi;
|
||||
|
||||
use std::ffi::c_void;
|
||||
|
||||
use ffi::{NcclComm, NcclUniqueId};
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_cuda::device;
|
||||
|
||||
pub use ffi::NcclUniqueId as UniqueId;
|
||||
|
||||
/// NCCL is issued on the thread's current launch stream (legacy null stream
|
||||
/// by default, the capture stream during CUDA graph capture). The model's
|
||||
/// kernels run on the same stream, so AllReduce stays correctly ordered after
|
||||
/// the producing matmul and before the consuming kernel — no extra sync.
|
||||
fn launch_stream() -> xserv_cuda::ffi::CudaStream {
|
||||
xserv_cuda::stream::current_stream_raw()
|
||||
}
|
||||
|
||||
/// Generate a unique id on one rank (typically rank 0) and broadcast the bytes
|
||||
/// to all ranks out-of-band (e.g. via a shared variable across threads).
|
||||
pub fn get_unique_id() -> NcclUniqueId {
|
||||
let mut id = NcclUniqueId::default();
|
||||
ffi::check(unsafe { ffi::ncclGetUniqueId(&mut id) }, "ncclGetUniqueId");
|
||||
id
|
||||
}
|
||||
|
||||
/// Per-rank tensor-parallel context: NCCL communicator + a dedicated stream.
|
||||
pub struct TpContext {
|
||||
pub rank: usize,
|
||||
pub world: usize,
|
||||
pub device: u32,
|
||||
comm: NcclComm,
|
||||
}
|
||||
|
||||
// The NCCL communicator is owned by exactly one rank thread.
|
||||
unsafe impl Send for TpContext {}
|
||||
|
||||
impl TpContext {
|
||||
/// Initialize this rank. Must be called from the thread that will own this
|
||||
/// rank's GPU work; binds the thread to `device` first. All ranks must call
|
||||
/// this concurrently with the same `id` and `world`.
|
||||
pub fn init(rank: usize, world: usize, id: NcclUniqueId, device: u32) -> Self {
|
||||
device::set_device(device).expect("set_device");
|
||||
let mut comm: NcclComm = std::ptr::null_mut();
|
||||
// Wrap the concurrent inits in a group so they rendezvous without deadlock.
|
||||
ffi::check(unsafe { ffi::ncclGroupStart() }, "ncclGroupStart(init)");
|
||||
ffi::check(
|
||||
unsafe { ffi::ncclCommInitRank(&mut comm, world as i32, id, rank as i32) },
|
||||
"ncclCommInitRank",
|
||||
);
|
||||
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)");
|
||||
Self {
|
||||
rank,
|
||||
world,
|
||||
device,
|
||||
comm,
|
||||
}
|
||||
}
|
||||
|
||||
/// In-place AllReduce(sum) over `count` BF16 elements in `buf`.
|
||||
pub fn all_reduce_sum_bf16(&self, buf: &mut GpuBuffer, count: usize) {
|
||||
self.all_reduce_sum_bf16_ptr(buf.as_mut_ptr() as *mut c_void, count);
|
||||
}
|
||||
|
||||
/// In-place AllReduce(sum) directly on a device pointer (`count` BF16 elems),
|
||||
/// issued on the null stream so it is ordered with the model's kernels.
|
||||
/// Asynchronous: a later sync (e.g. the D2H logits copy) completes it.
|
||||
///
|
||||
/// # Safety
|
||||
/// `ptr` must point to at least `count` BF16 elements of valid device memory
|
||||
/// on this rank's device. The reduction is in-place (send == recv).
|
||||
pub fn all_reduce_sum_bf16_ptr(&self, ptr: *mut c_void, count: usize) {
|
||||
if self.world == 1 {
|
||||
return; // nothing to reduce
|
||||
}
|
||||
ffi::check(
|
||||
unsafe {
|
||||
ffi::ncclAllReduce(
|
||||
ptr as *const c_void,
|
||||
ptr,
|
||||
count,
|
||||
ffi::NCCL_BF16,
|
||||
ffi::NCCL_SUM,
|
||||
self.comm,
|
||||
launch_stream(),
|
||||
)
|
||||
},
|
||||
"ncclAllReduce",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TpContext {
|
||||
fn drop(&mut self) {
|
||||
if !self.comm.is_null() {
|
||||
unsafe { ffi::ncclCommDestroy(self.comm) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-stage pipeline-parallel context: a NCCL communicator spanning all `P`
|
||||
/// stages plus point-to-point send/recv of the hidden state to the neighbour
|
||||
/// stages. Init is identical to `TpContext` (one comm across `world` ranks);
|
||||
/// only the collective differs — PP hands off `[tokens, hidden]` between
|
||||
/// consecutive stages instead of AllReducing within a layer.
|
||||
pub struct PpContext {
|
||||
pub stage: usize,
|
||||
pub world: usize,
|
||||
pub device: u32,
|
||||
comm: NcclComm,
|
||||
}
|
||||
|
||||
// The NCCL communicator is owned by exactly one stage thread.
|
||||
unsafe impl Send for PpContext {}
|
||||
|
||||
impl PpContext {
|
||||
/// Initialize this stage. Must be called from the thread that owns this
|
||||
/// stage's GPU; binds the thread to `device` first. All stages call this
|
||||
/// concurrently with the same `id` and `world`.
|
||||
pub fn init(stage: usize, world: usize, id: NcclUniqueId, device: u32) -> Self {
|
||||
device::set_device(device).expect("set_device");
|
||||
let mut comm: NcclComm = std::ptr::null_mut();
|
||||
ffi::check(unsafe { ffi::ncclGroupStart() }, "ncclGroupStart(init)");
|
||||
ffi::check(
|
||||
unsafe { ffi::ncclCommInitRank(&mut comm, world as i32, id, stage as i32) },
|
||||
"ncclCommInitRank",
|
||||
);
|
||||
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)");
|
||||
Self {
|
||||
stage,
|
||||
world,
|
||||
device,
|
||||
comm,
|
||||
}
|
||||
}
|
||||
|
||||
/// Send `count` BF16 elements at `ptr` to `peer`, on the null stream so it is
|
||||
/// ordered after the producing matmul. Asynchronous — a later `synchronize`
|
||||
/// (the caller must do one before reusing/freeing the buffer) completes it.
|
||||
///
|
||||
/// # Safety
|
||||
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
|
||||
pub fn send_bf16_ptr(&self, ptr: *const c_void, count: usize, peer: usize) {
|
||||
ffi::check(
|
||||
unsafe {
|
||||
ffi::ncclSend(
|
||||
ptr,
|
||||
count,
|
||||
ffi::NCCL_BF16,
|
||||
peer as i32,
|
||||
self.comm,
|
||||
launch_stream(),
|
||||
)
|
||||
},
|
||||
"ncclSend",
|
||||
);
|
||||
}
|
||||
|
||||
/// Receive `count` BF16 elements from `peer` into `ptr`, on the null stream.
|
||||
///
|
||||
/// # Safety
|
||||
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
|
||||
pub fn recv_bf16_ptr(&self, ptr: *mut c_void, count: usize, peer: usize) {
|
||||
ffi::check(
|
||||
unsafe {
|
||||
ffi::ncclRecv(
|
||||
ptr,
|
||||
count,
|
||||
ffi::NCCL_BF16,
|
||||
peer as i32,
|
||||
self.comm,
|
||||
launch_stream(),
|
||||
)
|
||||
},
|
||||
"ncclRecv",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PpContext {
|
||||
fn drop(&mut self) {
|
||||
if !self.comm.is_null() {
|
||||
unsafe { ffi::ncclCommDestroy(self.comm) };
|
||||
}
|
||||
}
|
||||
}
|
||||
48
crates/xserv-distributed/tests/allreduce.rs
Normal file
48
crates/xserv-distributed/tests/allreduce.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
//! 2-GPU AllReduce smoke test. Skips if fewer than 2 GPUs are present.
|
||||
|
||||
use half::bf16;
|
||||
use std::thread;
|
||||
use xserv_cuda::{GpuBuffer, device};
|
||||
use xserv_distributed::{TpContext, get_unique_id};
|
||||
|
||||
#[test]
|
||||
fn allreduce_two_gpu_sum() {
|
||||
let world = 2usize;
|
||||
if device::device_count().unwrap_or(0) < world as i32 {
|
||||
eprintln!("skip: need >= {world} GPUs");
|
||||
return;
|
||||
}
|
||||
|
||||
let id = get_unique_id();
|
||||
let n = 4096usize;
|
||||
|
||||
let handles: Vec<_> = (0..world)
|
||||
.map(|rank| {
|
||||
let id = id;
|
||||
thread::spawn(move || {
|
||||
let ctx = TpContext::init(rank, world, id, rank as u32);
|
||||
|
||||
// Rank r fills its buffer with (r + 1).
|
||||
let val = bf16::from_f32((rank + 1) as f32);
|
||||
let host = vec![val; n];
|
||||
let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
|
||||
let mut buf = GpuBuffer::alloc(n * 2).unwrap();
|
||||
buf.copy_from_host(src).unwrap();
|
||||
|
||||
ctx.all_reduce_sum_bf16(&mut buf, n);
|
||||
|
||||
let mut out = vec![0u8; n * 2];
|
||||
buf.copy_to_host(&mut out).unwrap();
|
||||
let res = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const bf16, n) };
|
||||
(res[0].to_f32(), res[n - 1].to_f32())
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// sum over ranks of (r+1) = 1 + 2 = 3
|
||||
for h in handles {
|
||||
let (first, last) = h.join().unwrap();
|
||||
assert_eq!(first, 3.0, "AllReduce(sum) first element");
|
||||
assert_eq!(last, 3.0, "AllReduce(sum) last element");
|
||||
}
|
||||
}
|
||||
63
crates/xserv-distributed/tests/sendrecv.rs
Normal file
63
crates/xserv-distributed/tests/sendrecv.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
//! 2-GPU NCCL P2P send/recv smoke test for pipeline parallelism.
|
||||
//! Stage 0 sends a known vector to stage 1, which verifies it. Skips if fewer
|
||||
//! than 2 GPUs are present. Mirrors `allreduce.rs` (GpuBuffer + half only —
|
||||
//! this crate does not depend on xserv-tensor).
|
||||
|
||||
use half::bf16;
|
||||
use std::ffi::c_void;
|
||||
use std::thread;
|
||||
use xserv_cuda::{GpuBuffer, device};
|
||||
use xserv_distributed::{PpContext, get_unique_id};
|
||||
|
||||
#[test]
|
||||
fn pp_send_recv_two_stages() {
|
||||
let world = 2usize;
|
||||
if device::device_count().unwrap_or(0) < world as i32 {
|
||||
eprintln!("skip: need >= {world} GPUs");
|
||||
return;
|
||||
}
|
||||
|
||||
let id = get_unique_id();
|
||||
let n = 4096usize; // one [1, hidden]-sized hand-off
|
||||
|
||||
let handles: Vec<_> = (0..world)
|
||||
.map(|stage| {
|
||||
let id = id;
|
||||
thread::spawn(move || {
|
||||
let pp = PpContext::init(stage, world, id, stage as u32);
|
||||
let mut buf = GpuBuffer::alloc(n * 2).unwrap();
|
||||
|
||||
if stage == 0 {
|
||||
// Fill with a known pattern and send to stage 1.
|
||||
let host: Vec<bf16> = (0..n).map(|i| bf16::from_f32((i % 97) as f32)).collect();
|
||||
let src =
|
||||
unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
|
||||
buf.copy_from_host(src).unwrap();
|
||||
pp.send_bf16_ptr(buf.as_mut_ptr() as *const c_void, n, 1);
|
||||
device::synchronize().unwrap();
|
||||
None
|
||||
} else {
|
||||
// Receive into a zeroed buffer and read it back.
|
||||
buf.copy_from_host(&vec![0u8; n * 2]).unwrap();
|
||||
pp.recv_bf16_ptr(buf.as_mut_ptr() as *mut c_void, n, 0);
|
||||
device::synchronize().unwrap();
|
||||
let mut out = vec![0u8; n * 2];
|
||||
buf.copy_to_host(&mut out).unwrap();
|
||||
let res = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const bf16, n) };
|
||||
Some((res[0].to_f32(), res[1].to_f32(), res[n - 1].to_f32()))
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut checked = false;
|
||||
for h in handles {
|
||||
if let Some((first, second, last)) = h.join().unwrap() {
|
||||
assert_eq!(first, 0.0, "recv[0]");
|
||||
assert_eq!(second, 1.0, "recv[1]");
|
||||
assert_eq!(last, ((n - 1) % 97) as f32, "recv[last]");
|
||||
checked = true;
|
||||
}
|
||||
}
|
||||
assert!(checked, "stage 1 never verified the received buffer");
|
||||
}
|
||||
@@ -8,6 +8,7 @@ fn main() {
|
||||
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
|
||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||
println!("cargo:rustc-link-lib=dylib=cublas");
|
||||
println!("cargo:rustc-link-lib=dylib=cublasLt");
|
||||
|
||||
cc::Build::new()
|
||||
.cuda(true)
|
||||
@@ -16,12 +17,24 @@ fn main() {
|
||||
.include("../../csrc")
|
||||
.file("../../csrc/gemm/naive.cu")
|
||||
.file("../../csrc/gemm/tiled.cu")
|
||||
.file("../../csrc/gemm/gemv.cu")
|
||||
.file("../../csrc/normalization/rmsnorm.cu")
|
||||
.file("../../csrc/normalization/layernorm.cu")
|
||||
.file("../../csrc/activation/activations.cu")
|
||||
.file("../../csrc/reduce/softmax.cu")
|
||||
.file("../../csrc/reduce/argmax.cu")
|
||||
.file("../../csrc/embedding/embedding.cu")
|
||||
.file("../../csrc/embedding/rope.cu")
|
||||
.file("../../csrc/attention/causal_mask.cu")
|
||||
.file("../../csrc/embedding/transpose.cu")
|
||||
.file("../../csrc/attention/flash_attention.cu")
|
||||
.file("../../csrc/attention/paged_attention.cu")
|
||||
.file("../../csrc/attention/reshape_and_cache.cu")
|
||||
.file("../../csrc/moe/moe_kernels.cu")
|
||||
.file("../../csrc/moe/moe_sparse.cu")
|
||||
.file("../../csrc/quantization/dequant_fp8.cu")
|
||||
.file("../../csrc/quantization/quantize_fp8.cu")
|
||||
.file("../../csrc/quantization/mxfp4_gemm.cu")
|
||||
.compile("xserv_kernels");
|
||||
|
||||
println!("cargo:rerun-if-changed=../../csrc/");
|
||||
|
||||
@@ -6,36 +6,271 @@ unsafe extern "C" {
|
||||
fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_scale_f32(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
scale: f32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_scale_bf16(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
scale: f32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_f32(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_mul_f32(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_mul_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_silu_mul_bf16(
|
||||
gate: *const c_void,
|
||||
up: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gpt_oss_glu_bf16(
|
||||
gate_up: *const c_void,
|
||||
out: *mut c_void,
|
||||
n_elements: i32,
|
||||
alpha: f32,
|
||||
limit: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_bias_add_2d_bf16(
|
||||
x: *const c_void,
|
||||
bias: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
fn dispatch_unary(
|
||||
x: &Tensor,
|
||||
f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
|
||||
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
|
||||
) -> Tensor {
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel();
|
||||
assert!(
|
||||
n <= i32::MAX as usize,
|
||||
"tensor too large for i32 kernel param ({n} elements)"
|
||||
);
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => f32_fn(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => bf16_fn(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype"),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn dispatch_binary(
|
||||
a: &Tensor,
|
||||
b: &Tensor,
|
||||
f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
|
||||
bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
|
||||
) -> Tensor {
|
||||
assert_eq!(a.shape(), b.shape());
|
||||
assert!(a.is_contiguous() && b.is_contiguous());
|
||||
assert!(matches!(a.device(), Device::Cuda(_)));
|
||||
assert_eq!(a.dtype(), b.dtype());
|
||||
let out = Tensor::empty(a.shape(), a.dtype(), a.device());
|
||||
let n = a.numel();
|
||||
assert!(
|
||||
n <= i32::MAX as usize,
|
||||
"tensor too large for i32 kernel param ({n} elements)"
|
||||
);
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match a.dtype() {
|
||||
DType::F32 => f32_fn(
|
||||
a.data_ptr() as _,
|
||||
b.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => bf16_fn(
|
||||
a.data_ptr() as _,
|
||||
b.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype"),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub fn gelu(x: &Tensor) -> Tensor {
|
||||
assert!(x.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel() as i32;
|
||||
dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16)
|
||||
}
|
||||
pub fn silu(x: &Tensor) -> Tensor {
|
||||
dispatch_unary(x, launch_silu_f32, launch_silu_bf16)
|
||||
}
|
||||
|
||||
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel();
|
||||
assert!(
|
||||
n <= i32::MAX as usize,
|
||||
"tensor too large for i32 kernel param ({n} elements)"
|
||||
);
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_gelu_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
DType::BF16 => launch_gelu_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
_ => panic!("unsupported dtype for gelu"),
|
||||
DType::F32 => launch_scale_f32(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
scale_val,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_scale_bf16(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
scale_val,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for scale"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
pub fn silu(x: &Tensor) -> Tensor {
|
||||
assert!(x.is_contiguous());
|
||||
pub fn add(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
dispatch_binary(a, b, launch_add_f32, launch_add_bf16)
|
||||
}
|
||||
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16)
|
||||
}
|
||||
|
||||
/// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c] (BF16 only).
|
||||
pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor {
|
||||
assert_eq!(x.ndim(), 2);
|
||||
assert_eq!(bias.ndim(), 1);
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert_eq!(bias.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && bias.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel() as i32;
|
||||
let rows = x.shape()[0];
|
||||
let cols = x.shape()[1];
|
||||
assert_eq!(
|
||||
bias.shape()[0],
|
||||
cols,
|
||||
"bias size {} != cols {cols}",
|
||||
bias.shape()[0]
|
||||
);
|
||||
assert!(rows * cols <= i32::MAX as usize);
|
||||
let out = Tensor::empty(&[rows, cols], DType::BF16, x.device());
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_silu_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
DType::BF16 => launch_silu_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
_ => panic!("unsupported dtype for silu"),
|
||||
}
|
||||
launch_bias_add_2d_bf16(
|
||||
x.data_ptr() as _,
|
||||
bias.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Fused SiLU×Mul: out = silu(gate) * up (BF16 only)
|
||||
/// Saves one HBM read + one HBM write compared to separate silu + mul.
|
||||
pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
|
||||
assert_eq!(gate.shape(), up.shape());
|
||||
assert!(gate.is_contiguous() && up.is_contiguous());
|
||||
assert!(matches!(gate.device(), Device::Cuda(_)));
|
||||
assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16");
|
||||
let out = Tensor::empty(gate.shape(), gate.dtype(), gate.device());
|
||||
let n = gate.numel();
|
||||
assert!(
|
||||
n <= i32::MAX as usize,
|
||||
"tensor too large for i32 kernel param ({n} elements)"
|
||||
);
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
launch_silu_mul_bf16(
|
||||
gate.data_ptr() as *const c_void,
|
||||
up.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// gpt-oss fused GLU activation (BF16 only).
|
||||
/// Input: gate_up [rows, 2*D] with interleaved columns (gate=even, up=odd).
|
||||
/// Output: [rows, D]
|
||||
/// Computes: gate.clamp(max=limit) * sigmoid(gate * alpha) * (up.clamp(-limit,limit) + 1)
|
||||
pub fn gpt_oss_glu(gate_up: &Tensor, alpha: f32, limit: f32) -> Tensor {
|
||||
assert!(gate_up.is_contiguous());
|
||||
assert!(matches!(gate_up.device(), Device::Cuda(_)));
|
||||
assert_eq!(gate_up.dtype(), DType::BF16, "gpt_oss_glu requires BF16");
|
||||
assert_eq!(gate_up.ndim(), 2);
|
||||
let rows = gate_up.shape()[0];
|
||||
let cols = gate_up.shape()[1];
|
||||
assert_eq!(cols % 2, 0);
|
||||
let d = cols / 2;
|
||||
let out = Tensor::empty(&[rows, d], gate_up.dtype(), gate_up.device());
|
||||
let n_elements = (rows * d) as i32;
|
||||
unsafe {
|
||||
launch_gpt_oss_glu_bf16(
|
||||
gate_up.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n_elements,
|
||||
alpha,
|
||||
limit,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
72
crates/xserv-kernels/src/argmax.rs
Normal file
72
crates/xserv-kernels/src/argmax.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_argmax_bf16(
|
||||
logits: *const c_void,
|
||||
out_idx: *mut c_void,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// GPU argmax over the last dim of a [rows, cols] BF16 tensor.
|
||||
///
|
||||
/// Returns a host `Vec<u32>` of length `rows`. Internally:
|
||||
/// - launches one kernel that writes [rows] i32 indices on device
|
||||
/// - D2H copies just `rows * 4` bytes (vs `rows * cols * 2` for the
|
||||
/// "copy logits to CPU then argmax" path it replaces)
|
||||
///
|
||||
/// This is the greedy-decode hot path: avoids touching the full
|
||||
/// [B, vocab] logits buffer on the host every step.
|
||||
pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
|
||||
assert_eq!(logits.ndim(), 2, "argmax expects a 2D [rows, cols] tensor");
|
||||
assert_eq!(logits.dtype(), DType::BF16, "argmax kernel is BF16-only");
|
||||
assert!(logits.is_contiguous(), "argmax requires contiguous input");
|
||||
assert!(
|
||||
matches!(logits.device(), Device::Cuda(_)),
|
||||
"argmax requires GPU input"
|
||||
);
|
||||
|
||||
let rows = logits.shape()[0];
|
||||
let cols = logits.shape()[1];
|
||||
assert!(rows <= i32::MAX as usize);
|
||||
assert!(cols <= i32::MAX as usize);
|
||||
|
||||
// Output buffer: rows * i32. Pooled allocator so this is essentially free
|
||||
// after the first call.
|
||||
let bytes = rows * std::mem::size_of::<i32>();
|
||||
let mut out = xserv_cuda::allocator::cached_alloc(bytes).expect("argmax out alloc");
|
||||
|
||||
unsafe {
|
||||
launch_argmax_bf16(
|
||||
logits.data_ptr() as *const c_void,
|
||||
out.as_mut_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
let mut host_bytes = vec![0u8; bytes];
|
||||
out.copy_to_host(&mut host_bytes).expect("argmax D2H");
|
||||
drop(out); // returned to pool
|
||||
|
||||
let host_i32: &[i32] =
|
||||
unsafe { std::slice::from_raw_parts(host_bytes.as_ptr() as *const i32, rows) };
|
||||
host_i32.iter().map(|&v| v as u32).collect()
|
||||
}
|
||||
|
||||
/// Convenience: argmax of a single row [1, cols] (or [cols] reshaped to [1, cols]).
|
||||
pub fn argmax_bf16_single(logits: &Tensor) -> u32 {
|
||||
let cols = *logits.shape().last().unwrap();
|
||||
let rows = logits.numel() / cols;
|
||||
assert_eq!(rows, 1, "argmax_bf16_single requires a single row");
|
||||
let view = if logits.ndim() == 2 {
|
||||
logits.clone()
|
||||
} else {
|
||||
logits.reshape(&[1, cols])
|
||||
};
|
||||
argmax_bf16_to_host(&view)[0]
|
||||
}
|
||||
682
crates/xserv-kernels/src/attention.rs
Normal file
682
crates/xserv-kernels/src/attention.rs
Normal file
@@ -0,0 +1,682 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Tensor};
|
||||
|
||||
use crate::activation::scale;
|
||||
use crate::gemm::batched_matmul;
|
||||
use crate::softmax::softmax;
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_causal_mask_f32(
|
||||
scores: *mut c_void,
|
||||
batch: i32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
offset: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_causal_mask_bf16(
|
||||
scores: *mut c_void,
|
||||
batch: i32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
offset: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_flash_attention_bf16(
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
q_len: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
causal: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_flash_attention_sinks_bf16(
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
sinks: *const c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
q_len: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
causal: i32,
|
||||
window_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_decode_attention_bf16(
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
causal: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_paged_decode_attention_bf16(
|
||||
q: *const c_void,
|
||||
k_cache: *const c_void,
|
||||
v_cache: *const c_void,
|
||||
o: *mut c_void,
|
||||
block_tables: *const i32,
|
||||
context_lens: *const i32,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
head_dim: i32,
|
||||
max_blocks_per_seq: i32,
|
||||
scale: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_paged_decode_attention_tree_bf16(
|
||||
q: *const c_void,
|
||||
k_cache: *const c_void,
|
||||
v_cache: *const c_void,
|
||||
o: *mut c_void,
|
||||
block_tables: *const i32,
|
||||
context_lens: *const i32,
|
||||
tree_mask: *const i32,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
head_dim: i32,
|
||||
max_blocks_per_seq: i32,
|
||||
tree_start: i32,
|
||||
tree_len: i32,
|
||||
scale: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_paged_decode_attention_sinks_bf16(
|
||||
q: *const c_void,
|
||||
k_cache: *const c_void,
|
||||
v_cache: *const c_void,
|
||||
o: *mut c_void,
|
||||
block_tables: *const i32,
|
||||
context_lens: *const i32,
|
||||
sinks: *const c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
head_dim: i32,
|
||||
max_blocks_per_seq: i32,
|
||||
scale: f32,
|
||||
window_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_reshape_and_cache_bf16(
|
||||
k_src: *const c_void,
|
||||
v_src: *const c_void,
|
||||
k_pool: *mut c_void,
|
||||
v_pool: *mut c_void,
|
||||
block_ids: *const c_void,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
start_pos: i32,
|
||||
block_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_reshape_and_cache_batched_bf16(
|
||||
k_src: *const c_void,
|
||||
v_src: *const c_void,
|
||||
k_pool: *mut c_void,
|
||||
v_pool: *mut c_void,
|
||||
block_tables: *const c_void,
|
||||
kv_lens: *const c_void,
|
||||
batch: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
block_size: i32,
|
||||
max_blocks_per_seq: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_copy_kv_position(
|
||||
k_pool: *mut c_void,
|
||||
v_pool: *mut c_void,
|
||||
block_ids: *const i32,
|
||||
src_pos: i32,
|
||||
dst_pos: i32,
|
||||
num_kv_heads: i32,
|
||||
head_dim: i32,
|
||||
block_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// Scatter `[num_kv_heads, num_tokens, head_dim]` BF16 K/V into a paged
|
||||
/// pool for a single sequence whose block table lives at `block_ids_gpu`
|
||||
/// (int32, on device).
|
||||
///
|
||||
/// `k_pool_ptr`/`v_pool_ptr` point to one layer's pool, of logical shape
|
||||
/// `[num_blocks_total, num_kv_heads, block_size, head_dim]`.
|
||||
///
|
||||
/// All pointers must be on the same GPU as the launching context.
|
||||
///
|
||||
/// # Safety
|
||||
/// Pointers must be valid GPU pointers with the documented layouts.
|
||||
/// `block_ids_gpu` must contain at least `(start_pos + num_tokens + block_size - 1) / block_size`
|
||||
/// valid physical block ids.
|
||||
pub unsafe fn reshape_and_cache_bf16(
|
||||
k_src: *const c_void,
|
||||
v_src: *const c_void,
|
||||
k_pool_ptr: *mut c_void,
|
||||
v_pool_ptr: *mut c_void,
|
||||
block_ids_gpu: *const i32,
|
||||
num_tokens: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
start_pos: usize,
|
||||
block_size: usize,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
unsafe {
|
||||
launch_reshape_and_cache_bf16(
|
||||
k_src,
|
||||
v_src,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
block_ids_gpu as *const c_void,
|
||||
num_tokens as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
start_pos as i32,
|
||||
block_size as i32,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Batched scatter for the multi-sequence decode step. Reads
|
||||
/// `block_tables` (`[batch, max_blocks_per_seq]` int32 — same buffer the
|
||||
/// paged-attention kernel reads) and `kv_lens` (`[batch]` int32, current
|
||||
/// seq_len + 1 — i.e., the index of the just-written token + 1) so the
|
||||
/// caller doesn't need a separate per-step upload of block ids.
|
||||
///
|
||||
/// # Safety
|
||||
/// All pointers must be on the same GPU. `block_tables` and `kv_lens` must
|
||||
/// already be synced to the device for the active batch.
|
||||
pub unsafe fn reshape_and_cache_batched_bf16(
|
||||
k_src: *const c_void,
|
||||
v_src: *const c_void,
|
||||
k_pool_ptr: *mut c_void,
|
||||
v_pool_ptr: *mut c_void,
|
||||
block_tables_gpu: *const i32,
|
||||
kv_lens_gpu: *const i32,
|
||||
batch: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
block_size: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
unsafe {
|
||||
launch_reshape_and_cache_batched_bf16(
|
||||
k_src,
|
||||
v_src,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
block_tables_gpu as *const c_void,
|
||||
kv_lens_gpu as *const c_void,
|
||||
batch as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
block_size as i32,
|
||||
max_blocks_per_seq as i32,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy one token's K/V from `src_pos` to `dst_pos` within the same sequence's
|
||||
/// paged cache (one layer). Used by tree speculative decoding to remap
|
||||
/// accepted sibling K/V to canonical sequential positions after acceptance.
|
||||
///
|
||||
/// # Safety
|
||||
/// Pool and block_ids pointers must be valid GPU pointers for the given layer.
|
||||
pub unsafe fn copy_kv_position(
|
||||
k_pool_ptr: *mut c_void,
|
||||
v_pool_ptr: *mut c_void,
|
||||
block_ids_gpu: *const i32,
|
||||
src_pos: usize,
|
||||
dst_pos: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
block_size: usize,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_copy_kv_position(
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
block_ids_gpu,
|
||||
src_pos as i32,
|
||||
dst_pos as i32,
|
||||
num_kv_heads as i32,
|
||||
head_dim as i32,
|
||||
block_size as i32,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
||||
let ndim = scores.ndim();
|
||||
let rows = scores.shape()[ndim - 2];
|
||||
let cols = scores.shape()[ndim - 1];
|
||||
let batch: usize = scores.shape()[..ndim - 2].iter().product();
|
||||
|
||||
unsafe {
|
||||
match scores.dtype() {
|
||||
DType::F32 => launch_causal_mask_f32(
|
||||
scores.data_ptr() as *mut c_void,
|
||||
batch as i32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
offset as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_causal_mask_bf16(
|
||||
scores.data_ptr() as *mut c_void,
|
||||
batch as i32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
offset as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for causal mask"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Multi-head attention (naive, materializes S×S score matrix).
|
||||
///
|
||||
/// q, k, v: [batch, num_heads, seq_len, head_dim] — contiguous, on GPU
|
||||
/// Returns: [batch, num_heads, seq_len, head_dim]
|
||||
pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(k.ndim(), 4);
|
||||
assert_eq!(v.ndim(), 4);
|
||||
assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous());
|
||||
|
||||
let batch = q.shape()[0];
|
||||
let num_heads = q.shape()[1];
|
||||
let q_len = q.shape()[2];
|
||||
let head_dim = q.shape()[3];
|
||||
let kv_len = k.shape()[2];
|
||||
|
||||
assert_eq!(k.shape(), &[batch, num_heads, kv_len, head_dim]);
|
||||
assert_eq!(v.shape(), &[batch, num_heads, kv_len, head_dim]);
|
||||
|
||||
// scores = Q @ K^T → [B, H, q_len, kv_len]
|
||||
let k_t = k.transpose(2, 3).contiguous();
|
||||
let scores = batched_matmul(q, &k_t);
|
||||
|
||||
// Scale by 1/sqrt(head_dim)
|
||||
let scale_factor = 1.0 / (head_dim as f32).sqrt();
|
||||
let scaled_scores = scale(&scores, scale_factor);
|
||||
|
||||
// Causal mask
|
||||
if causal {
|
||||
let offset = kv_len - q_len;
|
||||
apply_causal_mask(&scaled_scores, offset);
|
||||
}
|
||||
|
||||
// Softmax
|
||||
let weights = softmax(&scaled_scores);
|
||||
|
||||
// output = weights @ V → [B, H, q_len, head_dim]
|
||||
batched_matmul(&weights, v)
|
||||
}
|
||||
|
||||
/// Decode Attention — optimized for single-token decode (q_len=1).
|
||||
///
|
||||
/// q: [batch, num_q_heads, 1, head_dim] BF16, contiguous, GPU
|
||||
/// k: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
|
||||
/// v: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
|
||||
///
|
||||
/// Returns: [batch, num_q_heads, 1, head_dim] BF16
|
||||
pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(q.shape()[2], 1, "decode_attention requires q_len == 1");
|
||||
|
||||
let batch = q.shape()[0];
|
||||
let num_q_heads = q.shape()[1];
|
||||
let head_dim = q.shape()[3];
|
||||
let num_kv_heads = k.shape()[1];
|
||||
let kv_len = k.shape()[2];
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
|
||||
|
||||
unsafe {
|
||||
launch_decode_attention_bf16(
|
||||
q.data_ptr() as *const c_void,
|
||||
k.data_ptr() as *const c_void,
|
||||
v.data_ptr() as *const c_void,
|
||||
output.data_ptr() as *mut c_void,
|
||||
batch as i32,
|
||||
num_q_heads as i32,
|
||||
num_kv_heads as i32,
|
||||
kv_len as i32,
|
||||
head_dim as i32,
|
||||
scale,
|
||||
1, // causal (always 1 for decode)
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Flash Attention 2 — O(1) extra memory, supports GQA natively.
|
||||
/// Auto-dispatches to decode_attention when q_len == 1.
|
||||
///
|
||||
/// q: [batch, num_q_heads, q_len, head_dim] BF16, contiguous, GPU
|
||||
/// k: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
|
||||
/// v: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
|
||||
///
|
||||
/// Returns: [batch, num_q_heads, q_len, head_dim] BF16
|
||||
pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(k.ndim(), 4);
|
||||
assert_eq!(v.ndim(), 4);
|
||||
assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous());
|
||||
assert_eq!(q.dtype(), DType::BF16, "flash_attention requires BF16");
|
||||
assert_eq!(k.dtype(), DType::BF16);
|
||||
assert_eq!(v.dtype(), DType::BF16);
|
||||
|
||||
let batch = q.shape()[0];
|
||||
let num_q_heads = q.shape()[1];
|
||||
let q_len = q.shape()[2];
|
||||
let head_dim = q.shape()[3];
|
||||
let num_kv_heads = k.shape()[1];
|
||||
let kv_len = k.shape()[2];
|
||||
|
||||
assert_eq!(k.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||||
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||||
assert!(
|
||||
num_q_heads % num_kv_heads == 0,
|
||||
"num_q_heads must be divisible by num_kv_heads"
|
||||
);
|
||||
assert!(
|
||||
head_dim <= 128,
|
||||
"flash_attention supports head_dim up to 128"
|
||||
);
|
||||
|
||||
// Dispatch to specialized decode kernel for single-token generation
|
||||
if q_len == 1 {
|
||||
return decode_attention(q, k, v);
|
||||
}
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(
|
||||
&[batch, num_q_heads, q_len, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
);
|
||||
|
||||
unsafe {
|
||||
launch_flash_attention_bf16(
|
||||
q.data_ptr() as *const c_void,
|
||||
k.data_ptr() as *const c_void,
|
||||
v.data_ptr() as *const c_void,
|
||||
output.data_ptr() as *mut c_void,
|
||||
batch as i32,
|
||||
num_q_heads as i32,
|
||||
num_kv_heads as i32,
|
||||
q_len as i32,
|
||||
kv_len as i32,
|
||||
head_dim as i32,
|
||||
scale,
|
||||
if causal { 1 } else { 0 },
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Flash attention for prefill with gpt-oss attention sinks + optional sliding window.
|
||||
///
|
||||
/// Same layout/contract as `flash_attention`, plus a per-head `sinks` tensor
|
||||
/// ([num_q_heads] BF16, GPU) folded into the softmax denominator, and a
|
||||
/// `window_size` (0 = full causal, >0 = sliding window). Always causal.
|
||||
pub fn flash_attention_sinks(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
sinks: &Tensor,
|
||||
window_size: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(k.ndim(), 4);
|
||||
assert_eq!(v.ndim(), 4);
|
||||
assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous());
|
||||
assert_eq!(q.dtype(), DType::BF16);
|
||||
assert_eq!(k.dtype(), DType::BF16);
|
||||
assert_eq!(v.dtype(), DType::BF16);
|
||||
|
||||
let batch = q.shape()[0];
|
||||
let num_q_heads = q.shape()[1];
|
||||
let q_len = q.shape()[2];
|
||||
let head_dim = q.shape()[3];
|
||||
let num_kv_heads = k.shape()[1];
|
||||
let kv_len = k.shape()[2];
|
||||
|
||||
assert_eq!(k.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||||
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||||
assert!(num_q_heads % num_kv_heads == 0);
|
||||
assert!(head_dim <= 128);
|
||||
assert_eq!(
|
||||
sinks.shape()[0],
|
||||
num_q_heads,
|
||||
"sinks must have num_q_heads entries"
|
||||
);
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(
|
||||
&[batch, num_q_heads, q_len, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
);
|
||||
|
||||
unsafe {
|
||||
launch_flash_attention_sinks_bf16(
|
||||
q.data_ptr() as *const c_void,
|
||||
k.data_ptr() as *const c_void,
|
||||
v.data_ptr() as *const c_void,
|
||||
output.data_ptr() as *mut c_void,
|
||||
sinks.data_ptr() as *const c_void,
|
||||
batch as i32,
|
||||
num_q_heads as i32,
|
||||
num_kv_heads as i32,
|
||||
q_len as i32,
|
||||
kv_len as i32,
|
||||
head_dim as i32,
|
||||
scale,
|
||||
1, // always causal
|
||||
window_size as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Paged decode attention.
|
||||
///
|
||||
/// q: [batch, num_q_heads, 1, head_dim] BF16, contiguous, GPU
|
||||
/// k_cache_ptr / v_cache_ptr: pointers to [num_blocks, num_kv_heads, BLOCK_SIZE, head_dim] BF16 pools
|
||||
/// block_tables_ptr: i32 [batch, max_blocks_per_seq] (rows already arranged for this batch)
|
||||
/// context_lens_ptr: i32 [batch]
|
||||
///
|
||||
/// Returns: [batch, num_q_heads, 1, head_dim] BF16
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn paged_decode_attention(
|
||||
q: &Tensor,
|
||||
k_cache_ptr: *const c_void,
|
||||
v_cache_ptr: *const c_void,
|
||||
block_tables_ptr: *const i32,
|
||||
context_lens_ptr: *const i32,
|
||||
batch: usize,
|
||||
num_q_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(
|
||||
q.shape()[2],
|
||||
1,
|
||||
"paged_decode_attention requires q_len == 1"
|
||||
);
|
||||
assert_eq!(q.dtype(), DType::BF16);
|
||||
assert!(
|
||||
num_q_heads % num_kv_heads == 0,
|
||||
"GQA: num_q_heads must be divisible by num_kv_heads"
|
||||
);
|
||||
assert!(head_dim <= 128);
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
|
||||
|
||||
unsafe {
|
||||
launch_paged_decode_attention_bf16(
|
||||
q.data_ptr() as *const c_void,
|
||||
k_cache_ptr,
|
||||
v_cache_ptr,
|
||||
output.data_ptr() as *mut c_void,
|
||||
block_tables_ptr,
|
||||
context_lens_ptr,
|
||||
batch as i32,
|
||||
num_q_heads as i32,
|
||||
num_kv_heads as i32,
|
||||
head_dim as i32,
|
||||
max_blocks_per_seq as i32,
|
||||
scale,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Tree-aware paged decode attention. Adds a per-query attention mask over
|
||||
/// the newly-written K/V region `[tree_start, tree_start+tree_len)`. Query i
|
||||
/// attends to position tree_start+j iff tree_mask[i, j] != 0. Positions <
|
||||
/// tree_start are always attended.
|
||||
///
|
||||
/// Used by speculative decoding with tree drafting to let sibling candidates
|
||||
/// share position slots without seeing each other's K/V.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn paged_decode_attention_tree(
|
||||
q: &Tensor,
|
||||
k_cache_ptr: *const c_void,
|
||||
v_cache_ptr: *const c_void,
|
||||
block_tables_ptr: *const i32,
|
||||
context_lens_ptr: *const i32,
|
||||
tree_mask_ptr: *const i32,
|
||||
batch: usize,
|
||||
num_q_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
tree_start: usize,
|
||||
tree_len: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(q.shape()[2], 1);
|
||||
assert_eq!(q.dtype(), DType::BF16);
|
||||
assert!(num_q_heads % num_kv_heads == 0);
|
||||
assert!(head_dim <= 128);
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
|
||||
|
||||
unsafe {
|
||||
launch_paged_decode_attention_tree_bf16(
|
||||
q.data_ptr() as *const c_void,
|
||||
k_cache_ptr,
|
||||
v_cache_ptr,
|
||||
output.data_ptr() as *mut c_void,
|
||||
block_tables_ptr,
|
||||
context_lens_ptr,
|
||||
tree_mask_ptr,
|
||||
batch as i32,
|
||||
num_q_heads as i32,
|
||||
num_kv_heads as i32,
|
||||
head_dim as i32,
|
||||
max_blocks_per_seq as i32,
|
||||
tree_start as i32,
|
||||
tree_len as i32,
|
||||
scale,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Paged decode attention with attention sinks and optional sliding window.
|
||||
///
|
||||
/// sinks_ptr: pointer to [num_q_heads] BF16 on GPU (or null for no sinks)
|
||||
/// window_size: 0 = full attention, >0 = sliding window
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn paged_decode_attention_sinks(
|
||||
q: &Tensor,
|
||||
k_cache_ptr: *const c_void,
|
||||
v_cache_ptr: *const c_void,
|
||||
block_tables_ptr: *const i32,
|
||||
context_lens_ptr: *const i32,
|
||||
sinks_ptr: *const c_void,
|
||||
batch: usize,
|
||||
num_q_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
window_size: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(q.shape()[2], 1);
|
||||
assert_eq!(q.dtype(), DType::BF16);
|
||||
assert!(num_q_heads % num_kv_heads == 0);
|
||||
assert!(head_dim <= 128);
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
|
||||
|
||||
unsafe {
|
||||
launch_paged_decode_attention_sinks_bf16(
|
||||
q.data_ptr() as *const c_void,
|
||||
k_cache_ptr,
|
||||
v_cache_ptr,
|
||||
output.data_ptr() as *mut c_void,
|
||||
block_tables_ptr,
|
||||
context_lens_ptr,
|
||||
sinks_ptr,
|
||||
batch as i32,
|
||||
num_q_heads as i32,
|
||||
num_kv_heads as i32,
|
||||
head_dim as i32,
|
||||
max_blocks_per_seq as i32,
|
||||
scale,
|
||||
window_size as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
316
crates/xserv-kernels/src/dispatch.rs
Normal file
316
crates/xserv-kernels/src/dispatch.rs
Normal file
@@ -0,0 +1,316 @@
|
||||
//! Low-level kernel dispatchers for CUDA Graph capture.
|
||||
//! These functions write to pre-allocated output buffers and accept an explicit stream.
|
||||
|
||||
use std::ffi::c_void;
|
||||
|
||||
// Re-declare the extern functions we need (same as in the individual modules)
|
||||
unsafe extern "C" {
|
||||
fn launch_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
residual: *const c_void,
|
||||
gamma: *const c_void,
|
||||
normed_out: *mut c_void,
|
||||
sum_out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_silu_mul_bf16(
|
||||
gate: *const c_void,
|
||||
up: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_embedding_bf16(
|
||||
table: *const c_void,
|
||||
token_ids: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden_size: i32,
|
||||
vocab_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_reshape_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_merge_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_transpose_hsd_to_shd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_transpose_shd_to_hsd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_rope_bf16(
|
||||
x: *mut c_void,
|
||||
cos_cache: *const c_void,
|
||||
sin_cache: *const c_void,
|
||||
positions: *const c_void,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gemv_bf16(
|
||||
x: *const c_void,
|
||||
w: *const c_void,
|
||||
y_bf16: *mut c_void,
|
||||
y_fp32_buf: *mut c_void,
|
||||
k: i32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_decode_attention_bf16(
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
causal: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// Raw rmsnorm dispatch: writes to pre-allocated `out`.
|
||||
pub unsafe fn rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_rmsnorm_bf16(x, gamma, out, rows, hidden_size, eps, stream);
|
||||
}
|
||||
|
||||
/// Raw add_rmsnorm dispatch.
|
||||
pub unsafe fn add_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
residual: *const c_void,
|
||||
gamma: *const c_void,
|
||||
normed_out: *mut c_void,
|
||||
sum_out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_add_rmsnorm_bf16(
|
||||
x,
|
||||
residual,
|
||||
gamma,
|
||||
normed_out,
|
||||
sum_out,
|
||||
rows,
|
||||
hidden_size,
|
||||
eps,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
/// Raw silu_mul dispatch.
|
||||
pub unsafe fn silu_mul_bf16(
|
||||
gate: *const c_void,
|
||||
up: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_silu_mul_bf16(gate, up, out, n, stream);
|
||||
}
|
||||
|
||||
/// Raw add dispatch.
|
||||
pub unsafe fn add_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_add_bf16(a, b, out, n, stream);
|
||||
}
|
||||
|
||||
/// Raw embedding dispatch.
|
||||
pub unsafe fn embedding_bf16(
|
||||
table: *const c_void,
|
||||
token_ids: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden_size: i32,
|
||||
vocab_size: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_embedding_bf16(
|
||||
table,
|
||||
token_ids,
|
||||
out,
|
||||
num_tokens,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
/// Raw reshape_heads dispatch.
|
||||
pub unsafe fn reshape_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_reshape_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw merge_heads dispatch.
|
||||
pub unsafe fn merge_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_merge_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw transpose HSD->SHD dispatch.
|
||||
pub unsafe fn transpose_hsd_to_shd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_transpose_hsd_to_shd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw transpose SHD->HSD dispatch.
|
||||
pub unsafe fn transpose_shd_to_hsd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_transpose_shd_to_hsd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw RoPE dispatch (in-place).
|
||||
pub unsafe fn rope_bf16(
|
||||
x: *mut c_void,
|
||||
cos_cache: *const c_void,
|
||||
sin_cache: *const c_void,
|
||||
positions: *const c_void,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_rope_bf16(
|
||||
x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream,
|
||||
);
|
||||
}
|
||||
|
||||
/// Raw GEMV dispatch (BF16, M=1). Caller must provide fp32 accumulator buffer.
|
||||
pub unsafe fn gemv_bf16(
|
||||
x: *const c_void,
|
||||
w: *const c_void,
|
||||
y_bf16: *mut c_void,
|
||||
y_fp32_buf: *mut c_void,
|
||||
k: i32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_gemv_bf16(x, w, y_bf16, y_fp32_buf, k, n, stream);
|
||||
}
|
||||
|
||||
/// Raw decode attention dispatch.
|
||||
pub unsafe fn decode_attention_bf16(
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_decode_attention_bf16(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
batch,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
kv_len,
|
||||
head_dim,
|
||||
scale,
|
||||
1,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
// cuBLAS FFI
|
||||
pub type CublasHandle = *mut c_void;
|
||||
|
||||
unsafe extern "C" {
|
||||
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
|
||||
}
|
||||
|
||||
/// Set cuBLAS stream. Must be called before any cuBLAS operations during graph capture.
|
||||
pub unsafe fn set_cublas_stream(handle: CublasHandle, stream: *mut c_void) {
|
||||
cublasSetStream_v2(handle, stream);
|
||||
}
|
||||
@@ -1,12 +1,25 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_embedding_f32(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, stream: *mut c_void);
|
||||
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, stream: *mut c_void);
|
||||
fn launch_embedding_f32(
|
||||
table: *const c_void,
|
||||
token_ids: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden_size: i32,
|
||||
vocab_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_embedding_bf16(
|
||||
table: *const c_void,
|
||||
token_ids: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden_size: i32,
|
||||
vocab_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// Embedding lookup: table[token_ids[i]] for each i.
|
||||
@@ -18,6 +31,15 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
||||
|
||||
let hidden_size = table.shape()[1];
|
||||
let num_tokens = token_ids.len();
|
||||
let vocab_size = table.shape()[0];
|
||||
assert!(
|
||||
num_tokens <= i32::MAX as usize,
|
||||
"too many tokens for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
hidden_size <= i32::MAX as usize,
|
||||
"hidden_size too large for i32 kernel param"
|
||||
);
|
||||
|
||||
// Upload token_ids to GPU
|
||||
let ids_bytes = unsafe {
|
||||
@@ -26,26 +48,54 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
||||
num_tokens * std::mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
let mut ids_gpu = GpuBuffer::alloc(ids_bytes.len()).expect("alloc token_ids");
|
||||
let mut ids_gpu =
|
||||
xserv_cuda::allocator::cached_alloc(ids_bytes.len()).expect("alloc token_ids");
|
||||
ids_gpu.copy_from_host(ids_bytes).unwrap();
|
||||
|
||||
let out = Tensor::zeros(&[num_tokens, hidden_size], table.dtype(), table.device());
|
||||
for &tid in token_ids {
|
||||
assert!(
|
||||
(tid as usize) < vocab_size,
|
||||
"token_id {tid} out of bounds (vocab_size={vocab_size})"
|
||||
);
|
||||
}
|
||||
|
||||
embedding_device_ids(table, ids_gpu.as_ptr() as *const c_void, num_tokens)
|
||||
}
|
||||
|
||||
/// Embedding lookup with token ids already on the GPU (u32, [num_tokens]).
|
||||
/// Used by the CUDA-graph decode path, where ids live in a persistent device
|
||||
/// buffer updated outside the captured region (no bounds check possible here).
|
||||
pub fn embedding_device_ids(table: &Tensor, ids_gpu: *const c_void, num_tokens: usize) -> Tensor {
|
||||
assert_eq!(table.ndim(), 2);
|
||||
assert!(table.is_contiguous());
|
||||
assert!(matches!(table.device(), Device::Cuda(_)));
|
||||
let hidden_size = table.shape()[1];
|
||||
let vocab_size = table.shape()[0];
|
||||
|
||||
let out = Tensor::empty(&[num_tokens, hidden_size], table.dtype(), table.device());
|
||||
|
||||
unsafe {
|
||||
match table.dtype() {
|
||||
DType::F32 => launch_embedding_f32(
|
||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
||||
table.data_ptr() as _,
|
||||
ids_gpu,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden_size as i32, std::ptr::null_mut(),
|
||||
num_tokens as i32,
|
||||
hidden_size as i32,
|
||||
vocab_size as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_embedding_bf16(
|
||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
||||
table.data_ptr() as _,
|
||||
ids_gpu,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden_size as i32, std::ptr::null_mut(),
|
||||
num_tokens as i32,
|
||||
hidden_size as i32,
|
||||
vocab_size as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for embedding"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
@@ -1,7 +1,36 @@
|
||||
use std::cell::RefCell;
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_cuda::error::{self, Result};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
|
||||
const GEMV_TILE_K: usize = 256;
|
||||
|
||||
// GEMV: single-kernel, no FP32 temp buffer needed
|
||||
unsafe extern "C" {
|
||||
fn launch_gemv_bf16(
|
||||
x: *const c_void,
|
||||
w: *const c_void,
|
||||
y_bf16: *mut c_void,
|
||||
y_fp32_buf: *mut c_void,
|
||||
k: i32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
fn launch_gemv_bf16_batched(
|
||||
x: *const c_void,
|
||||
w: *const c_void,
|
||||
y_bf16: *mut c_void,
|
||||
y_fp32_buf: *mut c_void,
|
||||
m: i32,
|
||||
k: i32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum GemmBackend {
|
||||
Naive,
|
||||
@@ -9,16 +38,101 @@ pub enum GemmBackend {
|
||||
CuBlas,
|
||||
}
|
||||
|
||||
pub fn gemv_scratch_elems(k: usize, n: usize) -> usize {
|
||||
n * k.div_ceil(GEMV_TILE_K)
|
||||
}
|
||||
|
||||
/// Batched GEMV: [M, K] × [K, N] → [M, N], all BF16.
|
||||
/// Bit-exact with calling matmul on each row individually (same K-block partial
|
||||
/// + fixed-order reduction path), but in a single kernel launch per phase.
|
||||
pub fn matmul_batched_gemv(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert_eq!(b.ndim(), 2);
|
||||
assert!(a.is_contiguous());
|
||||
assert!(b.is_contiguous());
|
||||
assert_eq!(a.dtype(), DType::BF16);
|
||||
assert_eq!(b.dtype(), DType::BF16);
|
||||
let m = a.shape()[0];
|
||||
let k = a.shape()[1];
|
||||
let n = b.shape()[1];
|
||||
assert_eq!(b.shape()[0], k);
|
||||
|
||||
let out = Tensor::empty(&[m, n], DType::BF16, a.device());
|
||||
let scratch_elems = m * gemv_scratch_elems(k, n);
|
||||
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(scratch_elems * 4).unwrap();
|
||||
|
||||
let null_stream = xserv_cuda::current_stream_raw();
|
||||
if m == 1 {
|
||||
unsafe {
|
||||
launch_gemv_bf16(
|
||||
a.data_ptr() as *const c_void,
|
||||
b.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
fp32_buf.as_mut_ptr() as *mut c_void,
|
||||
k as i32,
|
||||
n as i32,
|
||||
null_stream,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
unsafe {
|
||||
launch_gemv_bf16_batched(
|
||||
a.data_ptr() as *const c_void,
|
||||
b.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
fp32_buf.as_mut_ptr() as *mut c_void,
|
||||
m as i32,
|
||||
k as i32,
|
||||
n as i32,
|
||||
null_stream,
|
||||
);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
// --- FFI: custom CUDA kernels ---
|
||||
unsafe extern "C" {
|
||||
fn launch_gemm_naive_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_naive_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_tiled_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_tiled_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_naive_f32(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
c: *mut c_void,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gemm_naive_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
c: *mut c_void,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gemm_tiled_f32(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
c: *mut c_void,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gemm_tiled_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
c: *mut c_void,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
// --- FFI: cuBLAS ---
|
||||
type CublasHandle = *mut c_void;
|
||||
pub type CublasHandle = *mut c_void;
|
||||
|
||||
#[allow(non_upper_case_globals)]
|
||||
const CUBLAS_OP_N: i32 = 0;
|
||||
@@ -34,15 +148,50 @@ unsafe extern "C" {
|
||||
fn cublasCreate_v2(handle: *mut CublasHandle) -> i32;
|
||||
fn cublasDestroy_v2(handle: CublasHandle) -> i32;
|
||||
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
|
||||
fn cublasSetWorkspace_v2(handle: CublasHandle, workspace: *mut c_void, size: usize) -> i32;
|
||||
fn cublasGemmEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32, transb: i32,
|
||||
m: i32, n: i32, k: i32,
|
||||
transa: i32,
|
||||
transb: i32,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void, a_type: i32, lda: i32,
|
||||
b: *const c_void, b_type: i32, ldb: i32,
|
||||
a: *const c_void,
|
||||
a_type: i32,
|
||||
lda: i32,
|
||||
b: *const c_void,
|
||||
b_type: i32,
|
||||
ldb: i32,
|
||||
beta: *const c_void,
|
||||
c: *mut c_void, c_type: i32, ldc: i32,
|
||||
c: *mut c_void,
|
||||
c_type: i32,
|
||||
ldc: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
) -> i32;
|
||||
fn cublasGemmStridedBatchedEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32,
|
||||
transb: i32,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void,
|
||||
a_type: i32,
|
||||
lda: i32,
|
||||
stride_a: i64,
|
||||
b: *const c_void,
|
||||
b_type: i32,
|
||||
ldb: i32,
|
||||
stride_b: i64,
|
||||
beta: *const c_void,
|
||||
c: *mut c_void,
|
||||
c_type: i32,
|
||||
ldc: i32,
|
||||
stride_c: i64,
|
||||
batch_count: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
) -> i32;
|
||||
@@ -50,13 +199,32 @@ unsafe extern "C" {
|
||||
|
||||
pub struct CublasContext {
|
||||
handle: CublasHandle,
|
||||
/// Dedicated 32 MiB workspace owned by this handle. Held to keep the GPU
|
||||
/// buffer alive for the lifetime of the handle; cuBLAS reads/writes into
|
||||
/// it during GEMM. Dropped after `cublasDestroy_v2` so cuBLAS can't touch
|
||||
/// freed memory.
|
||||
_workspace: Option<GpuBuffer>,
|
||||
}
|
||||
|
||||
impl CublasContext {
|
||||
pub fn new() -> Result<Self> {
|
||||
let mut handle = std::ptr::null_mut();
|
||||
error::check(unsafe { cublasCreate_v2(&mut handle) })?;
|
||||
Ok(Self { handle })
|
||||
// Attach a per-handle workspace. cublasSetWorkspace requires the
|
||||
// pointer to remain valid until destroy or until a new workspace is
|
||||
// set, so we keep the GpuBuffer in this struct.
|
||||
let mut workspace = GpuBuffer::alloc(CUBLAS_WORKSPACE_BYTES)?;
|
||||
error::check(unsafe {
|
||||
cublasSetWorkspace_v2(
|
||||
handle,
|
||||
workspace.as_mut_ptr() as *mut c_void,
|
||||
CUBLAS_WORKSPACE_BYTES,
|
||||
)
|
||||
})?;
|
||||
Ok(Self {
|
||||
handle,
|
||||
_workspace: Some(workspace),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,9 +233,32 @@ impl Drop for CublasContext {
|
||||
if !self.handle.is_null() {
|
||||
unsafe { cublasDestroy_v2(self.handle) };
|
||||
}
|
||||
// _workspace drops here, after cublasDestroy_v2 has released the handle.
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static CUBLAS_CTX: RefCell<CublasContext> = RefCell::new(
|
||||
CublasContext::new().expect("failed to create thread-local cuBLAS handle")
|
||||
);
|
||||
}
|
||||
|
||||
/// Borrow the thread-local cuBLAS handle for the duration of a closure.
|
||||
fn with_cublas<F, R>(f: F) -> R
|
||||
where
|
||||
F: FnOnce(CublasHandle) -> R,
|
||||
{
|
||||
CUBLAS_CTX.with(|cell| {
|
||||
let ctx = cell.borrow();
|
||||
f(ctx.handle)
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the thread-local cuBLAS handle for use with dispatch module.
|
||||
pub fn cublas_handle() -> CublasHandle {
|
||||
CUBLAS_CTX.with(|cell| cell.borrow().handle)
|
||||
}
|
||||
|
||||
/// Matrix multiplication: C = A @ B
|
||||
/// A: [M, K], B: [K, N], C: [M, N]
|
||||
/// All tensors must be contiguous and on the same GPU.
|
||||
@@ -76,76 +267,206 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
assert_eq!(b.ndim(), 2);
|
||||
assert_eq!(a.shape()[1], b.shape()[0], "inner dimension mismatch");
|
||||
assert_eq!(a.dtype(), b.dtype(), "dtype mismatch");
|
||||
assert!(a.is_contiguous() && b.is_contiguous(), "matmul requires contiguous tensors");
|
||||
assert!(matches!(a.device(), Device::Cuda(_)), "matmul requires GPU tensors");
|
||||
assert!(
|
||||
a.is_contiguous() && b.is_contiguous(),
|
||||
"matmul requires contiguous tensors"
|
||||
);
|
||||
assert!(
|
||||
matches!(a.device(), Device::Cuda(_)),
|
||||
"matmul requires GPU tensors"
|
||||
);
|
||||
|
||||
let m = a.shape()[0];
|
||||
let k = a.shape()[1];
|
||||
let n = b.shape()[1];
|
||||
let dtype = a.dtype();
|
||||
|
||||
let c = Tensor::zeros(&[m, n], dtype, a.device());
|
||||
// All backends (naive, tiled, cuBLAS with beta=0, custom GEMV) fully
|
||||
// overwrite every element of C, so we skip the cudaMemset.
|
||||
let c = Tensor::empty(&[m, n], dtype, a.device());
|
||||
|
||||
let a_ptr = a.data_ptr() as *const c_void;
|
||||
let b_ptr = b.data_ptr() as *const c_void;
|
||||
let c_ptr = c.data_ptr() as *mut c_void;
|
||||
let null_stream = std::ptr::null_mut();
|
||||
let null_stream = xserv_cuda::current_stream_raw();
|
||||
|
||||
match backend {
|
||||
GemmBackend::Naive => {
|
||||
unsafe {
|
||||
match dtype {
|
||||
DType::F32 => launch_gemm_naive_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
DType::BF16 => launch_gemm_naive_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
_ => panic!("unsupported dtype for naive GEMM"),
|
||||
}
|
||||
GemmBackend::Naive => unsafe {
|
||||
match dtype {
|
||||
DType::F32 => launch_gemm_naive_f32(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
null_stream,
|
||||
),
|
||||
DType::BF16 => launch_gemm_naive_bf16(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
null_stream,
|
||||
),
|
||||
_ => panic!("unsupported dtype for naive GEMM"),
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
GemmBackend::Tiled => {
|
||||
unsafe {
|
||||
match dtype {
|
||||
DType::F32 => launch_gemm_tiled_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
DType::BF16 => launch_gemm_tiled_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
_ => panic!("unsupported dtype for tiled GEMM"),
|
||||
}
|
||||
},
|
||||
GemmBackend::Tiled => unsafe {
|
||||
match dtype {
|
||||
DType::F32 => launch_gemm_tiled_f32(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
null_stream,
|
||||
),
|
||||
DType::BF16 => launch_gemm_tiled_bf16(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
null_stream,
|
||||
),
|
||||
_ => panic!("unsupported dtype for tiled GEMM"),
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
},
|
||||
GemmBackend::CuBlas => {
|
||||
// cuBLAS uses column-major, but we have row-major tensors.
|
||||
// Trick: compute C^T = B^T @ A^T, which gives us C in row-major.
|
||||
// cuBLAS sees our row-major data as column-major transposed.
|
||||
let ctx = CublasContext::new().unwrap();
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
if m == 1 && dtype == DType::BF16 && n >= 256 {
|
||||
let mut fp32_buf =
|
||||
xserv_cuda::allocator::cached_alloc(gemv_scratch_elems(k, n) * 4).unwrap();
|
||||
unsafe {
|
||||
launch_gemv_bf16(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
fp32_buf.as_mut_ptr() as *mut c_void,
|
||||
k as i32,
|
||||
n as i32,
|
||||
null_stream,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
|
||||
let (a_type, b_type, c_type) = match dtype {
|
||||
DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F),
|
||||
DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF),
|
||||
_ => panic!("unsupported dtype for cuBLAS GEMM"),
|
||||
};
|
||||
let (a_type, b_type, c_type) = match dtype {
|
||||
DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F),
|
||||
DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF),
|
||||
_ => panic!("unsupported dtype for cuBLAS GEMM"),
|
||||
};
|
||||
|
||||
unsafe {
|
||||
cublasSetStream_v2(ctx.handle, null_stream);
|
||||
// Row-major trick: swap A/B and transpose flags
|
||||
// C(row-major) = A @ B <=> C^T(col-major) = B^T @ A^T
|
||||
error::check(cublasGemmEx(
|
||||
ctx.handle,
|
||||
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
n as i32, m as i32, k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b_ptr, b_type, n as i32, // B as col-major = B^T
|
||||
a_ptr, a_type, k as i32, // A as col-major = A^T
|
||||
&beta as *const f32 as *const c_void,
|
||||
c_ptr, c_type, n as i32, // C as col-major = C^T
|
||||
CUBLAS_COMPUTE_32F,
|
||||
-1, // default algo
|
||||
)).expect("cuBLAS GEMM failed");
|
||||
with_cublas(|handle| unsafe {
|
||||
cublasSetStream_v2(handle, null_stream);
|
||||
error::check(cublasGemmEx(
|
||||
handle,
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
n as i32,
|
||||
m as i32,
|
||||
k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b_ptr,
|
||||
b_type,
|
||||
n as i32,
|
||||
a_ptr,
|
||||
a_type,
|
||||
k as i32,
|
||||
&beta as *const f32 as *const c_void,
|
||||
c_ptr,
|
||||
c_type,
|
||||
n as i32,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
-1,
|
||||
))
|
||||
.expect("cuBLAS GEMM failed");
|
||||
});
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
c
|
||||
}
|
||||
|
||||
/// Batched matrix multiplication via cuBLAS: C[b] = A[b] @ B[b]
|
||||
/// a: [..., M, K], b: [..., K, N] → [..., M, N]
|
||||
/// Leading dimensions must match and tensors must be contiguous.
|
||||
pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert!(a.ndim() >= 2 && b.ndim() >= 2);
|
||||
assert_eq!(a.ndim(), b.ndim());
|
||||
assert!(a.is_contiguous() && b.is_contiguous());
|
||||
assert!(matches!(a.device(), Device::Cuda(_)));
|
||||
assert_eq!(a.dtype(), b.dtype());
|
||||
|
||||
let ndim = a.ndim();
|
||||
let m = a.shape()[ndim - 2];
|
||||
let k = a.shape()[ndim - 1];
|
||||
let n = b.shape()[ndim - 1];
|
||||
assert_eq!(b.shape()[ndim - 2], k, "inner dimension mismatch");
|
||||
|
||||
// Compute batch count from leading dimensions
|
||||
let batch: usize = a.shape()[..ndim - 2].iter().product();
|
||||
assert_eq!(
|
||||
b.shape()[..ndim - 2].iter().product::<usize>(),
|
||||
batch,
|
||||
"batch dimensions mismatch"
|
||||
);
|
||||
|
||||
let mut out_shape: Vec<usize> = a.shape()[..ndim - 2].to_vec();
|
||||
out_shape.push(m);
|
||||
out_shape.push(n);
|
||||
// cuBLAS with beta=0 fully overwrites every element of C.
|
||||
let c = Tensor::empty(&out_shape, a.dtype(), a.device());
|
||||
|
||||
let dtype = a.dtype();
|
||||
let (a_type, b_type, c_type) = match dtype {
|
||||
DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F),
|
||||
DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF),
|
||||
_ => panic!("unsupported dtype for batched matmul"),
|
||||
};
|
||||
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
// cuBLAS strides are in elements (not bytes)
|
||||
let stride_a = (m * k) as i64;
|
||||
let stride_b = (k * n) as i64;
|
||||
let stride_c = (m * n) as i64;
|
||||
|
||||
with_cublas(|handle| unsafe {
|
||||
cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
|
||||
// Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major)
|
||||
error::check(cublasGemmStridedBatchedEx(
|
||||
handle,
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
n as i32,
|
||||
m as i32,
|
||||
k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b.data_ptr() as _,
|
||||
b_type,
|
||||
n as i32,
|
||||
stride_b,
|
||||
a.data_ptr() as _,
|
||||
a_type,
|
||||
k as i32,
|
||||
stride_a,
|
||||
&beta as *const f32 as *const c_void,
|
||||
c.data_ptr() as *mut c_void,
|
||||
c_type,
|
||||
n as i32,
|
||||
stride_c,
|
||||
batch as i32,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
-1,
|
||||
))
|
||||
.expect("cuBLAS batched GEMM failed");
|
||||
});
|
||||
c
|
||||
}
|
||||
|
||||
@@ -2,10 +2,26 @@ use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_layernorm_f32(x: *const c_void, gamma: *const c_void, beta: *const c_void,
|
||||
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_layernorm_bf16(x: *const c_void, gamma: *const c_void, beta: *const c_void,
|
||||
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_layernorm_f32(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
beta: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_layernorm_bf16(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
beta: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor {
|
||||
@@ -17,23 +33,40 @@ pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor
|
||||
assert_eq!(beta.shape(), &[hidden_size]);
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
assert!(
|
||||
rows <= i32::MAX as usize,
|
||||
"too many rows for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
hidden_size <= i32::MAX as usize,
|
||||
"hidden_size too large for i32 kernel param"
|
||||
);
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_layernorm_f32(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
||||
x.data_ptr() as _,
|
||||
gamma.data_ptr() as _,
|
||||
beta.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_layernorm_bf16(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
||||
x.data_ptr() as _,
|
||||
gamma.data_ptr() as _,
|
||||
beta.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for layernorm"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
@@ -1,15 +1,36 @@
|
||||
pub mod activation;
|
||||
pub mod argmax;
|
||||
pub mod attention;
|
||||
pub mod dispatch;
|
||||
pub mod embedding;
|
||||
pub mod gemm;
|
||||
pub mod layernorm;
|
||||
pub mod moe;
|
||||
pub mod quantization;
|
||||
pub mod rmsnorm;
|
||||
pub mod rope;
|
||||
pub mod softmax;
|
||||
pub mod transpose;
|
||||
|
||||
pub use activation::{gelu, silu};
|
||||
pub use embedding::embedding;
|
||||
pub use gemm::{matmul, GemmBackend};
|
||||
pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu_mul};
|
||||
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
|
||||
pub use attention::{
|
||||
attention, copy_kv_position, decode_attention, flash_attention, flash_attention_sinks,
|
||||
paged_decode_attention, paged_decode_attention_sinks, paged_decode_attention_tree,
|
||||
reshape_and_cache_batched_bf16, reshape_and_cache_bf16,
|
||||
};
|
||||
pub use embedding::{embedding, embedding_device_ids};
|
||||
pub use gemm::{GemmBackend, batched_matmul, matmul, matmul_batched_gemv};
|
||||
pub use layernorm::layernorm;
|
||||
pub use rmsnorm::rmsnorm;
|
||||
pub use rope::{rope_inplace, RopeCache};
|
||||
pub use rmsnorm::{add_rmsnorm, rmsnorm};
|
||||
pub use rope::{RopeCache, rope_inplace, rope_inplace_device_pos};
|
||||
pub use softmax::softmax;
|
||||
pub use transpose::{
|
||||
merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu,
|
||||
transpose_for_rope_gpu, transpose_from_rope_gpu,
|
||||
};
|
||||
|
||||
/// Register GPU kernels with the tensor crate. Call once at startup.
|
||||
pub fn init() {
|
||||
xserv_tensor::register_gpu_contiguous(strided_to_contiguous_gpu);
|
||||
}
|
||||
|
||||
474
crates/xserv-kernels/src/moe.rs
Normal file
474
crates/xserv-kernels/src/moe.rs
Normal file
@@ -0,0 +1,474 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Tensor};
|
||||
|
||||
use crate::gemm::{CublasHandle, cublas_handle};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_moe_topk_softmax_bf16(
|
||||
router_logits: *const c_void,
|
||||
topk_ids: *mut c_void,
|
||||
topk_weights: *mut c_void,
|
||||
num_tokens: i32,
|
||||
num_experts: i32,
|
||||
top_k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_replicate_bf16(
|
||||
x: *const c_void,
|
||||
x_rep: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden: i32,
|
||||
local_experts: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_bias_add_3d_bf16(
|
||||
x: *mut c_void,
|
||||
bias: *const c_void,
|
||||
batch: i32,
|
||||
num_tokens: i32,
|
||||
dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_weighted_sum_bf16(
|
||||
expert_out: *const c_void,
|
||||
topk_ids: *const c_void,
|
||||
topk_weights: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden: i32,
|
||||
top_k: i32,
|
||||
expert_start: i32,
|
||||
local_experts: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
fn launch_moe_sparse_gemv_fp8_bf16(
|
||||
x: *const c_void,
|
||||
w: *const c_void,
|
||||
w_scales: *const c_void,
|
||||
bias: *const c_void,
|
||||
topk_ids: *const c_void,
|
||||
y: *mut c_void,
|
||||
num_tokens: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
top_k: i32,
|
||||
expert_start: i32,
|
||||
local_experts: i32,
|
||||
x_per_slot: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_sparse_gemv_mxfp4_bf16(
|
||||
x: *const c_void,
|
||||
w_packed: *const c_void,
|
||||
w_scales: *const c_void,
|
||||
bias: *const c_void,
|
||||
topk_ids: *const c_void,
|
||||
y: *mut c_void,
|
||||
num_tokens: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
top_k: i32,
|
||||
expert_start: i32,
|
||||
local_experts: i32,
|
||||
x_per_slot: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_weighted_sum_sparse_bf16(
|
||||
down: *const c_void,
|
||||
topk_ids: *const c_void,
|
||||
topk_weights: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden: i32,
|
||||
top_k: i32,
|
||||
expert_start: i32,
|
||||
local_experts: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
fn cublasGemmStridedBatchedEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32,
|
||||
transb: i32,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void,
|
||||
a_type: i32,
|
||||
lda: i32,
|
||||
stride_a: i64,
|
||||
b: *const c_void,
|
||||
b_type: i32,
|
||||
ldb: i32,
|
||||
stride_b: i64,
|
||||
beta: *const c_void,
|
||||
c: *mut c_void,
|
||||
c_type: i32,
|
||||
ldc: i32,
|
||||
stride_c: i64,
|
||||
batch_count: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
) -> i32;
|
||||
|
||||
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
|
||||
}
|
||||
|
||||
const CUDA_R_16BF: i32 = 14;
|
||||
const CUBLAS_COMPUTE_32F: i32 = 68;
|
||||
const CUBLAS_GEMM_DEFAULT: i32 = -1;
|
||||
|
||||
/// GPU top-k selection + softmax over router logits.
|
||||
///
|
||||
/// Input: router_logits [num_tokens, num_experts] BF16 on GPU
|
||||
/// Output: (topk_ids [num_tokens, top_k] i32, topk_weights [num_tokens, top_k] f32)
|
||||
pub fn moe_topk_softmax(
|
||||
router_logits: &Tensor,
|
||||
num_experts: usize,
|
||||
top_k: usize,
|
||||
) -> (Tensor, Tensor) {
|
||||
assert_eq!(router_logits.ndim(), 2);
|
||||
assert_eq!(router_logits.dtype(), DType::BF16);
|
||||
assert!(router_logits.is_contiguous());
|
||||
let num_tokens = router_logits.shape()[0];
|
||||
assert_eq!(router_logits.shape()[1], num_experts);
|
||||
|
||||
// NOTE: topk_ids actually holds i32 expert indices; DType has no I32, so
|
||||
// this is a raw 4-byte buffer mislabeled F32. Never read it as floats —
|
||||
// all consumers (weighted-sum / sparse GEMV kernels) cast to int*.
|
||||
let topk_ids = Tensor::empty(&[num_tokens, top_k], DType::F32, router_logits.device());
|
||||
let topk_weights = Tensor::empty(&[num_tokens, top_k], DType::F32, router_logits.device());
|
||||
|
||||
unsafe {
|
||||
launch_moe_topk_softmax_bf16(
|
||||
router_logits.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *mut c_void,
|
||||
topk_weights.data_ptr() as *mut c_void,
|
||||
num_tokens as i32,
|
||||
num_experts as i32,
|
||||
top_k as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
(topk_ids, topk_weights)
|
||||
}
|
||||
|
||||
/// Replicate x [num_tokens, hidden] → [local_experts, num_tokens, hidden].
|
||||
pub fn moe_replicate(x: &Tensor, local_experts: usize) -> Tensor {
|
||||
assert_eq!(x.ndim(), 2);
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous());
|
||||
let num_tokens = x.shape()[0];
|
||||
let hidden = x.shape()[1];
|
||||
let out = Tensor::empty(
|
||||
&[local_experts, num_tokens, hidden],
|
||||
DType::BF16,
|
||||
x.device(),
|
||||
);
|
||||
|
||||
unsafe {
|
||||
launch_moe_replicate_bf16(
|
||||
x.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32,
|
||||
hidden as i32,
|
||||
local_experts as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
/// In-place 3D bias add: x [batch, num_tokens, dim] += bias [batch, dim].
|
||||
pub fn moe_bias_add_3d(x: &Tensor, bias: &Tensor) {
|
||||
assert_eq!(x.ndim(), 3);
|
||||
assert_eq!(bias.ndim(), 2);
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous());
|
||||
let batch = x.shape()[0];
|
||||
let num_tokens = x.shape()[1];
|
||||
let dim = x.shape()[2];
|
||||
assert_eq!(bias.shape(), &[batch, dim]);
|
||||
|
||||
unsafe {
|
||||
launch_moe_bias_add_3d_bf16(
|
||||
x.data_ptr() as *mut c_void,
|
||||
bias.data_ptr() as *const c_void,
|
||||
batch as i32,
|
||||
num_tokens as i32,
|
||||
dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Weighted sum of expert outputs → [num_tokens, hidden].
|
||||
///
|
||||
/// expert_out: [local_experts, num_tokens, hidden] BF16
|
||||
/// topk_ids: [num_tokens, top_k] i32 (global expert indices)
|
||||
/// topk_weights: [num_tokens, top_k] f32
|
||||
pub fn moe_weighted_sum(
|
||||
expert_out: &Tensor,
|
||||
topk_ids: &Tensor,
|
||||
topk_weights: &Tensor,
|
||||
expert_start: usize,
|
||||
local_experts: usize,
|
||||
top_k: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(expert_out.ndim(), 3);
|
||||
assert_eq!(expert_out.dtype(), DType::BF16);
|
||||
let num_tokens = expert_out.shape()[1];
|
||||
let hidden = expert_out.shape()[2];
|
||||
|
||||
let out = Tensor::empty(&[num_tokens, hidden], DType::BF16, expert_out.device());
|
||||
|
||||
unsafe {
|
||||
launch_moe_weighted_sum_bf16(
|
||||
expert_out.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
topk_weights.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32,
|
||||
hidden as i32,
|
||||
top_k as i32,
|
||||
expert_start as i32,
|
||||
local_experts as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
/// Sparse MoE GEMV (FP8 W8A16): compute only the routed experts.
|
||||
///
|
||||
/// x: [num_tokens, K] BF16 (x_per_slot=false, gate_up) or
|
||||
/// [num_tokens * top_k, K] BF16 (x_per_slot=true, down)
|
||||
/// w_fp8_t: [local_experts, N, K] FP8E4M3 (transposed weight layout)
|
||||
/// w_scales: [local_experts] F32 per-expert scalar scales
|
||||
/// bias: [local_experts, N] BF16 (fused into the epilogue)
|
||||
/// topk_ids: [num_tokens, top_k] i32 global expert ids (GPU)
|
||||
///
|
||||
/// Returns y [num_tokens, top_k, N] BF16. Slots routed to experts NOT
|
||||
/// owned by this rank are left UNWRITTEN (uninitialized memory) — the
|
||||
/// consumer must skip them (see moe_weighted_sum_sparse).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn moe_sparse_gemv_fp8(
|
||||
x: &Tensor,
|
||||
w_fp8_t: &Tensor,
|
||||
w_scales: &Tensor,
|
||||
bias: &Tensor,
|
||||
topk_ids: &Tensor,
|
||||
num_tokens: usize,
|
||||
top_k: usize,
|
||||
expert_start: usize,
|
||||
local_experts: usize,
|
||||
x_per_slot: bool,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous());
|
||||
assert_eq!(w_fp8_t.dtype(), DType::FP8E4M3);
|
||||
let n = w_fp8_t.shape()[1];
|
||||
let k = w_fp8_t.shape()[2];
|
||||
// The kernel reads weights as uint4 (16 FP8 values per lane) and would
|
||||
// silently skip a K%16 tail.
|
||||
assert_eq!(k % 16, 0, "sparse FP8 GEMV requires K % 16 == 0, got {k}");
|
||||
assert_eq!(x.shape()[x.ndim() - 1], k);
|
||||
assert_eq!(
|
||||
x.shape()[0],
|
||||
if x_per_slot {
|
||||
num_tokens * top_k
|
||||
} else {
|
||||
num_tokens
|
||||
}
|
||||
);
|
||||
|
||||
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_moe_sparse_gemv_fp8_bf16(
|
||||
x.data_ptr() as *const c_void,
|
||||
w_fp8_t.data_ptr() as *const c_void,
|
||||
w_scales.data_ptr() as *const c_void,
|
||||
bias.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
y.data_ptr() as *mut c_void,
|
||||
num_tokens as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
top_k as i32,
|
||||
expert_start as i32,
|
||||
local_experts as i32,
|
||||
x_per_slot as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
y
|
||||
}
|
||||
|
||||
/// Sparse MoE GEMV (MXFP4 W4A16): same contract as moe_sparse_gemv_fp8,
|
||||
/// with packed 4-bit weights [E, N, K/2] + UE8M0 block scales [E, N, K/32].
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn moe_sparse_gemv_mxfp4(
|
||||
x: &Tensor,
|
||||
w_packed: &Tensor,
|
||||
w_scales: &Tensor,
|
||||
bias: &Tensor,
|
||||
topk_ids: &Tensor,
|
||||
num_tokens: usize,
|
||||
top_k: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
expert_start: usize,
|
||||
local_experts: usize,
|
||||
x_per_slot: bool,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous());
|
||||
// 32-element MXFP4 blocks, read as uint4 (32 nibbles) per lane.
|
||||
assert_eq!(k % 32, 0, "sparse MXFP4 GEMV requires K % 32 == 0, got {k}");
|
||||
assert_eq!(x.shape()[x.ndim() - 1], k);
|
||||
assert_eq!(
|
||||
x.shape()[0],
|
||||
if x_per_slot {
|
||||
num_tokens * top_k
|
||||
} else {
|
||||
num_tokens
|
||||
}
|
||||
);
|
||||
|
||||
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_moe_sparse_gemv_mxfp4_bf16(
|
||||
x.data_ptr() as *const c_void,
|
||||
w_packed.data_ptr() as *const c_void,
|
||||
w_scales.data_ptr() as *const c_void,
|
||||
bias.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
y.data_ptr() as *mut c_void,
|
||||
num_tokens as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
top_k as i32,
|
||||
expert_start as i32,
|
||||
local_experts as i32,
|
||||
x_per_slot as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
y
|
||||
}
|
||||
|
||||
/// Weighted sum over the slot axis of the sparse GEMV output.
|
||||
///
|
||||
/// down: [num_tokens, top_k, hidden] BF16 (non-local slots uninitialized
|
||||
/// and skipped, never multiplied by zero — NaN * 0 = NaN).
|
||||
pub fn moe_weighted_sum_sparse(
|
||||
down: &Tensor,
|
||||
topk_ids: &Tensor,
|
||||
topk_weights: &Tensor,
|
||||
expert_start: usize,
|
||||
local_experts: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(down.ndim(), 3);
|
||||
assert_eq!(down.dtype(), DType::BF16);
|
||||
let num_tokens = down.shape()[0];
|
||||
let top_k = down.shape()[1];
|
||||
let hidden = down.shape()[2];
|
||||
|
||||
let out = Tensor::empty(&[num_tokens, hidden], DType::BF16, down.device());
|
||||
unsafe {
|
||||
launch_moe_weighted_sum_sparse_bf16(
|
||||
down.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
topk_weights.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32,
|
||||
hidden as i32,
|
||||
top_k as i32,
|
||||
expert_start as i32,
|
||||
local_experts as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Strided batched GEMM for MoE expert forward.
|
||||
/// C[b] = A[b] @ B[b] for b in 0..batch
|
||||
///
|
||||
/// A: [batch, M, K] BF16 contiguous
|
||||
/// B: [batch, K, N] BF16 contiguous
|
||||
/// Returns C: [batch, M, N] BF16
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn batched_gemm_strided(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.ndim(), 3);
|
||||
assert_eq!(b.ndim(), 3);
|
||||
assert_eq!(a.dtype(), DType::BF16);
|
||||
assert_eq!(b.dtype(), DType::BF16);
|
||||
assert!(a.is_contiguous() && b.is_contiguous());
|
||||
assert_eq!(a.shape()[0], b.shape()[0]);
|
||||
assert_eq!(a.shape()[2], b.shape()[1]);
|
||||
|
||||
let batch = a.shape()[0];
|
||||
let m = a.shape()[1];
|
||||
let k = a.shape()[2];
|
||||
let n = b.shape()[2];
|
||||
|
||||
let c = Tensor::empty(&[batch, m, n], DType::BF16, a.device());
|
||||
|
||||
let alpha: f32 = 1.0;
|
||||
let beta: f32 = 0.0;
|
||||
|
||||
// cuBLAS column-major: we compute C^T = B^T @ A^T
|
||||
// A is [batch, M, K] row-major → A^T is [K, M] col-major, lda=K
|
||||
// B is [batch, K, N] row-major → B^T is [N, K] col-major, ldb=N? No...
|
||||
//
|
||||
// Actually for row-major: A[M,K] in memory = col-major A^T[K,M] with lda=K.
|
||||
// So we call cublasGemmStridedBatchedEx with:
|
||||
// transa=N, transb=N
|
||||
// m=N, n=M, k=K (because cuBLAS sees col-major)
|
||||
// A_cublas = B_row (pointer), lda=N
|
||||
// B_cublas = A_row (pointer), ldb=K
|
||||
// C_cublas = C_row (pointer), ldc=N
|
||||
|
||||
let stride_a = (m * k) as i64;
|
||||
let stride_b = (k * n) as i64;
|
||||
let stride_c = (m * n) as i64;
|
||||
|
||||
let handle = cublas_handle();
|
||||
unsafe {
|
||||
cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
|
||||
let status = cublasGemmStridedBatchedEx(
|
||||
handle,
|
||||
0,
|
||||
0, // CUBLAS_OP_N, CUBLAS_OP_N
|
||||
n as i32,
|
||||
m as i32,
|
||||
k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b.data_ptr() as *const c_void,
|
||||
CUDA_R_16BF,
|
||||
n as i32,
|
||||
stride_b,
|
||||
a.data_ptr() as *const c_void,
|
||||
CUDA_R_16BF,
|
||||
k as i32,
|
||||
stride_a,
|
||||
&beta as *const f32 as *const c_void,
|
||||
c.data_ptr() as *mut c_void,
|
||||
CUDA_R_16BF,
|
||||
n as i32,
|
||||
stride_c,
|
||||
batch as i32,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT,
|
||||
);
|
||||
assert_eq!(status, 0, "cublasGemmStridedBatchedEx failed: {status}");
|
||||
}
|
||||
|
||||
c
|
||||
}
|
||||
603
crates/xserv-kernels/src/quantization.rs
Normal file
603
crates/xserv-kernels/src/quantization.rs
Normal file
@@ -0,0 +1,603 @@
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Tensor};
|
||||
|
||||
// ============================================================
|
||||
// FFI: custom CUDA kernels
|
||||
// ============================================================
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_dequant_fp8e4m3_to_bf16(
|
||||
src: *const c_void,
|
||||
scales: *const c_void,
|
||||
dst: *mut c_void,
|
||||
num_experts: i32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_quantize_bf16_to_fp8e4m3_rowwise(
|
||||
src: *const c_void,
|
||||
dst: *mut c_void,
|
||||
scales: *mut c_void,
|
||||
num_rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_rowwise_scale_moe_bf16(
|
||||
data: *mut c_void,
|
||||
a_scales: *const c_void,
|
||||
b_scales: *const c_void,
|
||||
num_rows: i32,
|
||||
cols: i32,
|
||||
tokens: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_batched_gemv_mxfp4_bf16(
|
||||
x: *const c_void,
|
||||
w_packed: *const c_void,
|
||||
w_scales: *const c_void,
|
||||
y: *mut c_void,
|
||||
e: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_dequant_mxfp4_to_bf16_t(
|
||||
w_packed: *const c_void,
|
||||
w_scales: *const c_void,
|
||||
out: *mut c_void,
|
||||
e: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// FFI: cuBLASLt
|
||||
// ============================================================
|
||||
|
||||
type CublasLtHandle = *mut c_void;
|
||||
type CublasLtMatmulDesc = *mut c_void;
|
||||
type CublasLtMatrixLayout = *mut c_void;
|
||||
type CublasLtMatmulPreference = *mut c_void;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy)]
|
||||
struct CublasLtMatmulAlgo {
|
||||
data: [u64; 8],
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
struct CublasLtMatmulHeuristicResult {
|
||||
algo: CublasLtMatmulAlgo,
|
||||
workspace_size: usize,
|
||||
state: i32,
|
||||
_reserved: [f32; 4],
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
fn cublasLtCreate(handle: *mut CublasLtHandle) -> i32;
|
||||
fn cublasLtDestroy(handle: CublasLtHandle) -> i32;
|
||||
fn cublasLtMatmulDescCreate(
|
||||
desc: *mut CublasLtMatmulDesc,
|
||||
compute_type: i32,
|
||||
scale_type: i32,
|
||||
) -> i32;
|
||||
fn cublasLtMatmulDescDestroy(desc: CublasLtMatmulDesc) -> i32;
|
||||
fn cublasLtMatmulDescSetAttribute(
|
||||
desc: CublasLtMatmulDesc,
|
||||
attr: i32,
|
||||
buf: *const c_void,
|
||||
size: usize,
|
||||
) -> i32;
|
||||
fn cublasLtMatrixLayoutCreate(
|
||||
layout: *mut CublasLtMatrixLayout,
|
||||
dtype: i32,
|
||||
rows: u64,
|
||||
cols: u64,
|
||||
ld: i64,
|
||||
) -> i32;
|
||||
fn cublasLtMatrixLayoutDestroy(layout: CublasLtMatrixLayout) -> i32;
|
||||
fn cublasLtMatrixLayoutSetAttribute(
|
||||
layout: CublasLtMatrixLayout,
|
||||
attr: i32,
|
||||
buf: *const c_void,
|
||||
size: usize,
|
||||
) -> i32;
|
||||
fn cublasLtMatmulPreferenceCreate(pref: *mut CublasLtMatmulPreference) -> i32;
|
||||
fn cublasLtMatmulPreferenceDestroy(pref: CublasLtMatmulPreference) -> i32;
|
||||
fn cublasLtMatmulPreferenceSetAttribute(
|
||||
pref: CublasLtMatmulPreference,
|
||||
attr: i32,
|
||||
buf: *const c_void,
|
||||
size: usize,
|
||||
) -> i32;
|
||||
fn cublasLtMatmulAlgoGetHeuristic(
|
||||
handle: CublasLtHandle,
|
||||
desc: CublasLtMatmulDesc,
|
||||
a_layout: CublasLtMatrixLayout,
|
||||
b_layout: CublasLtMatrixLayout,
|
||||
c_layout: CublasLtMatrixLayout,
|
||||
d_layout: CublasLtMatrixLayout,
|
||||
pref: CublasLtMatmulPreference,
|
||||
requested: i32,
|
||||
results: *mut CublasLtMatmulHeuristicResult,
|
||||
found: *mut i32,
|
||||
) -> i32;
|
||||
fn cublasLtMatmul(
|
||||
handle: CublasLtHandle,
|
||||
desc: CublasLtMatmulDesc,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void,
|
||||
a_layout: CublasLtMatrixLayout,
|
||||
b: *const c_void,
|
||||
b_layout: CublasLtMatrixLayout,
|
||||
beta: *const c_void,
|
||||
c: *const c_void,
|
||||
c_layout: CublasLtMatrixLayout,
|
||||
d: *mut c_void,
|
||||
d_layout: CublasLtMatrixLayout,
|
||||
algo: *const CublasLtMatmulAlgo,
|
||||
workspace: *mut c_void,
|
||||
workspace_size: usize,
|
||||
stream: *mut c_void,
|
||||
) -> i32;
|
||||
}
|
||||
|
||||
// cuBLASLt constants
|
||||
const CUBLAS_COMPUTE_32F: i32 = 68;
|
||||
const CUDA_R_32F: i32 = 0;
|
||||
const CUDA_R_16BF: i32 = 14;
|
||||
const CUDA_R_8F_E4M3: i32 = 28;
|
||||
|
||||
// MatmulDesc attributes
|
||||
const CUBLASLT_MATMUL_DESC_A_SCALE_POINTER: i32 = 17;
|
||||
const CUBLASLT_MATMUL_DESC_B_SCALE_POINTER: i32 = 18;
|
||||
|
||||
// MatrixLayout attributes
|
||||
const CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT: i32 = 5;
|
||||
const CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET: i32 = 6;
|
||||
|
||||
// MatmulPreference attributes
|
||||
const CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES: i32 = 1;
|
||||
|
||||
const WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
|
||||
|
||||
const CUBLASLT_MATMUL_DESC_TRANSA: i32 = 3;
|
||||
|
||||
/// A fully-prepared FP8 matmul plan for one (M, N, K) shape: the matmul
|
||||
/// descriptor, the four matrix layouts, and the heuristically-chosen algo.
|
||||
/// Built once per shape and reused across every expert and every forward
|
||||
/// pass — the heuristic search and descriptor/layout creation are the
|
||||
/// expensive parts, so doing them once instead of per-expert-per-layer is
|
||||
/// the difference between FP8 being faster or slower than BF16.
|
||||
#[derive(Clone, Copy)]
|
||||
struct Fp8Plan {
|
||||
desc: CublasLtMatmulDesc,
|
||||
a_layout: CublasLtMatrixLayout,
|
||||
b_layout: CublasLtMatrixLayout,
|
||||
c_layout: CublasLtMatrixLayout,
|
||||
d_layout: CublasLtMatrixLayout,
|
||||
algo: CublasLtMatmulAlgo,
|
||||
workspace_size: usize,
|
||||
}
|
||||
|
||||
struct CublasLtContext {
|
||||
handle: CublasLtHandle,
|
||||
workspace: GpuBuffer,
|
||||
/// Persistent device scalar holding 1.0, used as the A/B scale pointer.
|
||||
/// Scales are applied post-GEMM, so the in-GEMM scales stay 1.0.
|
||||
one_buf: GpuBuffer,
|
||||
/// Cache of prepared matmul plans keyed by (M, N, K, batch).
|
||||
plans: HashMap<(usize, usize, usize, usize), Fp8Plan>,
|
||||
}
|
||||
|
||||
impl CublasLtContext {
|
||||
fn new() -> Self {
|
||||
let mut handle = std::ptr::null_mut();
|
||||
let status = unsafe { cublasLtCreate(&mut handle) };
|
||||
assert_eq!(status, 0, "cublasLtCreate failed: {status}");
|
||||
let workspace = GpuBuffer::alloc(WORKSPACE_BYTES).expect("alloc cublasLt workspace");
|
||||
let mut one_buf = GpuBuffer::alloc(4).expect("alloc cublasLt fp8 scale");
|
||||
one_buf
|
||||
.copy_from_host(&1.0f32.to_le_bytes())
|
||||
.expect("init fp8 scale");
|
||||
Self {
|
||||
handle,
|
||||
workspace,
|
||||
one_buf,
|
||||
plans: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the cached strided-batched plan for (m, n, k, batch), building it on
|
||||
/// first use.
|
||||
fn plan(&mut self, m: usize, n: usize, k: usize, batch: usize) -> Fp8Plan {
|
||||
if let Some(p) = self.plans.get(&(m, n, k, batch)) {
|
||||
return *p;
|
||||
}
|
||||
let one_ptr = self.one_buf.as_ptr() as *const c_void;
|
||||
let plan = unsafe { build_fp8_plan(self.handle, one_ptr, m, n, k, batch) };
|
||||
self.plans.insert((m, n, k, batch), plan);
|
||||
plan
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CublasLtContext {
|
||||
fn drop(&mut self) {
|
||||
// Tear down cached plans before destroying the handle.
|
||||
for (_, p) in self.plans.drain() {
|
||||
unsafe {
|
||||
cublasLtMatrixLayoutDestroy(p.a_layout);
|
||||
cublasLtMatrixLayoutDestroy(p.b_layout);
|
||||
cublasLtMatrixLayoutDestroy(p.c_layout);
|
||||
cublasLtMatrixLayoutDestroy(p.d_layout);
|
||||
cublasLtMatmulDescDestroy(p.desc);
|
||||
}
|
||||
}
|
||||
if !self.handle.is_null() {
|
||||
unsafe { cublasLtDestroy(self.handle) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a strided-batched FP8 matmul plan for `batch` experts of one
|
||||
/// (m, n, k) shape. Row-major → cuBLASLt col-major mapping (transA=T,
|
||||
/// transB=N, m_lt=N, n_lt=M, k_lt=K). A/B scale pointers stay at 1.0 — both
|
||||
/// the per-expert weight scale and the per-token activation scale are applied
|
||||
/// post-GEMM in a fused kernel, which lets all experts run in one matmul.
|
||||
unsafe fn build_fp8_plan(
|
||||
handle: CublasLtHandle,
|
||||
one_ptr: *const c_void,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
batch: usize,
|
||||
) -> Fp8Plan {
|
||||
let m_lt = n as u64;
|
||||
let n_lt = m as u64;
|
||||
let k_lt = k as u64;
|
||||
|
||||
let mut desc: CublasLtMatmulDesc = std::ptr::null_mut();
|
||||
cublasLtMatmulDescCreate(&mut desc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
|
||||
|
||||
// transA=T (required for FP8 on Blackwell)
|
||||
let trans_a: i32 = 1;
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
&trans_a as *const i32 as _,
|
||||
4,
|
||||
);
|
||||
let ptr_sz = std::mem::size_of::<*const c_void>();
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
|
||||
&one_ptr as *const _ as _,
|
||||
ptr_sz,
|
||||
);
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
|
||||
&one_ptr as *const _ as _,
|
||||
ptr_sz,
|
||||
);
|
||||
|
||||
// Per-expert strides in ELEMENTS for the strided-batch layout.
|
||||
let stride_a = (n * k) as i64; // weights [N, K]
|
||||
let stride_b = (m * k) as i64; // activations [M, K]
|
||||
let stride_c = (m * n) as i64; // output [M, N]
|
||||
let bc = batch as i32;
|
||||
let set_batch = |layout: CublasLtMatrixLayout, stride: i64| {
|
||||
cublasLtMatrixLayoutSetAttribute(
|
||||
layout,
|
||||
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&bc as *const i32 as _,
|
||||
4,
|
||||
);
|
||||
cublasLtMatrixLayoutSetAttribute(
|
||||
layout,
|
||||
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||
&stride as *const i64 as _,
|
||||
8,
|
||||
);
|
||||
};
|
||||
|
||||
// "A" layout (weights, transposed): physical (K, N) col-major, ld=K
|
||||
let mut a_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||
cublasLtMatrixLayoutCreate(&mut a_layout, CUDA_R_8F_E4M3, k_lt, m_lt, k as i64);
|
||||
set_batch(a_layout, stride_a);
|
||||
// "B" layout (activations): physical (K, M) col-major, ld=K
|
||||
let mut b_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||
cublasLtMatrixLayoutCreate(&mut b_layout, CUDA_R_8F_E4M3, k_lt, n_lt, k as i64);
|
||||
set_batch(b_layout, stride_b);
|
||||
// "C"/"D" layout (output): physical (N, M) col-major, ld=N
|
||||
let mut c_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||
cublasLtMatrixLayoutCreate(&mut c_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
|
||||
set_batch(c_layout, stride_c);
|
||||
let mut d_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||
cublasLtMatrixLayoutCreate(&mut d_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
|
||||
set_batch(d_layout, stride_c);
|
||||
|
||||
let mut pref: CublasLtMatmulPreference = std::ptr::null_mut();
|
||||
cublasLtMatmulPreferenceCreate(&mut pref);
|
||||
let ws_bytes = WORKSPACE_BYTES as u64;
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
pref,
|
||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&ws_bytes as *const u64 as _,
|
||||
8,
|
||||
);
|
||||
|
||||
let mut heuristic = std::mem::zeroed::<CublasLtMatmulHeuristicResult>();
|
||||
let mut found: i32 = 0;
|
||||
let status = cublasLtMatmulAlgoGetHeuristic(
|
||||
handle,
|
||||
desc,
|
||||
a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
d_layout,
|
||||
pref,
|
||||
1,
|
||||
&mut heuristic,
|
||||
&mut found,
|
||||
);
|
||||
assert!(
|
||||
status == 0 && found > 0,
|
||||
"cublasLtMatmulAlgoGetHeuristic failed for batched FP8 GEMM (m={m}, n={n}, k={k}, batch={batch}): status={status}, found={found}"
|
||||
);
|
||||
cublasLtMatmulPreferenceDestroy(pref);
|
||||
|
||||
Fp8Plan {
|
||||
desc,
|
||||
a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
d_layout,
|
||||
algo: heuristic.algo,
|
||||
workspace_size: heuristic.workspace_size,
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static CUBLASLT_CTX: RefCell<CublasLtContext> = RefCell::new(CublasLtContext::new());
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Public API
|
||||
// ============================================================
|
||||
|
||||
/// Dequantize a 3D FP8 E4M3 tensor to BF16 using per-expert FP32 scales.
|
||||
///
|
||||
/// src: [num_experts, rows, cols] FP8E4M3, contiguous, GPU
|
||||
/// scales: [num_experts] F32, contiguous, GPU
|
||||
///
|
||||
/// Returns: [num_experts, rows, cols] BF16
|
||||
pub fn dequant_fp8_to_bf16(src: &Tensor, scales: &Tensor) -> Tensor {
|
||||
assert_eq!(src.ndim(), 3, "dequant_fp8_to_bf16: src must be 3D");
|
||||
assert_eq!(src.dtype(), DType::FP8E4M3);
|
||||
assert!(src.is_contiguous());
|
||||
assert_eq!(scales.ndim(), 1);
|
||||
assert_eq!(scales.dtype(), DType::F32);
|
||||
assert!(scales.is_contiguous());
|
||||
|
||||
let num_experts = src.shape()[0];
|
||||
let rows = src.shape()[1];
|
||||
let cols = src.shape()[2];
|
||||
assert_eq!(scales.shape()[0], num_experts);
|
||||
|
||||
let out = Tensor::empty(&[num_experts, rows, cols], DType::BF16, src.device());
|
||||
|
||||
unsafe {
|
||||
launch_dequant_fp8e4m3_to_bf16(
|
||||
src.data_ptr() as *const c_void,
|
||||
scales.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_experts as i32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
/// Dynamically quantize a contiguous BF16 tensor to FP8 E4M3 with per-row scales.
|
||||
///
|
||||
/// src: [num_rows, cols] or [batch, rows, cols] BF16, contiguous, GPU
|
||||
/// Treats the tensor as 2D (flattens leading dims into num_rows).
|
||||
///
|
||||
/// Returns: (fp8_data [same shape] FP8E4M3, scales [total_rows] F32)
|
||||
pub fn quantize_bf16_to_fp8_rowwise(src: &Tensor) -> (Tensor, Tensor) {
|
||||
assert_eq!(src.dtype(), DType::BF16);
|
||||
assert!(src.is_contiguous());
|
||||
assert!(src.ndim() >= 2);
|
||||
|
||||
let cols = src.shape()[src.ndim() - 1];
|
||||
let num_rows: usize = src.shape()[..src.ndim() - 1].iter().product();
|
||||
|
||||
let fp8_out = Tensor::empty(src.shape(), DType::FP8E4M3, src.device());
|
||||
let scales = Tensor::empty(&[num_rows], DType::F32, src.device());
|
||||
|
||||
unsafe {
|
||||
launch_quantize_bf16_to_fp8e4m3_rowwise(
|
||||
src.data_ptr() as *const c_void,
|
||||
fp8_out.data_ptr() as *mut c_void,
|
||||
scales.data_ptr() as *mut c_void,
|
||||
num_rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
(fp8_out, scales)
|
||||
}
|
||||
|
||||
/// FP8 batched GEMM via cuBLASLt (transA=T required on Blackwell).
|
||||
///
|
||||
/// Computes: C[b] = scale_a[b] * scale_b[b] * (A_fp8[b] @ B_fp8_T[b]^T)
|
||||
/// effectively C[b] = A[b, M, K] @ W[b, K, N] but W is stored transposed
|
||||
/// as [b, N, K] for cuBLASLt FP8 compatibility.
|
||||
///
|
||||
/// a_fp8: [batch, M, K] FP8E4M3 (activations, quantized per-row)
|
||||
/// a_scales: [batch * M] F32 (per-token activation scales, applied post-GEMM)
|
||||
/// b_fp8_t: [batch, N, K] FP8E4M3 (weights, TRANSPOSED for cuBLASLt)
|
||||
/// b_scales: [batch] F32 (per-expert scalar weight scales, applied in-GEMM)
|
||||
///
|
||||
/// Returns: [batch, M, N] BF16
|
||||
pub fn batched_gemm_fp8(
|
||||
a_fp8: &Tensor,
|
||||
a_scales: &Tensor,
|
||||
b_fp8_t: &Tensor,
|
||||
b_scales: &Tensor,
|
||||
) -> Tensor {
|
||||
assert_eq!(a_fp8.ndim(), 3);
|
||||
assert_eq!(b_fp8_t.ndim(), 3);
|
||||
assert_eq!(a_fp8.dtype(), DType::FP8E4M3);
|
||||
assert_eq!(b_fp8_t.dtype(), DType::FP8E4M3);
|
||||
assert!(a_fp8.is_contiguous() && b_fp8_t.is_contiguous());
|
||||
assert_eq!(a_fp8.shape()[0], b_fp8_t.shape()[0]);
|
||||
// b_fp8_t is [batch, N, K] transposed, so b_fp8_t.shape[2] == K == a_fp8.shape[2]
|
||||
assert_eq!(a_fp8.shape()[2], b_fp8_t.shape()[2]);
|
||||
|
||||
let batch = a_fp8.shape()[0];
|
||||
let m = a_fp8.shape()[1]; // tokens
|
||||
let k = a_fp8.shape()[2]; // hidden
|
||||
let n = b_fp8_t.shape()[1]; // out_dim (from transposed weight)
|
||||
|
||||
// a_scales: [batch * M] per-token activation scales (applied post-GEMM, per row).
|
||||
// b_scales: [batch] per-expert scalar weight scales (applied in-GEMM via B-scale ptr).
|
||||
assert_eq!(a_scales.shape()[0], batch * m);
|
||||
assert_eq!(b_scales.shape()[0], batch);
|
||||
|
||||
let c = Tensor::empty(&[batch, m, n], DType::BF16, a_fp8.device());
|
||||
|
||||
CUBLASLT_CTX.with(|cell| {
|
||||
let mut ctx = cell.borrow_mut();
|
||||
let handle = ctx.handle;
|
||||
let ws_ptr = ctx.workspace.as_ptr() as *mut c_void;
|
||||
// Cached strided-batched plan: heuristic + descriptor/layout creation
|
||||
// happen once per (m, n, k, batch). All experts run in ONE matmul.
|
||||
let plan = ctx.plan(m, n, k, batch);
|
||||
|
||||
// alpha=1, beta=0, in-GEMM scales=1.0. The unscaled result
|
||||
// D_raw[e] = A_fp8[e] @ B_fp8[e]^T
|
||||
// is recovered to the real value by the fused post-scale kernel below.
|
||||
let alpha: f32 = 1.0;
|
||||
let beta: f32 = 0.0;
|
||||
|
||||
unsafe {
|
||||
let status = cublasLtMatmul(
|
||||
handle,
|
||||
plan.desc,
|
||||
&alpha as *const f32 as _,
|
||||
b_fp8_t.data_ptr() as *const c_void, // cuBLASLt "A" = weights
|
||||
plan.a_layout,
|
||||
a_fp8.data_ptr() as *const c_void, // cuBLASLt "B" = activations
|
||||
plan.b_layout,
|
||||
&beta as *const f32 as _,
|
||||
c.data_ptr() as *const c_void, // C (unused with beta=0)
|
||||
plan.c_layout,
|
||||
c.data_ptr() as *mut c_void, // D = output
|
||||
plan.d_layout,
|
||||
&plan.algo,
|
||||
ws_ptr,
|
||||
plan.workspace_size,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
assert_eq!(
|
||||
status, 0,
|
||||
"batched cublasLtMatmul FP8 failed: status={status}"
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
// Post-GEMM: recover the real result in one pass.
|
||||
// c[e, t, :] *= a_scales[e*M + t] * b_scales[e]
|
||||
// (per-token activation scale × per-expert weight scale). BF16's relative
|
||||
// error is scale-invariant, so applying the scale here is precision-
|
||||
// equivalent to folding it into the GEMM epilogue.
|
||||
let total_rows = (batch * m) as i32;
|
||||
unsafe {
|
||||
launch_rowwise_scale_moe_bf16(
|
||||
c.data_ptr() as *mut c_void,
|
||||
a_scales.data_ptr() as *const c_void,
|
||||
b_scales.data_ptr() as *const c_void,
|
||||
total_rows,
|
||||
n as i32,
|
||||
m as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
c
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// MXFP4 W4A16 (weight-only 4-bit) for MoE experts
|
||||
// ============================================================
|
||||
|
||||
/// MXFP4 W4A16 batched GEMV for decode (M=1).
|
||||
///
|
||||
/// x: [E, K] BF16 (per-expert activation; replicated across experts)
|
||||
/// w_packed: [E, N, K/2] byte tensor — two E2M1 nibbles per byte (lo = even k)
|
||||
/// w_scales: [E, N, K/32] byte tensor — UE8M0 scale per 32-element block
|
||||
///
|
||||
/// Returns: [E, N] BF16, where y[e,n] = sum_k x[e,k] * dequant(W[e,n,k]).
|
||||
pub fn batched_gemv_mxfp4(
|
||||
x: &Tensor,
|
||||
w_packed: &Tensor,
|
||||
w_scales: &Tensor,
|
||||
n: usize,
|
||||
k: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous());
|
||||
let e = x.shape()[0];
|
||||
assert_eq!(x.shape()[x.ndim() - 1], k, "GEMV K mismatch");
|
||||
|
||||
let y = Tensor::empty(&[e, n], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_batched_gemv_mxfp4_bf16(
|
||||
x.data_ptr() as *const c_void,
|
||||
w_packed.data_ptr() as *const c_void,
|
||||
w_scales.data_ptr() as *const c_void,
|
||||
y.data_ptr() as *mut c_void,
|
||||
e as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
y
|
||||
}
|
||||
|
||||
/// Dequantize MXFP4 weights [E, N, K] → BF16 [E, K, N] for the prefill GEMM path
|
||||
/// (the BF16 batched GEMM expects weights as [E, K, N]).
|
||||
pub fn dequant_mxfp4_to_bf16_t(
|
||||
w_packed: &Tensor,
|
||||
w_scales: &Tensor,
|
||||
e: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
) -> Tensor {
|
||||
let out = Tensor::empty(&[e, k, n], DType::BF16, w_packed.device());
|
||||
unsafe {
|
||||
launch_dequant_mxfp4_to_bf16_t(
|
||||
w_packed.data_ptr() as *const c_void,
|
||||
w_scales.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
e as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
@@ -2,10 +2,35 @@ use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_rmsnorm_f32(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_rmsnorm_f32(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
residual: *const c_void,
|
||||
gamma: *const c_void,
|
||||
normed_out: *mut c_void,
|
||||
sum_out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
||||
@@ -17,21 +42,82 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
||||
assert_eq!(x.dtype(), gamma.dtype());
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
assert!(
|
||||
rows <= i32::MAX as usize,
|
||||
"too many rows for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
hidden_size <= i32::MAX as usize,
|
||||
"hidden_size too large for i32 kernel param"
|
||||
);
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_rmsnorm_f32(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
x.data_ptr() as _,
|
||||
gamma.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_rmsnorm_bf16(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
x.data_ptr() as _,
|
||||
gamma.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for rmsnorm"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
/// Fused Add + RMSNorm: computes sum = x + residual, then normed = rmsnorm(sum, gamma, eps).
|
||||
/// Returns (normed, sum). BF16 only.
|
||||
/// Saves one kernel launch and one full HBM round-trip per layer.
|
||||
pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (Tensor, Tensor) {
|
||||
assert!(x.ndim() >= 1);
|
||||
assert_eq!(x.shape(), residual.shape());
|
||||
assert!(x.is_contiguous() && residual.is_contiguous() && gamma.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
assert_eq!(x.dtype(), DType::BF16, "add_rmsnorm requires BF16");
|
||||
assert_eq!(residual.dtype(), DType::BF16);
|
||||
assert_eq!(gamma.dtype(), DType::BF16);
|
||||
|
||||
let hidden_size = *x.shape().last().unwrap();
|
||||
assert_eq!(gamma.shape(), &[hidden_size]);
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
assert!(
|
||||
rows <= i32::MAX as usize,
|
||||
"too many rows for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
hidden_size <= i32::MAX as usize,
|
||||
"hidden_size too large for i32 kernel param"
|
||||
);
|
||||
let normed_out = Tensor::empty(x.shape(), DType::BF16, x.device());
|
||||
let sum_out = Tensor::empty(x.shape(), DType::BF16, x.device());
|
||||
|
||||
unsafe {
|
||||
launch_add_rmsnorm_bf16(
|
||||
x.data_ptr() as *const c_void,
|
||||
residual.data_ptr() as *const c_void,
|
||||
gamma.data_ptr() as *const c_void,
|
||||
normed_out.data_ptr() as *mut c_void,
|
||||
sum_out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
(normed_out, sum_out)
|
||||
}
|
||||
|
||||
@@ -3,15 +3,34 @@ use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_rope_f32(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void);
|
||||
fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void);
|
||||
fn launch_compute_rope_cache(cos_cache: *mut c_void, sin_cache: *mut c_void,
|
||||
max_seq_len: i32, half_dim: i32, theta: f32,
|
||||
stream: *mut c_void);
|
||||
fn launch_rope_f32(
|
||||
x: *mut c_void,
|
||||
cos_cache: *const c_void,
|
||||
sin_cache: *const c_void,
|
||||
positions: *const c_void,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_rope_bf16(
|
||||
x: *mut c_void,
|
||||
cos_cache: *const c_void,
|
||||
sin_cache: *const c_void,
|
||||
positions: *const c_void,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_compute_rope_cache(
|
||||
cos_cache: *mut c_void,
|
||||
sin_cache: *mut c_void,
|
||||
max_seq_len: i32,
|
||||
half_dim: i32,
|
||||
theta: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
pub struct RopeCache {
|
||||
@@ -30,13 +49,99 @@ impl RopeCache {
|
||||
|
||||
unsafe {
|
||||
launch_compute_rope_cache(
|
||||
cos.as_mut_ptr() as _, sin.as_mut_ptr() as _,
|
||||
max_seq_len as i32, half_dim as i32, theta, std::ptr::null_mut(),
|
||||
cos.as_mut_ptr() as _,
|
||||
sin.as_mut_ptr() as _,
|
||||
max_seq_len as i32,
|
||||
half_dim as i32,
|
||||
theta,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
|
||||
Self { cos, sin, max_seq_len, half_dim }
|
||||
Self {
|
||||
cos,
|
||||
sin,
|
||||
max_seq_len,
|
||||
half_dim,
|
||||
}
|
||||
}
|
||||
|
||||
/// YaRN (Yet another RoPE extensioN) RoPE cache. Applies frequency-dependent
|
||||
/// interpolation so the model can extrapolate beyond its training context.
|
||||
pub fn new_yarn(
|
||||
max_seq_len: usize,
|
||||
head_dim: usize,
|
||||
theta: f64,
|
||||
factor: f64,
|
||||
original_max_pos: usize,
|
||||
beta_fast: f64,
|
||||
beta_slow: f64,
|
||||
) -> Self {
|
||||
let half_dim = head_dim / 2;
|
||||
let dim = head_dim as f64;
|
||||
|
||||
// find_correction_dim: inverse formula to find dimension from number of rotations
|
||||
let find_correction_dim = |num_rotations: f64| -> f64 {
|
||||
dim * (original_max_pos as f64 / (num_rotations * 2.0 * std::f64::consts::PI)).ln()
|
||||
/ (2.0 * theta.ln())
|
||||
};
|
||||
|
||||
let low_raw = find_correction_dim(beta_fast);
|
||||
let high_raw = find_correction_dim(beta_slow);
|
||||
// config has truncate=false, so use raw values (no floor/ceil)
|
||||
let low = low_raw.max(0.0);
|
||||
let high = high_raw.min((half_dim - 1) as f64);
|
||||
|
||||
// Compute inv_freq with YaRN interpolation
|
||||
let mut inv_freq = vec![0.0f64; half_dim];
|
||||
for i in 0..half_dim {
|
||||
let pos_freq = theta.powf((2 * i) as f64 / dim);
|
||||
let inv_freq_extrapolation = 1.0 / pos_freq; // original
|
||||
let inv_freq_interpolation = 1.0 / (factor * pos_freq); // scaled
|
||||
|
||||
// Linear ramp: 0 where we keep original, 1 where we interpolate
|
||||
let ramp = if (high - low).abs() < 0.001 {
|
||||
0.5
|
||||
} else {
|
||||
((i as f64 - low) / (high - low)).clamp(0.0, 1.0)
|
||||
};
|
||||
let extrapolation_factor = 1.0 - ramp;
|
||||
|
||||
inv_freq[i] = inv_freq_interpolation * (1.0 - extrapolation_factor)
|
||||
+ inv_freq_extrapolation * extrapolation_factor;
|
||||
}
|
||||
|
||||
// Attention scaling factor for YaRN: 0.1 * ln(factor) + 1.0
|
||||
let attn_factor = 0.1 * factor.ln() + 1.0;
|
||||
|
||||
// Build cos/sin cache on CPU then upload
|
||||
let total = max_seq_len * half_dim;
|
||||
let mut cos_host = vec![0.0f32; total];
|
||||
let mut sin_host = vec![0.0f32; total];
|
||||
for pos in 0..max_seq_len {
|
||||
for i in 0..half_dim {
|
||||
let angle = pos as f64 * inv_freq[i];
|
||||
cos_host[pos * half_dim + i] = (angle.cos() * attn_factor) as f32;
|
||||
sin_host[pos * half_dim + i] = (angle.sin() * attn_factor) as f32;
|
||||
}
|
||||
}
|
||||
|
||||
let nbytes = total * std::mem::size_of::<f32>();
|
||||
let mut cos = GpuBuffer::alloc(nbytes).expect("alloc yarn cos_cache");
|
||||
let mut sin = GpuBuffer::alloc(nbytes).expect("alloc yarn sin_cache");
|
||||
let cos_bytes =
|
||||
unsafe { std::slice::from_raw_parts(cos_host.as_ptr() as *const u8, nbytes) };
|
||||
let sin_bytes =
|
||||
unsafe { std::slice::from_raw_parts(sin_host.as_ptr() as *const u8, nbytes) };
|
||||
cos.copy_from_host(cos_bytes).unwrap();
|
||||
sin.copy_from_host(sin_bytes).unwrap();
|
||||
|
||||
Self {
|
||||
cos,
|
||||
sin,
|
||||
max_seq_len,
|
||||
half_dim,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,27 +164,48 @@ pub fn rope_inplace(x: &Tensor, cache: &RopeCache, positions: &[u32]) {
|
||||
num_tokens * std::mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
let mut pos_gpu = GpuBuffer::alloc(pos_bytes.len()).expect("alloc positions");
|
||||
let mut pos_gpu =
|
||||
xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions");
|
||||
pos_gpu.copy_from_host(pos_bytes).unwrap();
|
||||
|
||||
rope_inplace_device_pos(x, cache, pos_gpu.as_ptr() as *const c_void);
|
||||
}
|
||||
|
||||
/// RoPE in-place with positions already on the GPU (u32, [num_tokens]).
|
||||
/// Used by the CUDA-graph decode path, where the position lives in a
|
||||
/// persistent device buffer updated outside the captured region.
|
||||
pub fn rope_inplace_device_pos(x: &Tensor, cache: &RopeCache, pos_gpu: *const c_void) {
|
||||
assert_eq!(x.ndim(), 3);
|
||||
assert!(x.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let num_tokens = x.shape()[0];
|
||||
let num_heads = x.shape()[1];
|
||||
let head_dim = x.shape()[2];
|
||||
assert_eq!(head_dim / 2, cache.half_dim);
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_rope_f32(
|
||||
x.data_ptr() as *mut c_void,
|
||||
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
||||
pos_gpu.as_ptr() as _,
|
||||
num_tokens as i32, num_heads as i32, head_dim as i32,
|
||||
std::ptr::null_mut(),
|
||||
cache.cos.as_ptr() as _,
|
||||
cache.sin.as_ptr() as _,
|
||||
pos_gpu,
|
||||
num_tokens as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_rope_bf16(
|
||||
x.data_ptr() as *mut c_void,
|
||||
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
||||
pos_gpu.as_ptr() as _,
|
||||
num_tokens as i32, num_heads as i32, head_dim as i32,
|
||||
std::ptr::null_mut(),
|
||||
cache.cos.as_ptr() as _,
|
||||
cache.sin.as_ptr() as _,
|
||||
pos_gpu,
|
||||
num_tokens as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for rope"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
|
||||
@@ -2,8 +2,20 @@ use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_softmax_f32(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
|
||||
fn launch_softmax_bf16(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
|
||||
fn launch_softmax_f32(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_softmax_bf16(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// Softmax along the last dimension.
|
||||
@@ -14,21 +26,34 @@ pub fn softmax(x: &Tensor) -> Tensor {
|
||||
|
||||
let cols = *x.shape().last().unwrap();
|
||||
let rows = x.numel() / cols;
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
assert!(
|
||||
rows <= i32::MAX as usize,
|
||||
"too many rows for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
cols <= i32::MAX as usize,
|
||||
"cols too large for i32 kernel param"
|
||||
);
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_softmax_f32(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32, std::ptr::null_mut(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_softmax_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32, std::ptr::null_mut(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for softmax"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
256
crates/xserv-kernels/src/transpose.rs
Normal file
256
crates/xserv-kernels/src/transpose.rs
Normal file
@@ -0,0 +1,256 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_reshape_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_merge_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_transpose_hsd_to_shd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_transpose_shd_to_hsd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_repeat_kv_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
kv_heads: i32,
|
||||
n_rep: i32,
|
||||
seq_len: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_strided_copy_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
numel: i32,
|
||||
ndim: i32,
|
||||
shape0: i32,
|
||||
shape1: i32,
|
||||
shape2: i32,
|
||||
shape3: i32,
|
||||
in_stride0: i32,
|
||||
in_stride1: i32,
|
||||
in_stride2: i32,
|
||||
in_stride3: i32,
|
||||
in_offset: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_strided_copy_f32(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
numel: i32,
|
||||
ndim: i32,
|
||||
shape0: i32,
|
||||
shape1: i32,
|
||||
shape2: i32,
|
||||
shape3: i32,
|
||||
in_stride0: i32,
|
||||
in_stride1: i32,
|
||||
in_stride2: i32,
|
||||
in_stride3: i32,
|
||||
in_offset: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// [S, H*D] → [1, H, S, D] on GPU (BF16)
|
||||
pub fn reshape_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_reshape_heads_bf16(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
seq_len as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// [1, H, S, D] → [S, H*D] on GPU (BF16)
|
||||
pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let hidden = num_heads * head_dim;
|
||||
let out = Tensor::empty(&[seq_len, hidden], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_merge_heads_bf16(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
seq_len as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// [1, H, S, D] → [S, H, D] for RoPE on GPU (BF16)
|
||||
pub fn transpose_for_rope_gpu(
|
||||
x: &Tensor,
|
||||
seq_len: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(&[seq_len, num_heads, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_transpose_hsd_to_shd_bf16(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
seq_len as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// [S, H, D] → [1, H, S, D] after RoPE on GPU (BF16)
|
||||
pub fn transpose_from_rope_gpu(
|
||||
x: &Tensor,
|
||||
seq_len: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_transpose_shd_to_hsd_bf16(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
seq_len as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// [1, KV_H, S, D] → [1, KV_H*n_rep, S, D] on GPU (BF16)
|
||||
pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
|
||||
if n_rep == 1 {
|
||||
return x.clone();
|
||||
}
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let kv_heads = x.shape()[1];
|
||||
let seq_len = x.shape()[2];
|
||||
let head_dim = x.shape()[3];
|
||||
let new_heads = kv_heads * n_rep;
|
||||
let out = Tensor::empty(&[1, new_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_repeat_kv_bf16(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
kv_heads as i32,
|
||||
n_rep as i32,
|
||||
seq_len as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Make a non-contiguous GPU tensor contiguous via a strided copy kernel.
|
||||
/// Supports BF16 and F32, up to 4D tensors (padded to 4D internally).
|
||||
pub fn strided_to_contiguous_gpu(x: &Tensor) -> Tensor {
|
||||
assert!(matches!(x.device(), Device::Cuda(_)), "expected GPU tensor");
|
||||
assert!(!x.is_contiguous(), "tensor is already contiguous");
|
||||
assert!(x.ndim() <= 4, "strided_to_contiguous_gpu supports up to 4D");
|
||||
|
||||
let ndim = x.ndim();
|
||||
let numel = x.numel();
|
||||
|
||||
// Pad shape and strides to 4D (prepend 1s for shape, 0s for strides)
|
||||
let mut shape4 = [1i32; 4];
|
||||
let mut strides4 = [0i32; 4];
|
||||
let pad = 4 - ndim;
|
||||
for i in 0..ndim {
|
||||
shape4[pad + i] = x.shape()[i] as i32;
|
||||
strides4[pad + i] = x.strides()[i] as i32;
|
||||
}
|
||||
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
// Use storage base pointer + element offset, because strides are relative to
|
||||
// element 0 of the storage, not the data_ptr() (which already adds byte offset).
|
||||
let storage_ptr = x.storage().gpu_buffer().as_ptr();
|
||||
let in_offset = x.offset() as i32;
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::BF16 => launch_strided_copy_bf16(
|
||||
storage_ptr as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
numel as i32,
|
||||
ndim as i32,
|
||||
shape4[0],
|
||||
shape4[1],
|
||||
shape4[2],
|
||||
shape4[3],
|
||||
strides4[0],
|
||||
strides4[1],
|
||||
strides4[2],
|
||||
strides4[3],
|
||||
in_offset,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::F32 => launch_strided_copy_f32(
|
||||
storage_ptr as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
numel as i32,
|
||||
ndim as i32,
|
||||
shape4[0],
|
||||
shape4[1],
|
||||
shape4[2],
|
||||
shape4[3],
|
||||
strides4[0],
|
||||
strides4[1],
|
||||
strides4[2],
|
||||
strides4[3],
|
||||
in_offset,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!(
|
||||
"strided_to_contiguous_gpu: unsupported dtype {:?}",
|
||||
x.dtype()
|
||||
),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
232
crates/xserv-kernels/tests/attention_test.rs
Normal file
232
crates/xserv-kernels/tests/attention_test.rs
Normal file
@@ -0,0 +1,232 @@
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn init() {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
}
|
||||
|
||||
fn cpu_attention(
|
||||
q: &[f32],
|
||||
k: &[f32],
|
||||
v: &[f32],
|
||||
batch: usize,
|
||||
heads: usize,
|
||||
q_len: usize,
|
||||
kv_len: usize,
|
||||
head_dim: usize,
|
||||
causal: bool,
|
||||
) -> Vec<f32> {
|
||||
let mut out = vec![0.0f32; batch * heads * q_len * head_dim];
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
for b in 0..batch {
|
||||
for h in 0..heads {
|
||||
// scores = Q @ K^T, scaled
|
||||
let mut scores = vec![0.0f32; q_len * kv_len];
|
||||
for i in 0..q_len {
|
||||
for j in 0..kv_len {
|
||||
let mut s = 0.0f32;
|
||||
for d in 0..head_dim {
|
||||
let qi = q[((b * heads + h) * q_len + i) * head_dim + d];
|
||||
let ki = k[((b * heads + h) * kv_len + j) * head_dim + d];
|
||||
s += qi * ki;
|
||||
}
|
||||
scores[i * kv_len + j] = s * scale;
|
||||
}
|
||||
}
|
||||
// causal mask
|
||||
if causal {
|
||||
let offset = kv_len - q_len;
|
||||
for i in 0..q_len {
|
||||
for j in 0..kv_len {
|
||||
if j > i + offset {
|
||||
scores[i * kv_len + j] = f32::NEG_INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// softmax per row
|
||||
for i in 0..q_len {
|
||||
let row = &mut scores[i * kv_len..(i + 1) * kv_len];
|
||||
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut sum = 0.0f32;
|
||||
for v in row.iter_mut() {
|
||||
*v = (*v - max).exp();
|
||||
sum += *v;
|
||||
}
|
||||
for v in row.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
// output = weights @ V
|
||||
for i in 0..q_len {
|
||||
for d in 0..head_dim {
|
||||
let mut s = 0.0f32;
|
||||
for j in 0..kv_len {
|
||||
let w = scores[i * kv_len + j];
|
||||
let vi = v[((b * heads + h) * kv_len + j) * head_dim + d];
|
||||
s += w * vi;
|
||||
}
|
||||
out[((b * heads + h) * q_len + i) * head_dim + d] = s;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn check_close(a: &[f32], b: &[f32], atol: f32, name: &str) {
|
||||
assert_eq!(a.len(), b.len(), "{name}: length mismatch");
|
||||
let mut max_err = 0.0f32;
|
||||
for (i, (x, y)) in a.iter().zip(b).enumerate() {
|
||||
let err = (x - y).abs();
|
||||
if err > max_err {
|
||||
max_err = err;
|
||||
}
|
||||
assert!(
|
||||
err <= atol,
|
||||
"{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}"
|
||||
);
|
||||
}
|
||||
println!("{name}: max_err = {max_err:.6e}");
|
||||
}
|
||||
|
||||
fn make_data(n: usize) -> Vec<f32> {
|
||||
(0..n).map(|i| ((i % 17) as f32 - 8.0) * 0.05).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batched_matmul() {
|
||||
init();
|
||||
let batch = 4;
|
||||
let heads = 8;
|
||||
let m = 32;
|
||||
let k = 64;
|
||||
let n = 32;
|
||||
|
||||
let a_data = make_data(batch * heads * m * k);
|
||||
let b_data = make_data(batch * heads * k * n);
|
||||
|
||||
let a = Tensor::from_slice(&a_data, &[batch, heads, m, k]).to_device(Device::Cuda(0));
|
||||
let b = Tensor::from_slice(&b_data, &[batch, heads, k, n]).to_device(Device::Cuda(0));
|
||||
let c = batched_matmul(&a, &b).to_device(Device::Cpu);
|
||||
|
||||
assert_eq!(c.shape(), &[batch, heads, m, n]);
|
||||
|
||||
// Verify one batch element
|
||||
let a_cpu = &a_data[0..m * k];
|
||||
let b_cpu = &b_data[0..k * n];
|
||||
let mut expected = vec![0.0f32; m * n];
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
let mut s = 0.0f32;
|
||||
for kk in 0..k {
|
||||
s += a_cpu[i * k + kk] * b_cpu[kk * n + j];
|
||||
}
|
||||
expected[i * n + j] = s;
|
||||
}
|
||||
}
|
||||
let result = c.as_slice::<f32>();
|
||||
check_close(&result[0..m * n], &expected, 1e-3, "batched_matmul[0]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_no_causal() {
|
||||
init();
|
||||
let b = 1;
|
||||
let h = 2;
|
||||
let s = 8;
|
||||
let d = 16;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, false);
|
||||
|
||||
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, false).to_device(Device::Cpu);
|
||||
check_close(
|
||||
out.as_slice::<f32>(),
|
||||
&expected,
|
||||
1e-4,
|
||||
"attention_no_causal",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_causal() {
|
||||
init();
|
||||
let b = 1;
|
||||
let h = 2;
|
||||
let s = 16;
|
||||
let d = 32;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true);
|
||||
|
||||
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-3, "attention_causal");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_causal_larger() {
|
||||
init();
|
||||
let b = 2;
|
||||
let h = 4;
|
||||
let s = 64;
|
||||
let d = 64;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true);
|
||||
|
||||
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
||||
check_close(
|
||||
out.as_slice::<f32>(),
|
||||
&expected,
|
||||
1e-2,
|
||||
"attention_causal_larger",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_causal_first_row_sees_only_first_token() {
|
||||
init();
|
||||
let b = 1;
|
||||
let h = 1;
|
||||
let s = 4;
|
||||
let d = 8;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data: Vec<f32> = (0..s * d)
|
||||
.map(|i| {
|
||||
if i < d { 1.0 } else { 0.0 } // only first V row is nonzero
|
||||
})
|
||||
.collect();
|
||||
|
||||
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
||||
|
||||
// First row (position 0) with causal mask can only see position 0.
|
||||
// So attention weight for position 0 is 1.0 for token 0 only.
|
||||
// output[0] should be exactly V[0] = [1, 1, 1, ...1]
|
||||
let result = out.as_slice::<f32>();
|
||||
for i in 0..d {
|
||||
assert!(
|
||||
(result[i] - 1.0).abs() < 1e-5,
|
||||
"first row should equal V[0], got {} at dim {}",
|
||||
result[i],
|
||||
i
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
use half::bf16;
|
||||
use xserv_kernels::{matmul, GemmBackend};
|
||||
use xserv_kernels::{GemmBackend, matmul};
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
|
||||
@@ -75,56 +75,110 @@ fn run_gemm_test_bf16(backend: GemmBackend, m: usize, n: usize, k: usize) {
|
||||
// --- F32 tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_small() { run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4); }
|
||||
fn test_gemm_naive_f32_small() {
|
||||
run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_medium() { run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64); }
|
||||
fn test_gemm_naive_f32_medium() {
|
||||
run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_rect() { run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48); }
|
||||
fn test_gemm_naive_f32_rect() {
|
||||
run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_small() { run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4); }
|
||||
fn test_gemm_tiled_f32_small() {
|
||||
run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_medium() { run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128); }
|
||||
fn test_gemm_tiled_f32_medium() {
|
||||
run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_rect() { run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97); }
|
||||
fn test_gemm_tiled_f32_rect() {
|
||||
run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_small() { run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4); }
|
||||
fn test_gemm_cublas_f32_small() {
|
||||
run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_medium() { run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256); }
|
||||
fn test_gemm_cublas_f32_medium() {
|
||||
run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_rect() { run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97); }
|
||||
fn test_gemm_cublas_f32_rect() {
|
||||
run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97);
|
||||
}
|
||||
|
||||
// --- BF16 tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_bf16_small() { run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4); }
|
||||
fn test_gemm_naive_bf16_small() {
|
||||
run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_bf16_medium() { run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64); }
|
||||
fn test_gemm_naive_bf16_medium() {
|
||||
run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_bf16_small() { run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4); }
|
||||
fn test_gemm_tiled_bf16_small() {
|
||||
run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_bf16_medium() { run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128); }
|
||||
fn test_gemm_tiled_bf16_medium() {
|
||||
run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4); }
|
||||
fn test_gemm_cublas_bf16_small() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256); }
|
||||
fn test_gemm_cublas_bf16_medium() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256);
|
||||
}
|
||||
|
||||
// --- Custom GEMV tests (M=1, BF16 fast path) ---
|
||||
|
||||
#[test]
|
||||
fn test_gemv_bf16_small() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 64, 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemv_bf16_medium() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 256, 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemv_bf16_4096() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 4096, 4096);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemv_bf16_rect() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 512, 4096);
|
||||
}
|
||||
|
||||
// --- Larger benchmark-style tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_1024() { run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024); }
|
||||
fn test_gemm_cublas_f32_1024() {
|
||||
run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_consistency_all_backends() {
|
||||
|
||||
@@ -2,7 +2,9 @@ use half::bf16;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
|
||||
fn init() {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
}
|
||||
|
||||
// --- CPU reference implementations ---
|
||||
|
||||
@@ -37,10 +39,12 @@ fn cpu_layernorm(x: &[f32], gamma: &[f32], beta: &[f32], eps: f32, hidden: usize
|
||||
|
||||
fn cpu_gelu(x: &[f32]) -> Vec<f32> {
|
||||
let sqrt_2_over_pi = 0.7978845608f32;
|
||||
x.iter().map(|&v| {
|
||||
let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
|
||||
0.5 * v * (1.0 + inner.tanh())
|
||||
}).collect()
|
||||
x.iter()
|
||||
.map(|&v| {
|
||||
let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
|
||||
0.5 * v * (1.0 + inner.tanh())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn cpu_silu(x: &[f32]) -> Vec<f32> {
|
||||
@@ -74,10 +78,10 @@ fn cpu_rope(x: &mut [f32], positions: &[u32], num_heads: usize, head_dim: usize,
|
||||
let cos_val = angle.cos();
|
||||
let sin_val = angle.sin();
|
||||
let base = (t * num_heads + h) * head_dim;
|
||||
let x0 = x[base + 2 * i];
|
||||
let x1 = x[base + 2 * i + 1];
|
||||
x[base + 2 * i] = x0 * cos_val - x1 * sin_val;
|
||||
x[base + 2 * i + 1] = x0 * sin_val + x1 * cos_val;
|
||||
let x0 = x[base + i];
|
||||
let x1 = x[base + i + half_dim];
|
||||
x[base + i] = x0 * cos_val - x1 * sin_val;
|
||||
x[base + i + half_dim] = x1 * cos_val + x0 * sin_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -88,8 +92,13 @@ fn check_close(result: &[f32], expected: &[f32], atol: f32, name: &str) {
|
||||
let mut max_err = 0.0f32;
|
||||
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
|
||||
let err = (r - e).abs();
|
||||
if err > max_err { max_err = err; }
|
||||
assert!(err <= atol, "{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}");
|
||||
if err > max_err {
|
||||
max_err = err;
|
||||
}
|
||||
assert!(
|
||||
err <= atol,
|
||||
"{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}"
|
||||
);
|
||||
}
|
||||
println!("{name}: max_err = {max_err:.6e}");
|
||||
}
|
||||
@@ -208,13 +217,18 @@ fn test_softmax_sum_to_one() {
|
||||
init();
|
||||
let rows = 4;
|
||||
let cols = 2048;
|
||||
let data: Vec<f32> = (0..rows * cols).map(|i| ((i % 31) as f32 - 15.0) * 0.5).collect();
|
||||
let data: Vec<f32> = (0..rows * cols)
|
||||
.map(|i| ((i % 31) as f32 - 15.0) * 0.5)
|
||||
.collect();
|
||||
let x = Tensor::from_slice(&data, &[rows, cols]).to_device(Device::Cuda(0));
|
||||
let out = softmax(&x).to_device(Device::Cpu);
|
||||
let result = out.as_slice::<f32>();
|
||||
for r in 0..rows {
|
||||
let row_sum: f32 = result[r * cols..(r + 1) * cols].iter().sum();
|
||||
assert!((row_sum - 1.0).abs() < 1e-5, "softmax row {r} sum = {row_sum}");
|
||||
assert!(
|
||||
(row_sum - 1.0).abs() < 1e-5,
|
||||
"softmax row {r} sum = {row_sum}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,8 +261,10 @@ fn test_embedding_f32() {
|
||||
for i in 0..hidden {
|
||||
let expected = table_data[tid as usize * hidden + i];
|
||||
let got = result[seq_idx * hidden + i];
|
||||
assert!((got - expected).abs() < 1e-6,
|
||||
"embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}");
|
||||
assert!(
|
||||
(got - expected).abs() < 1e-6,
|
||||
"embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -270,8 +286,8 @@ fn test_rope_f32() {
|
||||
let mut expected = x_data.clone();
|
||||
cpu_rope(&mut expected, &positions, num_heads, head_dim, theta);
|
||||
|
||||
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
|
||||
.to_device(Device::Cuda(0));
|
||||
let x =
|
||||
Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim]).to_device(Device::Cuda(0));
|
||||
let cache = RopeCache::new(64, head_dim, theta);
|
||||
rope_inplace(&x, &cache, &positions);
|
||||
|
||||
@@ -292,8 +308,8 @@ fn test_rope_position_0_identity() {
|
||||
.map(|i| (i as f32 + 1.0) * 0.1)
|
||||
.collect();
|
||||
|
||||
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
|
||||
.to_device(Device::Cuda(0));
|
||||
let x =
|
||||
Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim]).to_device(Device::Cuda(0));
|
||||
let cache = RopeCache::new(64, head_dim, 10000.0);
|
||||
rope_inplace(&x, &cache, &positions);
|
||||
|
||||
|
||||
18
crates/xserv-model/Cargo.toml
Normal file
18
crates/xserv-model/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "xserv-model"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
xserv-cuda = { path = "../xserv-cuda" }
|
||||
xserv-tensor = { path = "../xserv-tensor" }
|
||||
xserv-kernels = { path = "../xserv-kernels" }
|
||||
xserv-tokenizer = { path = "../xserv-tokenizer" }
|
||||
xserv-distributed = { path = "../xserv-distributed" }
|
||||
half.workspace = true
|
||||
libc.workspace = true
|
||||
smallvec.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
safetensors.workspace = true
|
||||
rand.workspace = true
|
||||
1126
crates/xserv-model/src/bin/bench-eagle3.rs
Normal file
1126
crates/xserv-model/src/bin/bench-eagle3.rs
Normal file
File diff suppressed because it is too large
Load Diff
421
crates/xserv-model/src/bin/bench-gpt-oss.rs
Normal file
421
crates/xserv-model/src/bin/bench-gpt-oss.rs
Normal file
@@ -0,0 +1,421 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use xserv_distributed::{TpContext, UniqueId, get_unique_id};
|
||||
use xserv_model::{BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: bench-gpt-oss <model-dir> [--max-tokens N] [--tp N]");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let max_tokens: usize = get_arg(&args, "--max-tokens").unwrap_or(32);
|
||||
let world: usize = get_arg(&args, "--tp").unwrap_or(2);
|
||||
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
|
||||
eprintln!(
|
||||
"gpt-oss-20b: layers={}, hidden={}, heads={}/{} kv, experts={}, top_k={}, vocab={}",
|
||||
config.num_layers(),
|
||||
config.hidden(),
|
||||
config.num_heads(),
|
||||
config.num_kv_heads(),
|
||||
config.num_experts(),
|
||||
config.experts_per_token(),
|
||||
config.vocab_size
|
||||
);
|
||||
eprintln!("TP world={world}, max_tokens={max_tokens}");
|
||||
|
||||
let max_seq_len: usize = 2048;
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
// TP setup
|
||||
let uid = get_unique_id();
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
|
||||
// Spawn worker threads for ranks 1..world
|
||||
let mut worker_handles = Vec::new();
|
||||
let mut worker_txs = Vec::new();
|
||||
for rank in 1..world {
|
||||
let (tx, rx) = std::sync::mpsc::channel::<WorkerCmd>();
|
||||
let (ack_tx, ack_rx) = std::sync::mpsc::channel::<()>();
|
||||
let cfg = config.clone();
|
||||
let md = model_dir.clone();
|
||||
let uid_copy = uid;
|
||||
worker_handles.push((
|
||||
std::thread::spawn(move || {
|
||||
worker_loop(rank, world, uid_copy, md, cfg, max_seq_len, rx, ack_tx);
|
||||
}),
|
||||
ack_rx,
|
||||
));
|
||||
worker_txs.push(tx);
|
||||
}
|
||||
|
||||
// Rank 0 setup
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
let tp0 = Arc::new(TpContext::init(0, world, uid, 0));
|
||||
eprintln!("[rank 0] Loading weights...");
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
|
||||
eprintln!(
|
||||
"[rank 0] Loaded {} tensors, building model...",
|
||||
weights.len()
|
||||
);
|
||||
let model = GptOss::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp0));
|
||||
let total_blocks = max_blocks_per_seq + 64;
|
||||
let mut cache = PagedKVCache::new_tp(
|
||||
&config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
0,
|
||||
);
|
||||
eprintln!("[rank 0] Ready.");
|
||||
|
||||
// Prompt
|
||||
let prompt_arg = get_arg::<String>(&args, "--prompt");
|
||||
let prompt = prompt_arg
|
||||
.as_deref()
|
||||
.unwrap_or("What is the meaning of life?");
|
||||
let token_ids = tokenizer.encode(prompt);
|
||||
eprintln!("Prompt ({} tokens): {prompt}", token_ids.len());
|
||||
|
||||
// Register sequence
|
||||
let slot = 0;
|
||||
cache.register_sequence(slot).unwrap();
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Register(slot));
|
||||
|
||||
// Teacher-forced diagnostic: prefill (prompt + forced ids) in one shot and
|
||||
// report, for each forced position, whether xserv's argmax == the forced
|
||||
// (oracle) next token. Removes free-running compounding so it isolates
|
||||
// whether per-position logits agree with the llama.cpp trajectory.
|
||||
if let Some(forced) = get_arg::<String>(&args, "--forced") {
|
||||
let forced_ids: Vec<u32> = forced
|
||||
.split(',')
|
||||
.filter_map(|s| s.trim().parse().ok())
|
||||
.collect();
|
||||
let mut seq = token_ids.clone();
|
||||
seq.extend_from_slice(&forced_ids);
|
||||
// Workers must run the same prefill in lockstep (TP AllReduces match up).
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Prefill {
|
||||
tokens: seq.clone(),
|
||||
slot,
|
||||
},
|
||||
);
|
||||
let logits = model.forward_prefill_paged(&seq, slot, &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let vocab = logits.shape()[1];
|
||||
let data = logits_cpu.as_slice::<half::bf16>();
|
||||
let plen = token_ids.len();
|
||||
let mut matches = 0usize;
|
||||
let mut total = 0usize;
|
||||
// position i predicts seq[i+1]; we check the forced region
|
||||
for i in (plen - 1)..(seq.len() - 1) {
|
||||
let row = &data[i * vocab..(i + 1) * vocab];
|
||||
let argmax = row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(j, _)| j as u32)
|
||||
.unwrap();
|
||||
let expected = seq[i + 1];
|
||||
let ok = argmax == expected;
|
||||
if ok {
|
||||
matches += 1;
|
||||
}
|
||||
total += 1;
|
||||
eprintln!(
|
||||
"pos {i}: xserv_argmax={argmax} oracle={expected} {}",
|
||||
if ok { "OK" } else { "DIFF" }
|
||||
);
|
||||
}
|
||||
eprintln!(
|
||||
"\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%",
|
||||
100.0 * matches as f64 / total as f64
|
||||
);
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
|
||||
for (h, _) in worker_handles {
|
||||
h.join().unwrap();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Teacher-forced DECODE diagnostic: prefill the prompt, then walk the oracle
|
||||
// trajectory through the autoregressive decode path (NOT prefill), recording
|
||||
// per-position top-1 agreement bucketed by position. Localizes long-context
|
||||
// decode degradation (which prefill teacher-forcing cannot see).
|
||||
if let Some(forced) = get_arg::<String>(&args, "--forced-decode") {
|
||||
let forced_ids: Vec<u32> = forced
|
||||
.split(',')
|
||||
.filter_map(|s| s.trim().parse().ok())
|
||||
.collect();
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Prefill {
|
||||
tokens: token_ids.clone(),
|
||||
slot,
|
||||
},
|
||||
);
|
||||
let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
let mut pred = sample_greedy_last(&logits); // prediction for forced[0]
|
||||
let bucket = 50usize;
|
||||
let mut buckets: Vec<(usize, usize)> = Vec::new();
|
||||
let (mut matches, mut total) = (0usize, 0usize);
|
||||
for (i, &f) in forced_ids.iter().enumerate() {
|
||||
let ok = pred == f;
|
||||
matches += ok as usize;
|
||||
total += 1;
|
||||
let b = i / bucket;
|
||||
if buckets.len() <= b {
|
||||
buckets.push((0, 0));
|
||||
}
|
||||
buckets[b].0 += ok as usize;
|
||||
buckets[b].1 += 1;
|
||||
// Teacher-force: feed the oracle token through the decode path.
|
||||
let pos = cache.seq_len(slot);
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Decode {
|
||||
tokens: vec![f],
|
||||
positions: vec![pos],
|
||||
slots: vec![slot],
|
||||
},
|
||||
);
|
||||
let logits = model.forward_decode_paged(&[f], &[pos], &[slot], &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
pred = sample_greedy_last(&logits);
|
||||
}
|
||||
eprintln!(
|
||||
"Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%",
|
||||
100.0 * matches as f64 / total as f64
|
||||
);
|
||||
for (b, (m, t)) in buckets.iter().enumerate() {
|
||||
eprintln!(
|
||||
" pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%",
|
||||
b * bucket,
|
||||
b * bucket + t,
|
||||
100.0 * (*m as f64) / (*t as f64)
|
||||
);
|
||||
}
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
|
||||
for (h, _) in worker_handles {
|
||||
h.join().unwrap();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Prefill
|
||||
let t0 = Instant::now();
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Prefill {
|
||||
tokens: token_ids.clone(),
|
||||
slot,
|
||||
},
|
||||
);
|
||||
let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
let ttft = t0.elapsed();
|
||||
|
||||
let mut next = sample_greedy_last(&logits);
|
||||
let mut output_tokens = vec![next];
|
||||
|
||||
eprintln!("TTFT: {:.1}ms", ttft.as_secs_f64() * 1000.0);
|
||||
print!("{prompt}");
|
||||
|
||||
// Decode
|
||||
let mut decoder = GraphedGptOssDecoder::new();
|
||||
let decode_start = Instant::now();
|
||||
for _ in 1..max_tokens {
|
||||
let text = tokenizer.decode(&[next]);
|
||||
print!("{text}");
|
||||
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = cache.seq_len(slot);
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Decode {
|
||||
tokens: vec![next],
|
||||
positions: vec![pos],
|
||||
slots: vec![slot],
|
||||
},
|
||||
);
|
||||
let logits = decoder.decode(&model, &[next], &[pos], &[slot], &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
|
||||
next = sample_greedy_last(&logits);
|
||||
output_tokens.push(next);
|
||||
}
|
||||
let decode_elapsed = decode_start.elapsed();
|
||||
println!();
|
||||
|
||||
let gen_tokens = output_tokens.len();
|
||||
let full_text = tokenizer.decode(&output_tokens);
|
||||
eprintln!("\nGenerated text: {full_text}");
|
||||
eprintln!(
|
||||
"Token IDs: {:?}",
|
||||
&output_tokens[..output_tokens.len().min(20)]
|
||||
);
|
||||
let tpot = if gen_tokens > 1 {
|
||||
decode_elapsed.as_secs_f64() * 1000.0 / (gen_tokens - 1) as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let tok_s = if gen_tokens > 1 {
|
||||
(gen_tokens - 1) as f64 / decode_elapsed.as_secs_f64()
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
eprintln!("\n--- Performance ---");
|
||||
eprintln!("Generated: {} tokens", gen_tokens);
|
||||
eprintln!("TTFT: {:.1}ms", ttft.as_secs_f64() * 1000.0);
|
||||
eprintln!("TPOT: {:.1}ms", tpot);
|
||||
eprintln!("Throughput: {:.1} tok/s", tok_s);
|
||||
|
||||
// Cleanup
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
|
||||
for (h, _) in worker_handles {
|
||||
h.join().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// --- Worker infrastructure ---
|
||||
|
||||
#[derive(Clone)]
|
||||
enum WorkerCmd {
|
||||
Register(usize),
|
||||
Prefill {
|
||||
tokens: Vec<u32>,
|
||||
slot: usize,
|
||||
},
|
||||
Decode {
|
||||
tokens: Vec<u32>,
|
||||
positions: Vec<usize>,
|
||||
slots: Vec<usize>,
|
||||
},
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
fn worker_loop(
|
||||
rank: usize,
|
||||
world: usize,
|
||||
uid: UniqueId,
|
||||
model_dir: PathBuf,
|
||||
config: ModelConfig,
|
||||
max_seq_len: usize,
|
||||
rx: std::sync::mpsc::Receiver<WorkerCmd>,
|
||||
ack_tx: std::sync::mpsc::Sender<()>,
|
||||
) {
|
||||
xserv_cuda::device::set_device(rank as u32).unwrap();
|
||||
let tp = Arc::new(TpContext::init(rank, world, uid, rank as u32));
|
||||
eprintln!("[rank {rank}] Loading weights...");
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
|
||||
let model =
|
||||
GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp));
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 64;
|
||||
let mut cache = PagedKVCache::new_tp(
|
||||
&config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
rank as u32,
|
||||
);
|
||||
eprintln!("[rank {rank}] Ready.");
|
||||
ack_tx.send(()).unwrap();
|
||||
|
||||
let mut decoder = GraphedGptOssDecoder::new();
|
||||
while let Ok(cmd) = rx.recv() {
|
||||
match cmd {
|
||||
WorkerCmd::Register(slot) => {
|
||||
let _ = cache.register_sequence(slot);
|
||||
}
|
||||
WorkerCmd::Prefill { tokens, slot } => {
|
||||
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
|
||||
}
|
||||
WorkerCmd::Decode {
|
||||
tokens,
|
||||
positions,
|
||||
slots,
|
||||
} => {
|
||||
let _ = decoder.decode(&model, &tokens, &positions, &slots, &mut cache);
|
||||
}
|
||||
WorkerCmd::Shutdown => break,
|
||||
}
|
||||
ack_tx.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
fn broadcast_cmd(
|
||||
txs: &[std::sync::mpsc::Sender<WorkerCmd>],
|
||||
_handles: &[(std::thread::JoinHandle<()>, std::sync::mpsc::Receiver<()>)],
|
||||
cmd: WorkerCmd,
|
||||
) {
|
||||
for tx in txs {
|
||||
tx.send(cmd.clone()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
fn wait_workers(handles: &[(std::thread::JoinHandle<()>, std::sync::mpsc::Receiver<()>)]) {
|
||||
for (_, rx) in handles {
|
||||
rx.recv().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 {
|
||||
use half::bf16;
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
// GPU argmax fast path (4-byte D2H instead of the full logits row).
|
||||
if logits.dtype() == xserv_tensor::DType::BF16 && logits.is_contiguous() {
|
||||
let ids = xserv_kernels::argmax_bf16_to_host(logits);
|
||||
return *ids.last().unwrap();
|
||||
}
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
let data = logits_cpu.as_slice::<bf16>();
|
||||
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
|
||||
last.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| {
|
||||
let af = a.1.to_f32();
|
||||
let bf = b.1.to_f32();
|
||||
af.partial_cmp(&bf).unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn get_arg<T: std::str::FromStr>(args: &[String], flag: &str) -> Option<T> {
|
||||
args.iter()
|
||||
.position(|a| a == flag)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
}
|
||||
221
crates/xserv-model/src/bin/bench-gpt2.rs
Normal file
221
crates/xserv-model/src/bin/bench-gpt2.rs
Normal file
@@ -0,0 +1,221 @@
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
use xserv_model::gpt2::{KVCache, sample_greedy};
|
||||
use xserv_model::{GPT2, ModelConfig, loader};
|
||||
use xserv_tensor::Device;
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: bench-gpt2 <model-dir> [--gen-tokens N] [--no-cache]");
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let gen_tokens: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--gen-tokens")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(20);
|
||||
let use_cache = !args.iter().any(|a| a == "--no-cache");
|
||||
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cuda(0));
|
||||
let model = GPT2::from_weights(config.clone(), weights);
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
|
||||
// Warmup
|
||||
{
|
||||
let ids = tokenizer.encode("warmup");
|
||||
let _ = model.forward(&ids);
|
||||
}
|
||||
|
||||
eprintln!("mode: {}", if use_cache { "KV cache" } else { "no cache" });
|
||||
|
||||
let prompts: Vec<&str> = vec![
|
||||
"The capital of France is",
|
||||
"Once upon a time in a land far away",
|
||||
"Hello, how are you doing today",
|
||||
"In a shocking finding, scientists discovered a",
|
||||
"The weather today is sunny, so I decided to",
|
||||
"Alan Turing was a British mathematician who",
|
||||
"The best way to learn programming is",
|
||||
"Artificial intelligence will change the world because",
|
||||
"The history of the internet began in the",
|
||||
"A good morning routine starts with",
|
||||
"The stock market crashed because investors",
|
||||
"Deep learning is a subset of machine learning that",
|
||||
"The president of the United States announced",
|
||||
"In the year 2050, humans will",
|
||||
"The secret to happiness is",
|
||||
"When I was a child, I used to",
|
||||
"The most important scientific discovery of the century",
|
||||
"Climate change is caused by",
|
||||
"The recipe for chocolate cake requires",
|
||||
"In conclusion, the evidence suggests that",
|
||||
"The cat sat on the mat and",
|
||||
"According to recent studies, exercise can",
|
||||
"The first step in solving any problem is",
|
||||
"Technology has transformed the way we",
|
||||
"The novel begins with the protagonist",
|
||||
"Education is the most powerful weapon",
|
||||
"The ocean covers more than seventy percent of",
|
||||
"Last night I had a dream about",
|
||||
"The company announced its quarterly earnings",
|
||||
"Music has the power to",
|
||||
"The difference between success and failure is",
|
||||
"In the beginning, there was nothing but",
|
||||
"The doctor told me that I should",
|
||||
"Python is a popular programming language because",
|
||||
"The ancient Romans built roads that",
|
||||
"A balanced diet should include",
|
||||
"The movie received mixed reviews from critics",
|
||||
"Space exploration has led to many",
|
||||
"The teacher asked the students to",
|
||||
"Global warming is one of the most",
|
||||
"The bridge collapsed due to structural",
|
||||
"Quantum computing promises to revolutionize",
|
||||
"The new policy will affect millions of",
|
||||
"During the winter months, it is important to",
|
||||
"The human brain contains approximately",
|
||||
"Democracy depends on the active participation of",
|
||||
"The train arrived at the station exactly",
|
||||
"Researchers at MIT have developed a new",
|
||||
"The smartphone has become an essential part of",
|
||||
"After careful consideration, the committee decided to",
|
||||
];
|
||||
|
||||
println!("[");
|
||||
for (i, prompt) in prompts.iter().enumerate() {
|
||||
let input_ids = tokenizer.encode(prompt);
|
||||
let input_len = input_ids.len();
|
||||
|
||||
let (generated_ids, ttft_us, token_times_us) = if use_cache {
|
||||
generate_with_cache(&model, &config, &tokenizer, &input_ids, gen_tokens)
|
||||
} else {
|
||||
generate_no_cache(&model, &tokenizer, &input_ids, gen_tokens)
|
||||
};
|
||||
|
||||
let num_generated = generated_ids.len();
|
||||
let generated_text = tokenizer.decode(&generated_ids);
|
||||
|
||||
let tbt_us = if !token_times_us.is_empty() {
|
||||
token_times_us.iter().sum::<u128>() / token_times_us.len() as u128
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let total_gen_us: u128 = ttft_us + token_times_us.iter().sum::<u128>();
|
||||
let tpot_us = if num_generated > 0 {
|
||||
total_gen_us / num_generated as u128
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let gen_text_escaped = generated_text
|
||||
.replace('\\', "\\\\")
|
||||
.replace('"', "\\\"")
|
||||
.replace('\n', "\\n")
|
||||
.replace('\r', "\\r")
|
||||
.replace('\t', "\\t");
|
||||
let gen_ids_str: Vec<String> = generated_ids.iter().map(|id| id.to_string()).collect();
|
||||
|
||||
print!(" {{\"prompt\": \"{}\", ", prompt.replace('"', "\\\""));
|
||||
print!("\"input_len\": {input_len}, ");
|
||||
print!("\"num_generated\": {num_generated}, ");
|
||||
print!("\"generated_ids\": [{}], ", gen_ids_str.join(", "));
|
||||
print!("\"generated_text\": \"{gen_text_escaped}\", ");
|
||||
print!("\"ttft_us\": {ttft_us}, ");
|
||||
print!("\"tbt_us\": {tbt_us}, ");
|
||||
print!("\"tpot_us\": {tpot_us}}}");
|
||||
if i < prompts.len() - 1 {
|
||||
println!(",");
|
||||
} else {
|
||||
println!();
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"[{}/{}] input={input_len}tok gen={num_generated}tok ttft={:.1}ms tbt={:.1}ms | {}",
|
||||
i + 1,
|
||||
prompts.len(),
|
||||
ttft_us as f64 / 1000.0,
|
||||
tbt_us as f64 / 1000.0,
|
||||
&generated_text.replace('\n', " ")[..generated_text.len().min(60)]
|
||||
);
|
||||
}
|
||||
println!("]");
|
||||
}
|
||||
|
||||
fn generate_with_cache(
|
||||
model: &GPT2,
|
||||
config: &ModelConfig,
|
||||
tokenizer: &Tokenizer,
|
||||
input_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
) -> (Vec<u32>, u128, Vec<u128>) {
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(),
|
||||
config.num_heads(),
|
||||
config.head_dim(),
|
||||
xserv_tensor::DType::F32,
|
||||
Device::Cuda(0),
|
||||
);
|
||||
|
||||
// Prefill
|
||||
let t0 = Instant::now();
|
||||
let logits = model.forward_with_cache(input_ids, &mut cache);
|
||||
let first_token = sample_greedy(&logits);
|
||||
let ttft_us = t0.elapsed().as_micros();
|
||||
|
||||
let mut generated = vec![first_token];
|
||||
let mut token_times = Vec::new();
|
||||
|
||||
// Decode
|
||||
for _ in 1..gen_tokens {
|
||||
let last = *generated.last().unwrap();
|
||||
let t_start = Instant::now();
|
||||
let logits = model.forward_with_cache(&[last], &mut cache);
|
||||
let next = sample_greedy(&logits);
|
||||
token_times.push(t_start.elapsed().as_micros());
|
||||
generated.push(next);
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(generated, ttft_us, token_times)
|
||||
}
|
||||
|
||||
fn generate_no_cache(
|
||||
model: &GPT2,
|
||||
tokenizer: &Tokenizer,
|
||||
input_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
) -> (Vec<u32>, u128, Vec<u128>) {
|
||||
let mut all_ids = input_ids.to_vec();
|
||||
|
||||
let t0 = Instant::now();
|
||||
let logits = model.forward(&all_ids);
|
||||
let first_token = sample_greedy(&logits);
|
||||
let ttft_us = t0.elapsed().as_micros();
|
||||
all_ids.push(first_token);
|
||||
|
||||
let mut generated = vec![first_token];
|
||||
let mut token_times = Vec::new();
|
||||
|
||||
for _ in 1..gen_tokens {
|
||||
let t_start = Instant::now();
|
||||
let logits = model.forward(&all_ids);
|
||||
let next = sample_greedy(&logits);
|
||||
token_times.push(t_start.elapsed().as_micros());
|
||||
all_ids.push(next);
|
||||
generated.push(next);
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(generated, ttft_us, token_times)
|
||||
}
|
||||
232
crates/xserv-model/src/bin/bench-qwen3.rs
Normal file
232
crates/xserv-model/src/bin/bench-qwen3.rs
Normal file
@@ -0,0 +1,232 @@
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
use xserv_model::qwen3::sample_greedy;
|
||||
use xserv_model::{DecodeGraphState, GpuKVCache, ModelConfig, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: bench-qwen3 <model-dir> [--gen-tokens N] [--cuda-graph]");
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let gen_tokens: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--gen-tokens")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(20);
|
||||
let use_cuda_graph = args.iter().any(|a| a == "--cuda-graph");
|
||||
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
eprintln!("Loading Qwen3-8B weights...");
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cuda(0));
|
||||
eprintln!("Loaded {} tensors", weights.len());
|
||||
let model = Qwen3::from_weights(config.clone(), weights);
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
|
||||
// Warmup
|
||||
{
|
||||
let ids = tokenizer.encode("warmup");
|
||||
let mut cache = GpuKVCache::new(&config, 256, DType::BF16, 0);
|
||||
let _ = model.forward_gpu_cache(&ids, &mut cache);
|
||||
}
|
||||
|
||||
// CUDA Graph setup
|
||||
let layer_ptrs = model.layer_weight_ptrs();
|
||||
let (norm_w, lm_head, embed, cos, sin) = model.graph_capture_ptrs();
|
||||
let mut decode_graph = if use_cuda_graph {
|
||||
eprintln!("CUDA Graph mode enabled");
|
||||
Some(DecodeGraphState::new(&config))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut graph_captured = false;
|
||||
|
||||
eprintln!("Warmup done. Running benchmark...");
|
||||
|
||||
let prompts: Vec<&str> = vec![
|
||||
"The capital of France is",
|
||||
"Once upon a time in a land far away",
|
||||
"Hello, how are you doing today",
|
||||
"In a shocking finding, scientists discovered a",
|
||||
"The weather today is sunny, so I decided to",
|
||||
"Alan Turing was a British mathematician who",
|
||||
"The best way to learn programming is",
|
||||
"Artificial intelligence will change the world because",
|
||||
"The history of the internet began in the",
|
||||
"A good morning routine starts with",
|
||||
"The stock market crashed because investors",
|
||||
"Deep learning is a subset of machine learning that",
|
||||
"The president of the United States announced",
|
||||
"In the year 2050, humans will",
|
||||
"The secret to happiness is",
|
||||
"When I was a child, I used to",
|
||||
"The most important scientific discovery of the century",
|
||||
"Climate change is caused by",
|
||||
"The recipe for chocolate cake requires",
|
||||
"In conclusion, the evidence suggests that",
|
||||
"The cat sat on the mat and",
|
||||
"According to recent studies, exercise can",
|
||||
"The first step in solving any problem is",
|
||||
"Technology has transformed the way we",
|
||||
"The novel begins with the protagonist",
|
||||
"Education is the most powerful weapon",
|
||||
"The ocean covers more than seventy percent of",
|
||||
"Last night I had a dream about",
|
||||
"The company announced its quarterly earnings",
|
||||
"Music has the power to",
|
||||
"The difference between success and failure is",
|
||||
"In the beginning, there was nothing but",
|
||||
"The doctor told me that I should",
|
||||
"Python is a popular programming language because",
|
||||
"The ancient Romans built roads that",
|
||||
"A balanced diet should include",
|
||||
"The movie received mixed reviews from critics",
|
||||
"Space exploration has led to many",
|
||||
"The teacher asked the students to",
|
||||
"Global warming is one of the most",
|
||||
"The bridge collapsed due to structural",
|
||||
"Quantum computing promises to revolutionize",
|
||||
"The new policy will affect millions of",
|
||||
"During the winter months, it is important to",
|
||||
"The human brain contains approximately",
|
||||
"Democracy depends on the active participation of",
|
||||
"The train arrived at the station exactly",
|
||||
"Researchers at MIT have developed a new",
|
||||
"The smartphone has become an essential part of",
|
||||
"After careful consideration, the committee decided to",
|
||||
];
|
||||
|
||||
println!("[");
|
||||
for (i, prompt) in prompts.iter().enumerate() {
|
||||
let input_ids = tokenizer.encode(prompt);
|
||||
let input_len = input_ids.len();
|
||||
|
||||
let mut cache = GpuKVCache::new(&config, 256, DType::BF16, 0);
|
||||
|
||||
// Reset graph state for new prompt
|
||||
graph_captured = false;
|
||||
if let Some(ref mut g) = decode_graph {
|
||||
g.invalidate();
|
||||
}
|
||||
|
||||
// Prefill
|
||||
let t0 = Instant::now();
|
||||
let logits = model.forward_gpu_cache(&input_ids, &mut cache);
|
||||
let first_token = sample_greedy(&logits);
|
||||
let ttft_us = t0.elapsed().as_micros();
|
||||
|
||||
let mut generated = vec![first_token];
|
||||
let mut token_times = Vec::new();
|
||||
|
||||
// Decode
|
||||
for _ in 1..gen_tokens {
|
||||
let last = *generated.last().unwrap();
|
||||
let t_start = Instant::now();
|
||||
|
||||
let next = if let Some(ref mut graph) = decode_graph {
|
||||
if !graph_captured {
|
||||
// First decode token: run ungraphed, then capture
|
||||
let logits = model.forward_gpu_cache(&[last], &mut cache);
|
||||
graph_captured = true;
|
||||
graph.capture(&layer_ptrs, norm_w, lm_head, embed, cos, sin);
|
||||
sample_greedy(&logits)
|
||||
} else {
|
||||
// Replay captured graphs
|
||||
let pos = cache.seq_len() as u32;
|
||||
graph.execute(
|
||||
last,
|
||||
pos,
|
||||
&mut cache,
|
||||
&layer_ptrs,
|
||||
embed,
|
||||
config.vocab_size as i32,
|
||||
config.hidden() as i32,
|
||||
);
|
||||
cache.advance_seq_len(1);
|
||||
// Read logits from graph buffer
|
||||
let vocab_size = config.vocab_size;
|
||||
let mut logits_bytes = vec![0u8; vocab_size * 2];
|
||||
graph
|
||||
.logits_buffer()
|
||||
.copy_to_host(&mut logits_bytes)
|
||||
.unwrap();
|
||||
let logits_data: &[half::bf16] = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
logits_bytes.as_ptr() as *const half::bf16,
|
||||
vocab_size,
|
||||
)
|
||||
};
|
||||
logits_data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(idx, _)| idx as u32)
|
||||
.unwrap()
|
||||
}
|
||||
} else {
|
||||
let logits = model.forward_gpu_cache(&[last], &mut cache);
|
||||
sample_greedy(&logits)
|
||||
};
|
||||
|
||||
token_times.push(t_start.elapsed().as_micros());
|
||||
generated.push(next);
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let num_generated = generated.len();
|
||||
let generated_text = tokenizer.decode(&generated);
|
||||
let tbt_us = if !token_times.is_empty() {
|
||||
token_times.iter().sum::<u128>() / token_times.len() as u128
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let total_gen_us: u128 = ttft_us + token_times.iter().sum::<u128>();
|
||||
let tpot_us = if num_generated > 0 {
|
||||
total_gen_us / num_generated as u128
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let gen_text_escaped = generated_text
|
||||
.replace('\\', "\\\\")
|
||||
.replace('"', "\\\"")
|
||||
.replace('\n', "\\n")
|
||||
.replace('\r', "\\r")
|
||||
.replace('\t', "\\t");
|
||||
let gen_ids_str: Vec<String> = generated.iter().map(|id| id.to_string()).collect();
|
||||
|
||||
print!(" {{\"prompt\": \"{}\", ", prompt.replace('"', "\\\""));
|
||||
print!("\"input_len\": {input_len}, ");
|
||||
print!("\"num_generated\": {num_generated}, ");
|
||||
print!("\"generated_ids\": [{}], ", gen_ids_str.join(", "));
|
||||
print!("\"generated_text\": \"{gen_text_escaped}\", ");
|
||||
print!("\"ttft_us\": {ttft_us}, ");
|
||||
print!("\"tbt_us\": {tbt_us}, ");
|
||||
print!("\"tpot_us\": {tpot_us}}}");
|
||||
if i < prompts.len() - 1 {
|
||||
println!(",");
|
||||
} else {
|
||||
println!();
|
||||
}
|
||||
|
||||
let display_text = generated_text.replace('\n', " ");
|
||||
let truncated: String = display_text.chars().take(60).collect();
|
||||
eprintln!(
|
||||
"[{}/{}] input={input_len}tok gen={num_generated}tok ttft={:.1}ms tbt={:.1}ms | {}",
|
||||
i + 1,
|
||||
prompts.len(),
|
||||
ttft_us as f64 / 1000.0,
|
||||
tbt_us as f64 / 1000.0,
|
||||
truncated
|
||||
);
|
||||
}
|
||||
println!("]");
|
||||
}
|
||||
976
crates/xserv-model/src/bin/bench-speculative.rs
Normal file
976
crates/xserv-model/src/bin/bench-speculative.rs
Normal file
@@ -0,0 +1,976 @@
|
||||
//! Draft-model speculative decoding benchmark for Qwen3.
|
||||
//!
|
||||
//! v0 scope:
|
||||
//! - target + draft are Qwen3-family models with the same tokenizer/vocab;
|
||||
//! - batch=1;
|
||||
//! - greedy exact-match acceptance;
|
||||
//! - no probabilistic rejection sampling.
|
||||
|
||||
use half::bf16;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Instant;
|
||||
|
||||
use xserv_model::qwen3_graph::GraphedQwen3Decoder;
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
const DEFAULT_GAMMA: usize = 4;
|
||||
const DEFAULT_GEN_TOKENS: usize = 64;
|
||||
const DEFAULT_MAX_SEQ_LEN: usize = 2048;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum VerifyPath {
|
||||
Flash,
|
||||
PagedDecode,
|
||||
}
|
||||
|
||||
impl VerifyPath {
|
||||
fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
VerifyPath::Flash => "flash",
|
||||
VerifyPath::PagedDecode => "paged-decode",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const PROMPTS: [&str; 50] = [
|
||||
"The capital of France is",
|
||||
"Once upon a time in a land far away",
|
||||
"Hello, how are you doing today",
|
||||
"In a shocking finding, scientists discovered a",
|
||||
"The weather today is sunny, so I decided to",
|
||||
"Alan Turing was a British mathematician who",
|
||||
"The best way to learn programming is",
|
||||
"Artificial intelligence will change the world because",
|
||||
"The history of the internet began in the",
|
||||
"A good morning routine starts with",
|
||||
"The stock market crashed because investors",
|
||||
"Deep learning is a subset of machine learning that",
|
||||
"The president of the United States announced",
|
||||
"In the year 2050, humans will",
|
||||
"The secret to happiness is",
|
||||
"When I was a child, I used to",
|
||||
"The most important scientific discovery of the century",
|
||||
"Climate change is caused by",
|
||||
"The recipe for chocolate cake requires",
|
||||
"In conclusion, the evidence suggests that",
|
||||
"The cat sat on the mat and",
|
||||
"According to recent studies, exercise can",
|
||||
"The first step in solving any problem is",
|
||||
"Technology has transformed the way we",
|
||||
"The novel begins with the protagonist",
|
||||
"Education is the most powerful weapon",
|
||||
"The ocean covers more than seventy percent of",
|
||||
"Last night I had a dream about",
|
||||
"The company announced its quarterly earnings",
|
||||
"Music has the power to",
|
||||
"The difference between success and failure is",
|
||||
"In the beginning, there was nothing but",
|
||||
"The doctor told me that I should",
|
||||
"Python is a popular programming language because",
|
||||
"The ancient Romans built roads that",
|
||||
"A balanced diet should include",
|
||||
"The movie received mixed reviews from critics",
|
||||
"Space exploration has led to many",
|
||||
"The teacher asked the students to",
|
||||
"Global warming is one of the most",
|
||||
"The bridge collapsed due to structural",
|
||||
"Quantum computing promises to revolutionize",
|
||||
"The new policy will affect millions of",
|
||||
"During the winter months, it is important to",
|
||||
"The human brain contains approximately",
|
||||
"Democracy depends on the active participation of",
|
||||
"The train arrived at the station exactly",
|
||||
"Researchers at MIT have developed a new",
|
||||
"The smartphone has become an essential part of",
|
||||
"After careful consideration, the committee decided to",
|
||||
];
|
||||
|
||||
#[derive(Default)]
|
||||
struct RunStats {
|
||||
ids: Vec<u32>,
|
||||
total_s: f64,
|
||||
prefill_s: f64,
|
||||
decode_s: f64,
|
||||
target_steps: usize,
|
||||
accepted: usize,
|
||||
proposed: usize,
|
||||
verify_steps: usize,
|
||||
mirror_steps: usize,
|
||||
commit_steps: usize,
|
||||
correction_steps: usize,
|
||||
verify_decode_mismatches: usize,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct Totals {
|
||||
prompts: usize,
|
||||
baseline_generated: usize,
|
||||
spec_generated: usize,
|
||||
baseline_total_s: f64,
|
||||
baseline_prefill_s: f64,
|
||||
baseline_decode_s: f64,
|
||||
spec_total_s: f64,
|
||||
spec_prefill_s: f64,
|
||||
spec_decode_s: f64,
|
||||
spec_target_steps: usize,
|
||||
spec_accepted: usize,
|
||||
spec_proposed: usize,
|
||||
spec_verify_steps: usize,
|
||||
spec_mirror_steps: usize,
|
||||
spec_commit_steps: usize,
|
||||
spec_correction_steps: usize,
|
||||
spec_verify_decode_mismatches: usize,
|
||||
mismatches: usize,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 3 {
|
||||
eprintln!(
|
||||
"Usage: bench-speculative <target-model-dir> <draft-model-dir> \
|
||||
[--gen-tokens N] [--gamma N] [--prompts N] [--max-seq-len N] [--device N] \
|
||||
[--use-verify-logits] [--verify-path flash|paged-decode] [--dump-verify-mismatches]"
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let target_dir = PathBuf::from(&args[1]);
|
||||
let draft_dir = PathBuf::from(&args[2]);
|
||||
let gen_tokens = arg_usize(&args, "--gen-tokens", DEFAULT_GEN_TOKENS);
|
||||
let gamma = arg_usize(&args, "--gamma", DEFAULT_GAMMA);
|
||||
let prompt_count = arg_usize(&args, "--prompts", PROMPTS.len()).min(PROMPTS.len());
|
||||
let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN);
|
||||
let device = arg_usize(&args, "--device", 0) as u32;
|
||||
let use_verify_logits = args.iter().any(|a| a == "--use-verify-logits");
|
||||
let verify_path = parse_verify_path(&args, use_verify_logits);
|
||||
let dump_verify_mismatches = args.iter().any(|a| a == "--dump-verify-mismatches");
|
||||
|
||||
assert!(gen_tokens > 0, "--gen-tokens must be > 0");
|
||||
assert!(gamma > 0, "--gamma must be > 0");
|
||||
|
||||
xserv_cuda::device::set_device(device).unwrap();
|
||||
let info = xserv_cuda::device::device_info(device).unwrap();
|
||||
eprintln!(
|
||||
"GPU {device}: {} ({} MB free)",
|
||||
info.name,
|
||||
info.free_memory / 1024 / 1024
|
||||
);
|
||||
|
||||
let target_config = ModelConfig::from_file(&target_dir.join("config.json"));
|
||||
let draft_config = ModelConfig::from_file(&draft_dir.join("config.json"));
|
||||
assert_qwen3(&target_config, "target");
|
||||
assert_qwen3(&draft_config, "draft");
|
||||
assert_eq!(
|
||||
target_config.vocab_size, draft_config.vocab_size,
|
||||
"target and draft vocab_size must match"
|
||||
);
|
||||
|
||||
warn_if_tokenizers_differ(&target_dir, &draft_dir);
|
||||
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
|
||||
if tokenizer.vocab_size() != target_config.vocab_size {
|
||||
eprintln!(
|
||||
"WARNING: tokenizer decoder len {} differs from config vocab_size {}; continuing because token ids come from the shared tokenizer.json",
|
||||
tokenizer.vocab_size(),
|
||||
target_config.vocab_size
|
||||
);
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"Loading target Qwen3: layers={} hidden={} heads={}/{} vocab={}",
|
||||
target_config.num_layers(),
|
||||
target_config.hidden(),
|
||||
target_config.num_heads(),
|
||||
target_config.num_kv_heads(),
|
||||
target_config.vocab_size
|
||||
);
|
||||
let target_weights = loader::load_model_dir(&target_dir, Device::Cuda(device));
|
||||
let target = Qwen3::from_weights(target_config.clone(), target_weights);
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
eprintln!(
|
||||
"Loading draft Qwen3: layers={} hidden={} heads={}/{} vocab={}",
|
||||
draft_config.num_layers(),
|
||||
draft_config.hidden(),
|
||||
draft_config.num_heads(),
|
||||
draft_config.num_kv_heads(),
|
||||
draft_config.vocab_size
|
||||
);
|
||||
let draft_weights = loader::load_model_dir(&draft_dir, Device::Cuda(device));
|
||||
let draft = Qwen3::from_weights(draft_config.clone(), draft_weights);
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
let warm_ids = tokenizer.encode("warmup");
|
||||
let warm_tokens = gen_tokens.min(4);
|
||||
{
|
||||
let mut target_cache = new_cache(&target_config, max_seq_len, device);
|
||||
let _ = run_baseline(
|
||||
&target,
|
||||
&mut target_cache,
|
||||
&tokenizer,
|
||||
&warm_ids,
|
||||
warm_tokens,
|
||||
);
|
||||
}
|
||||
{
|
||||
let mut target_cache = new_cache_with_rows(
|
||||
&target_config,
|
||||
max_seq_len,
|
||||
device,
|
||||
if use_verify_logits { gamma } else { 1 },
|
||||
);
|
||||
let mut target_verify_cache =
|
||||
new_cache_with_rows(&target_config, max_seq_len, device, gamma);
|
||||
let mut draft_cache = new_cache(&draft_config, max_seq_len, device);
|
||||
let mut draft_decoder = GraphedQwen3Decoder::new();
|
||||
let _ = run_speculative(
|
||||
&target,
|
||||
&draft,
|
||||
&mut target_cache,
|
||||
&mut target_verify_cache,
|
||||
&mut draft_cache,
|
||||
&mut draft_decoder,
|
||||
&tokenizer,
|
||||
&warm_ids,
|
||||
warm_tokens,
|
||||
gamma,
|
||||
use_verify_logits,
|
||||
verify_path,
|
||||
dump_verify_mismatches,
|
||||
);
|
||||
}
|
||||
eprintln!(
|
||||
"Warmup done. Running {prompt_count} prompts, gen_tokens={gen_tokens}, gamma={gamma}, acceptance_mode={}, verify_path={}",
|
||||
if use_verify_logits {
|
||||
"verify_logits"
|
||||
} else {
|
||||
"decode"
|
||||
},
|
||||
verify_path.as_str()
|
||||
);
|
||||
|
||||
let mut totals = Totals::default();
|
||||
|
||||
// Persistent per-benchmark caches so the draft CUDA graph (Phase 24) can be
|
||||
// captured once and replayed across every prompt. Freeing and re-registering
|
||||
// slot 0 between prompts keeps block_table_gpu / context_lens_gpu addresses
|
||||
// stable, which is exactly what the graph captured.
|
||||
let mut target_cache = new_cache_with_rows(
|
||||
&target_config,
|
||||
max_seq_len,
|
||||
device,
|
||||
if use_verify_logits { gamma } else { 1 },
|
||||
);
|
||||
let mut target_verify_cache = new_cache_with_rows(&target_config, max_seq_len, device, gamma);
|
||||
let mut draft_cache = new_cache(&draft_config, max_seq_len, device);
|
||||
let mut draft_decoder = GraphedQwen3Decoder::new();
|
||||
|
||||
for (i, prompt) in PROMPTS.iter().take(prompt_count).enumerate() {
|
||||
let ids = tokenizer.encode(prompt);
|
||||
validate_length_budget(&ids, gen_tokens, max_seq_len, prompt);
|
||||
let mut baseline_cache = new_cache(&target_config, max_seq_len, device);
|
||||
let baseline = run_baseline(&target, &mut baseline_cache, &tokenizer, &ids, gen_tokens);
|
||||
drop(baseline_cache);
|
||||
|
||||
let spec = run_speculative(
|
||||
&target,
|
||||
&draft,
|
||||
&mut target_cache,
|
||||
&mut target_verify_cache,
|
||||
&mut draft_cache,
|
||||
&mut draft_decoder,
|
||||
&tokenizer,
|
||||
&ids,
|
||||
gen_tokens,
|
||||
gamma,
|
||||
use_verify_logits,
|
||||
verify_path,
|
||||
dump_verify_mismatches,
|
||||
);
|
||||
|
||||
let matched = baseline.ids == spec.ids;
|
||||
if !matched {
|
||||
totals.mismatches += 1;
|
||||
eprintln!("MISMATCH prompt {i}: {prompt}");
|
||||
eprintln!(" baseline: {:?}", baseline.ids);
|
||||
eprintln!(" spec: {:?}", spec.ids);
|
||||
}
|
||||
|
||||
println!(
|
||||
"prompt={:02} match={} gen={} accept={}/{} target_steps={} \
|
||||
baseline_e2e_tpot_ms={:.3} spec_e2e_tpot_ms={:.3}",
|
||||
i,
|
||||
matched,
|
||||
spec.ids.len(),
|
||||
spec.accepted,
|
||||
spec.proposed,
|
||||
spec.target_steps,
|
||||
per_token_ms(baseline.total_s, baseline.ids.len()),
|
||||
per_token_ms(spec.total_s, spec.ids.len()),
|
||||
);
|
||||
|
||||
totals.prompts += 1;
|
||||
totals.baseline_generated += baseline.ids.len();
|
||||
totals.spec_generated += spec.ids.len();
|
||||
totals.baseline_total_s += baseline.total_s;
|
||||
totals.baseline_prefill_s += baseline.prefill_s;
|
||||
totals.baseline_decode_s += baseline.decode_s;
|
||||
totals.spec_total_s += spec.total_s;
|
||||
totals.spec_prefill_s += spec.prefill_s;
|
||||
totals.spec_decode_s += spec.decode_s;
|
||||
totals.spec_target_steps += spec.target_steps;
|
||||
totals.spec_accepted += spec.accepted;
|
||||
totals.spec_proposed += spec.proposed;
|
||||
totals.spec_verify_steps += spec.verify_steps;
|
||||
totals.spec_mirror_steps += spec.mirror_steps;
|
||||
totals.spec_commit_steps += spec.commit_steps;
|
||||
totals.spec_correction_steps += spec.correction_steps;
|
||||
totals.spec_verify_decode_mismatches += spec.verify_decode_mismatches;
|
||||
}
|
||||
|
||||
let baseline_decode_tokens = totals.baseline_generated;
|
||||
let spec_decode_tokens = totals.spec_generated;
|
||||
let acceptance = ratio(totals.spec_accepted, totals.spec_proposed);
|
||||
let tokens_per_target_step = ratio(totals.spec_generated, totals.spec_target_steps);
|
||||
let matched =
|
||||
totals.mismatches == 0 && (!use_verify_logits || totals.spec_verify_decode_mismatches == 0);
|
||||
|
||||
println!("--- SUMMARY ---");
|
||||
println!("prompts={} matched={matched}", totals.prompts);
|
||||
println!(
|
||||
"acceptance_mode={}",
|
||||
if use_verify_logits {
|
||||
"verify_logits"
|
||||
} else {
|
||||
"decode"
|
||||
}
|
||||
);
|
||||
println!("verify_path={}", verify_path.as_str());
|
||||
println!(
|
||||
"acceptance_rate={:.4} accepted={} proposed={}",
|
||||
acceptance, totals.spec_accepted, totals.spec_proposed
|
||||
);
|
||||
println!(
|
||||
"tokens_per_target_step={:.4} target_steps={} verify_steps={} mirror_decode_steps={} commit_decode_steps={} correction_steps={}",
|
||||
tokens_per_target_step,
|
||||
totals.spec_target_steps,
|
||||
totals.spec_verify_steps,
|
||||
totals.spec_mirror_steps,
|
||||
totals.spec_commit_steps,
|
||||
totals.spec_correction_steps
|
||||
);
|
||||
println!(
|
||||
"verify_decode_mismatches={}",
|
||||
totals.spec_verify_decode_mismatches
|
||||
);
|
||||
println!(
|
||||
"baseline_e2e_tpot_ms={:.3} baseline_e2e_tok_s={:.3}",
|
||||
per_token_ms(totals.baseline_total_s, totals.baseline_generated),
|
||||
tok_s(totals.baseline_generated, totals.baseline_total_s)
|
||||
);
|
||||
println!(
|
||||
"spec_e2e_tpot_ms={:.3} spec_e2e_tok_s={:.3} speedup_e2e={:.4}",
|
||||
per_token_ms(totals.spec_total_s, totals.spec_generated),
|
||||
tok_s(totals.spec_generated, totals.spec_total_s),
|
||||
speedup(totals.baseline_total_s, totals.spec_total_s)
|
||||
);
|
||||
println!(
|
||||
"baseline_decode_tpot_ms={:.3} baseline_decode_tok_s={:.3}",
|
||||
per_token_ms(totals.baseline_decode_s, baseline_decode_tokens),
|
||||
tok_s(baseline_decode_tokens, totals.baseline_decode_s)
|
||||
);
|
||||
println!(
|
||||
"spec_decode_tpot_ms={:.3} spec_decode_tok_s={:.3} speedup_decode={:.4}",
|
||||
per_token_ms(totals.spec_decode_s, spec_decode_tokens),
|
||||
tok_s(spec_decode_tokens, totals.spec_decode_s),
|
||||
speedup(totals.baseline_decode_s, totals.spec_decode_s)
|
||||
);
|
||||
println!(
|
||||
"decode_token_counts baseline={} spec={}",
|
||||
baseline_decode_tokens, spec_decode_tokens
|
||||
);
|
||||
|
||||
if !matched {
|
||||
std::process::exit(2);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_baseline(
|
||||
model: &Qwen3,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
) -> RunStats {
|
||||
let slot = 0;
|
||||
cache.register_sequence(slot).expect("register target slot");
|
||||
|
||||
let t0 = Instant::now();
|
||||
let prefill_start = Instant::now();
|
||||
let logits = model.forward_prefill_paged(prompt_ids, slot, cache);
|
||||
sync_device();
|
||||
let prefill_s = prefill_start.elapsed().as_secs_f64();
|
||||
|
||||
let mut generated = Vec::with_capacity(gen_tokens);
|
||||
let mut next = last_argmax(&logits);
|
||||
generated.push(next);
|
||||
|
||||
let decode_start = Instant::now();
|
||||
let mut target_steps = 0usize;
|
||||
while generated.len() < gen_tokens && !tokenizer.is_eos(next) {
|
||||
let pos = cache.seq_len(slot);
|
||||
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], cache);
|
||||
target_steps += 1;
|
||||
next = last_argmax(&logits);
|
||||
generated.push(next);
|
||||
}
|
||||
sync_device();
|
||||
let decode_s = decode_start.elapsed().as_secs_f64();
|
||||
sync_device();
|
||||
let total_s = t0.elapsed().as_secs_f64();
|
||||
|
||||
cache.free_sequence(slot);
|
||||
RunStats {
|
||||
ids: generated,
|
||||
total_s,
|
||||
prefill_s,
|
||||
decode_s,
|
||||
target_steps,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_speculative(
|
||||
target: &Qwen3,
|
||||
draft: &Qwen3,
|
||||
target_cache: &mut PagedKVCache,
|
||||
target_verify_cache: &mut PagedKVCache,
|
||||
draft_cache: &mut PagedKVCache,
|
||||
draft_decoder: &mut GraphedQwen3Decoder,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
gamma: usize,
|
||||
use_verify_logits: bool,
|
||||
verify_path: VerifyPath,
|
||||
dump_verify_mismatches: bool,
|
||||
) -> RunStats {
|
||||
let slot = 0;
|
||||
target_cache
|
||||
.register_sequence(slot)
|
||||
.expect("register target slot");
|
||||
target_verify_cache
|
||||
.register_sequence(slot)
|
||||
.expect("register target verify slot");
|
||||
draft_cache
|
||||
.register_sequence(slot)
|
||||
.expect("register draft slot");
|
||||
|
||||
let t0 = Instant::now();
|
||||
let prefill_start = Instant::now();
|
||||
let target_logits = target.forward_prefill_paged(prompt_ids, slot, target_cache);
|
||||
if !use_verify_logits {
|
||||
let _ = target.forward_prefill_paged(prompt_ids, slot, target_verify_cache);
|
||||
}
|
||||
let draft_logits = draft.forward_prefill_paged(prompt_ids, slot, draft_cache);
|
||||
sync_device();
|
||||
let prefill_s = prefill_start.elapsed().as_secs_f64();
|
||||
|
||||
let mut target_next = last_argmax(&target_logits);
|
||||
let mut draft_next = last_argmax(&draft_logits);
|
||||
let mut generated = Vec::with_capacity(gen_tokens);
|
||||
let mut accepted_total = 0usize;
|
||||
let mut proposed_total = 0usize;
|
||||
let mut verify_steps = 0usize;
|
||||
let mut mirror_steps = 0usize;
|
||||
let mut commit_steps = 0usize;
|
||||
let mut correction_steps = 0usize;
|
||||
let mut verify_decode_mismatches = 0usize;
|
||||
|
||||
let decode_start = Instant::now();
|
||||
while generated.len() < gen_tokens {
|
||||
let remaining = gen_tokens - generated.len();
|
||||
let round_gamma = gamma.min(remaining);
|
||||
let round_start_len = target_cache.seq_len(slot);
|
||||
assert_eq!(
|
||||
round_start_len,
|
||||
draft_cache.seq_len(slot),
|
||||
"target and draft cache lengths diverged"
|
||||
);
|
||||
if !use_verify_logits {
|
||||
assert_eq!(
|
||||
round_start_len,
|
||||
target_verify_cache.seq_len(slot),
|
||||
"target verify cache length diverged"
|
||||
);
|
||||
}
|
||||
|
||||
let mut draft_tokens = Vec::with_capacity(round_gamma);
|
||||
for _ in 0..round_gamma {
|
||||
let token = draft_next;
|
||||
draft_tokens.push(token);
|
||||
if tokenizer.is_eos(token) {
|
||||
break;
|
||||
}
|
||||
let pos = draft_cache.seq_len(slot);
|
||||
let logits = draft_decoder.decode(draft, &[token], &[pos], &[slot], draft_cache);
|
||||
draft_next = last_argmax(&logits);
|
||||
}
|
||||
proposed_total += draft_tokens.len();
|
||||
|
||||
if use_verify_logits {
|
||||
verify_steps += 1;
|
||||
let verify_logits =
|
||||
target.forward_verify_paged_decode_attention(&draft_tokens, slot, target_cache);
|
||||
let verify_argmax = argmax_rows(&verify_logits);
|
||||
assert_eq!(
|
||||
verify_argmax.len(),
|
||||
draft_tokens.len(),
|
||||
"verify logits rows must match draft token count"
|
||||
);
|
||||
|
||||
let mut accepted = 0usize;
|
||||
let mut done = false;
|
||||
while accepted < draft_tokens.len() {
|
||||
let expected = if accepted > 0 {
|
||||
verify_argmax[accepted - 1]
|
||||
} else {
|
||||
target_next
|
||||
};
|
||||
if draft_tokens[accepted] != expected {
|
||||
break;
|
||||
}
|
||||
let token = draft_tokens[accepted];
|
||||
generated.push(token);
|
||||
accepted_total += 1;
|
||||
accepted += 1;
|
||||
|
||||
if generated.len() >= gen_tokens || tokenizer.is_eos(token) {
|
||||
done = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if accepted > 0 {
|
||||
target_next = verify_argmax[accepted - 1];
|
||||
}
|
||||
target_cache
|
||||
.truncate_sequence(slot, round_start_len + accepted)
|
||||
.unwrap();
|
||||
|
||||
if done {
|
||||
draft_cache
|
||||
.truncate_sequence(slot, target_cache.seq_len(slot))
|
||||
.unwrap();
|
||||
break;
|
||||
}
|
||||
|
||||
if accepted == draft_tokens.len() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let correction = if accepted > 0 {
|
||||
verify_argmax[accepted - 1]
|
||||
} else {
|
||||
target_next
|
||||
};
|
||||
generated.push(correction);
|
||||
|
||||
draft_cache
|
||||
.truncate_sequence(slot, round_start_len)
|
||||
.unwrap();
|
||||
replay_draft_tokens(
|
||||
draft,
|
||||
draft_decoder,
|
||||
draft_cache,
|
||||
slot,
|
||||
&draft_tokens[..accepted],
|
||||
&mut draft_next,
|
||||
);
|
||||
|
||||
if generated.len() >= gen_tokens || tokenizer.is_eos(correction) {
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = target_cache.seq_len(slot);
|
||||
let logits = target.forward_decode_paged(&[correction], &[pos], &[slot], target_cache);
|
||||
target_next = last_argmax(&logits);
|
||||
commit_steps += 1;
|
||||
|
||||
let pos = draft_cache.seq_len(slot);
|
||||
let logits = draft_decoder.decode(draft, &[correction], &[pos], &[slot], draft_cache);
|
||||
draft_next = last_argmax(&logits);
|
||||
correction_steps += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
verify_steps += 1;
|
||||
let verify_logits = match verify_path {
|
||||
VerifyPath::Flash => {
|
||||
target.forward_prefill_paged(&draft_tokens, slot, target_verify_cache)
|
||||
}
|
||||
VerifyPath::PagedDecode => target.forward_verify_paged_decode_attention(
|
||||
&draft_tokens,
|
||||
slot,
|
||||
target_verify_cache,
|
||||
),
|
||||
};
|
||||
let verify_argmax = argmax_rows(&verify_logits);
|
||||
assert_eq!(
|
||||
verify_argmax.len(),
|
||||
draft_tokens.len(),
|
||||
"verify logits rows must match draft token count"
|
||||
);
|
||||
|
||||
target_verify_cache
|
||||
.truncate_sequence(slot, round_start_len)
|
||||
.unwrap();
|
||||
|
||||
let mut accepted = 0usize;
|
||||
let mut done = false;
|
||||
while accepted < draft_tokens.len() {
|
||||
let expected = if use_verify_logits && accepted > 0 {
|
||||
verify_argmax[accepted - 1]
|
||||
} else {
|
||||
target_next
|
||||
};
|
||||
if draft_tokens[accepted] != expected {
|
||||
break;
|
||||
}
|
||||
let token_idx = accepted;
|
||||
let token = draft_tokens[token_idx];
|
||||
generated.push(token);
|
||||
accepted_total += 1;
|
||||
accepted += 1;
|
||||
|
||||
if generated.len() >= gen_tokens || tokenizer.is_eos(token) {
|
||||
done = true;
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = target_cache.seq_len(slot);
|
||||
let logits = target.forward_decode_paged(&[token], &[pos], &[slot], target_cache);
|
||||
let decode_next = last_argmax(&logits);
|
||||
if verify_argmax[token_idx] != decode_next {
|
||||
verify_decode_mismatches += 1;
|
||||
eprintln!(
|
||||
"VERIFY/DECODE MISMATCH at cache_len={} accepted_idx={}: verify={} decode={}",
|
||||
target_cache.seq_len(slot),
|
||||
token_idx,
|
||||
verify_argmax[token_idx],
|
||||
decode_next
|
||||
);
|
||||
if dump_verify_mismatches {
|
||||
eprintln!(
|
||||
" verify_top5={} decode_top5={}",
|
||||
format_topk(&verify_logits, token_idx, 5),
|
||||
format_topk(&logits, 0, 5)
|
||||
);
|
||||
}
|
||||
}
|
||||
target_next = decode_next;
|
||||
commit_steps += 1;
|
||||
|
||||
advance_target_cache(target, target_verify_cache, slot, token);
|
||||
mirror_steps += 1;
|
||||
}
|
||||
if done {
|
||||
draft_cache
|
||||
.truncate_sequence(slot, target_cache.seq_len(slot))
|
||||
.unwrap();
|
||||
target_verify_cache
|
||||
.truncate_sequence(slot, target_cache.seq_len(slot))
|
||||
.unwrap();
|
||||
break;
|
||||
}
|
||||
|
||||
if accepted == draft_tokens.len() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let correction = if use_verify_logits && accepted > 0 {
|
||||
verify_argmax[accepted - 1]
|
||||
} else {
|
||||
target_next
|
||||
};
|
||||
generated.push(correction);
|
||||
|
||||
draft_cache
|
||||
.truncate_sequence(slot, round_start_len)
|
||||
.unwrap();
|
||||
replay_draft_tokens(
|
||||
draft,
|
||||
draft_decoder,
|
||||
draft_cache,
|
||||
slot,
|
||||
&draft_tokens[..accepted],
|
||||
&mut draft_next,
|
||||
);
|
||||
|
||||
if generated.len() >= gen_tokens || tokenizer.is_eos(correction) {
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = target_cache.seq_len(slot);
|
||||
let logits = target.forward_decode_paged(&[correction], &[pos], &[slot], target_cache);
|
||||
target_next = last_argmax(&logits);
|
||||
commit_steps += 1;
|
||||
|
||||
advance_target_cache(target, target_verify_cache, slot, correction);
|
||||
mirror_steps += 1;
|
||||
|
||||
let pos = draft_cache.seq_len(slot);
|
||||
let logits = draft_decoder.decode(draft, &[correction], &[pos], &[slot], draft_cache);
|
||||
draft_next = last_argmax(&logits);
|
||||
correction_steps += 1;
|
||||
}
|
||||
sync_device();
|
||||
let decode_s = decode_start.elapsed().as_secs_f64();
|
||||
sync_device();
|
||||
let total_s = t0.elapsed().as_secs_f64();
|
||||
|
||||
target_cache.free_sequence(slot);
|
||||
target_verify_cache.free_sequence(slot);
|
||||
draft_cache.free_sequence(slot);
|
||||
|
||||
RunStats {
|
||||
ids: generated,
|
||||
total_s,
|
||||
prefill_s,
|
||||
decode_s,
|
||||
target_steps: verify_steps + mirror_steps + commit_steps + correction_steps,
|
||||
accepted: accepted_total,
|
||||
proposed: proposed_total,
|
||||
verify_steps,
|
||||
mirror_steps,
|
||||
commit_steps,
|
||||
correction_steps,
|
||||
verify_decode_mismatches,
|
||||
}
|
||||
}
|
||||
|
||||
fn advance_target_cache(target: &Qwen3, cache: &mut PagedKVCache, slot: usize, token: u32) {
|
||||
let pos = cache.seq_len(slot);
|
||||
let _ = target.forward_decode_paged(&[token], &[pos], &[slot], cache);
|
||||
}
|
||||
|
||||
fn replay_draft_tokens(
|
||||
draft: &Qwen3,
|
||||
draft_decoder: &mut GraphedQwen3Decoder,
|
||||
cache: &mut PagedKVCache,
|
||||
slot: usize,
|
||||
tokens: &[u32],
|
||||
next: &mut u32,
|
||||
) {
|
||||
for &token in tokens {
|
||||
let pos = cache.seq_len(slot);
|
||||
let logits = draft_decoder.decode(draft, &[token], &[pos], &[slot], cache);
|
||||
*next = last_argmax(&logits);
|
||||
}
|
||||
}
|
||||
|
||||
fn new_cache(config: &ModelConfig, max_seq_len: usize, device: u32) -> PagedKVCache {
|
||||
new_cache_with_rows(config, max_seq_len, device, 1)
|
||||
}
|
||||
|
||||
fn new_cache_with_rows(
|
||||
config: &ModelConfig,
|
||||
max_seq_len: usize,
|
||||
device: u32,
|
||||
max_rows: usize,
|
||||
) -> PagedKVCache {
|
||||
let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE);
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
PagedKVCache::new(
|
||||
config,
|
||||
total_blocks,
|
||||
0,
|
||||
max_rows.max(1),
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
device,
|
||||
)
|
||||
}
|
||||
|
||||
fn argmax_rows(logits: &Tensor) -> Vec<u32> {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
if logits.dtype() == DType::BF16
|
||||
&& matches!(logits.device(), Device::Cuda(_))
|
||||
&& logits.is_contiguous()
|
||||
{
|
||||
return xserv_kernels::argmax_bf16_to_host(logits);
|
||||
}
|
||||
|
||||
let vocab_size = logits.shape()[1];
|
||||
let rows = logits.shape()[0];
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
match logits.dtype() {
|
||||
DType::F32 => logits_cpu
|
||||
.as_slice::<f32>()
|
||||
.chunks_exact(vocab_size)
|
||||
.take(rows)
|
||||
.map(argmax_f32)
|
||||
.collect(),
|
||||
DType::BF16 => logits_cpu
|
||||
.as_slice::<bf16>()
|
||||
.chunks_exact(vocab_size)
|
||||
.take(rows)
|
||||
.map(|row| {
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
})
|
||||
.collect(),
|
||||
_ => panic!("unsupported dtype for argmax: {:?}", logits.dtype()),
|
||||
}
|
||||
}
|
||||
|
||||
fn last_argmax(logits: &Tensor) -> u32 {
|
||||
*argmax_rows(logits).last().unwrap()
|
||||
}
|
||||
|
||||
fn argmax_f32(row: &[f32]) -> u32 {
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn format_topk(logits: &Tensor, row: usize, k: usize) -> String {
|
||||
let vals = topk_row(logits, row, k);
|
||||
vals.iter()
|
||||
.map(|(id, val)| format!("{id}:{val:.3}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
}
|
||||
|
||||
fn topk_row(logits: &Tensor, row: usize, k: usize) -> Vec<(u32, f32)> {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
let vocab_size = logits.shape()[1];
|
||||
assert!(row < logits.shape()[0], "topk row out of bounds");
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let mut vals: Vec<(u32, f32)> = match logits.dtype() {
|
||||
DType::F32 => logits_cpu.as_slice::<f32>()[row * vocab_size..(row + 1) * vocab_size]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &v)| (i as u32, v))
|
||||
.collect(),
|
||||
DType::BF16 => logits_cpu.as_slice::<bf16>()[row * vocab_size..(row + 1) * vocab_size]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &v)| (i as u32, v.to_f32()))
|
||||
.collect(),
|
||||
_ => panic!("unsupported dtype for topk: {:?}", logits.dtype()),
|
||||
};
|
||||
vals.select_nth_unstable_by(k, |a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
vals.truncate(k);
|
||||
vals.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
vals
|
||||
}
|
||||
|
||||
fn assert_qwen3(config: &ModelConfig, name: &str) {
|
||||
let model_type = config.model_type.as_deref().unwrap_or("unknown");
|
||||
assert!(
|
||||
model_type.contains("qwen"),
|
||||
"{name} model_type must be qwen-like, got {model_type}"
|
||||
);
|
||||
}
|
||||
|
||||
fn warn_if_tokenizers_differ(target_dir: &Path, draft_dir: &Path) {
|
||||
let target = std::fs::read(target_dir.join("tokenizer.json"));
|
||||
let draft = std::fs::read(draft_dir.join("tokenizer.json"));
|
||||
if let (Ok(target), Ok(draft)) = (target, draft) {
|
||||
if target != draft {
|
||||
eprintln!(
|
||||
"WARNING: target and draft tokenizer.json differ; v0 assumes identical token ids"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn arg_usize(args: &[String], flag: &str, default: usize) -> usize {
|
||||
args.iter()
|
||||
.position(|a| a == flag)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn parse_verify_path(args: &[String], use_verify_logits: bool) -> VerifyPath {
|
||||
let default = if use_verify_logits {
|
||||
VerifyPath::PagedDecode
|
||||
} else {
|
||||
VerifyPath::Flash
|
||||
};
|
||||
let Some(value) = args
|
||||
.iter()
|
||||
.position(|a| a == "--verify-path")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
else {
|
||||
return default;
|
||||
};
|
||||
match value.as_str() {
|
||||
"flash" => VerifyPath::Flash,
|
||||
"paged-decode" => VerifyPath::PagedDecode,
|
||||
_ => {
|
||||
eprintln!("unknown --verify-path {value:?}; expected flash or paged-decode");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_length_budget(prompt_ids: &[u32], gen_tokens: usize, max_seq_len: usize, prompt: &str) {
|
||||
let required = prompt_ids.len() + gen_tokens;
|
||||
if required > max_seq_len {
|
||||
eprintln!(
|
||||
"prompt requires prompt_len({}) + gen_tokens({}) = {} tokens, exceeding --max-seq-len {}: {:?}",
|
||||
prompt_ids.len(),
|
||||
gen_tokens,
|
||||
required,
|
||||
max_seq_len,
|
||||
prompt
|
||||
);
|
||||
std::process::exit(2);
|
||||
}
|
||||
}
|
||||
|
||||
fn sync_device() {
|
||||
xserv_cuda::device::synchronize().expect("cuda device synchronize");
|
||||
}
|
||||
|
||||
fn ratio(num: usize, den: usize) -> f64 {
|
||||
if den == 0 {
|
||||
0.0
|
||||
} else {
|
||||
num as f64 / den as f64
|
||||
}
|
||||
}
|
||||
|
||||
fn speedup(baseline_s: f64, spec_s: f64) -> f64 {
|
||||
if spec_s == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
baseline_s / spec_s
|
||||
}
|
||||
}
|
||||
|
||||
fn tok_s(tokens: usize, seconds: f64) -> f64 {
|
||||
if seconds == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
tokens as f64 / seconds
|
||||
}
|
||||
}
|
||||
|
||||
fn per_token_ms(seconds: f64, tokens: usize) -> f64 {
|
||||
if tokens == 0 {
|
||||
0.0
|
||||
} else {
|
||||
seconds * 1000.0 / tokens as f64
|
||||
}
|
||||
}
|
||||
244
crates/xserv-model/src/bin/bench-tp.rs
Normal file
244
crates/xserv-model/src/bin/bench-tp.rs
Normal file
@@ -0,0 +1,244 @@
|
||||
//! Tensor-parallel E2E benchmark for Qwen3.
|
||||
//!
|
||||
//! Spawns one thread per TP rank (each bound to one GPU), loads the sharded
|
||||
//! model, and runs greedy autoregressive generation. Because lm_head is
|
||||
//! replicated and the post-AllReduce hidden state is identical on every rank,
|
||||
//! all ranks compute identical logits and pick the same greedy token — so the
|
||||
//! rank threads stay in lockstep via the per-layer AllReduces without any
|
||||
//! token broadcast. Rank 0 records output + timings.
|
||||
//!
|
||||
//! Usage: bench-tp <model-dir> [--tp N] [--gen-tokens N] [--devices 0,1,2,3]
|
||||
//!
|
||||
//! Run with --tp 1 / 2 / 4 and compare the printed text (correctness) and
|
||||
//! tok/s (performance).
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Instant;
|
||||
|
||||
use xserv_model::qwen3::sample_greedy;
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
struct PromptResult {
|
||||
gen_ids: Vec<u32>,
|
||||
ttft_ms: f64,
|
||||
decode_tok_s: f64,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: bench-tp <model-dir> [--tp N] [--gen-tokens N] [--devices 0,1,2,3]");
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let world: usize = arg(&args, "--tp")
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
let gen_tokens: usize = arg(&args, "--gen-tokens")
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(64);
|
||||
let devices: Vec<u32> = match arg(&args, "--devices") {
|
||||
Some(s) => s.split(',').filter_map(|d| d.trim().parse().ok()).collect(),
|
||||
None => (0..world as u32).collect(),
|
||||
};
|
||||
assert_eq!(devices.len(), world, "--devices count must equal --tp");
|
||||
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
assert!(
|
||||
config.num_kv_heads() % world == 0,
|
||||
"num_kv_heads {} not divisible by tp {world}",
|
||||
config.num_kv_heads()
|
||||
);
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
let eos = tokenizer.eos_token_id();
|
||||
|
||||
let prompts: Vec<&str> = vec![
|
||||
"The capital of France is",
|
||||
"Explain photosynthesis in one sentence.",
|
||||
"Write a haiku about the ocean.",
|
||||
"List three uses of a hammer.",
|
||||
"What is the speed of light?",
|
||||
"Describe the water cycle briefly.",
|
||||
"Who wrote Romeo and Juliet?",
|
||||
"Translate 'good morning' into Spanish.",
|
||||
];
|
||||
let prompt_ids: Vec<Vec<u32>> = prompts.iter().map(|p| tokenizer.encode(p)).collect();
|
||||
|
||||
// Tensors are not Send (their Storage holds a raw GPU pointer), so each rank
|
||||
// thread loads its own CPU copy of the weights and shards in-thread. Loading
|
||||
// is not part of the timed region.
|
||||
let id = if world > 1 {
|
||||
Some(xserv_distributed::get_unique_id())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let handles: Vec<_> = (0..world)
|
||||
.map(|rank| {
|
||||
let model_dir = model_dir.clone();
|
||||
let config = config.clone();
|
||||
let prompt_ids = prompt_ids.clone();
|
||||
let device = devices[rank];
|
||||
thread::spawn(move || {
|
||||
run_rank(
|
||||
rank, world, device, id, config, model_dir, prompt_ids, gen_tokens, eos,
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut rank0: Option<Vec<PromptResult>> = None;
|
||||
for (rank, h) in handles.into_iter().enumerate() {
|
||||
let r = h.join().expect("rank thread panicked");
|
||||
if rank == 0 {
|
||||
rank0 = r;
|
||||
}
|
||||
}
|
||||
|
||||
let results = rank0.expect("rank 0 produced no results");
|
||||
println!("\n=== TP={world} (devices {devices:?}) — Qwen3 E2E benchmark ===");
|
||||
println!(
|
||||
"{:<45} {:>10} {:>12} {:>8}",
|
||||
"prompt", "TTFT(ms)", "decode tok/s", "gen"
|
||||
);
|
||||
let mut tps_sum = 0.0;
|
||||
for (i, r) in results.iter().enumerate() {
|
||||
let text = tokenizer.decode(&r.gen_ids).replace('\n', " ");
|
||||
let short: String = text.chars().take(50).collect();
|
||||
let p: String = prompts[i].chars().take(43).collect();
|
||||
println!(
|
||||
"{:<45} {:>10.1} {:>12.1} {:>8} | {}",
|
||||
p,
|
||||
r.ttft_ms,
|
||||
r.decode_tok_s,
|
||||
r.gen_ids.len(),
|
||||
short
|
||||
);
|
||||
tps_sum += r.decode_tok_s;
|
||||
}
|
||||
println!(
|
||||
"--- mean decode throughput: {:.1} tok/s ---",
|
||||
tps_sum / results.len() as f64
|
||||
);
|
||||
|
||||
// Machine-readable line for cross-TP correctness diffing (rank 0 token ids).
|
||||
let all_ids: Vec<String> = results
|
||||
.iter()
|
||||
.map(|r| {
|
||||
r.gen_ids
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
})
|
||||
.collect();
|
||||
println!("CORRECTNESS_IDS tp={world} {}", all_ids.join(" | "));
|
||||
}
|
||||
|
||||
fn run_rank(
|
||||
rank: usize,
|
||||
world: usize,
|
||||
device: u32,
|
||||
id: Option<xserv_distributed::UniqueId>,
|
||||
config: ModelConfig,
|
||||
model_dir: PathBuf,
|
||||
prompt_ids: Vec<Vec<u32>>,
|
||||
gen_tokens: usize,
|
||||
eos: Option<u32>,
|
||||
) -> Option<Vec<PromptResult>> {
|
||||
// Bind this thread to its GPU and set up the TP communicator.
|
||||
let tp = if world > 1 {
|
||||
Some(Arc::new(xserv_distributed::TpContext::init(
|
||||
rank,
|
||||
world,
|
||||
id.unwrap(),
|
||||
device,
|
||||
)))
|
||||
} else {
|
||||
xserv_cuda::device::set_device(device).unwrap();
|
||||
None
|
||||
};
|
||||
|
||||
// Load this rank's own CPU copy of the weights and shard in-thread.
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
|
||||
let model = Qwen3::from_weights_tp(config.clone(), weights, rank, world, device, tp.clone());
|
||||
|
||||
// Per-rank paged KV cache holds only this rank's local KV heads.
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_seq = 2048usize;
|
||||
let max_blocks_per_seq = max_seq.div_ceil(BLOCK_SIZE);
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
let mut cache = PagedKVCache::new_tp(
|
||||
&config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
1,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
device,
|
||||
);
|
||||
|
||||
// Warmup (init kernels / allocator / NCCL channels) — not timed.
|
||||
cache.register_sequence(0).unwrap();
|
||||
let _ = model.forward_prefill_paged(&[1u32, 2, 3], 0, &mut cache);
|
||||
cache.free_sequence(0);
|
||||
|
||||
let mut out = Vec::new();
|
||||
for ids in &prompt_ids {
|
||||
cache.register_sequence(0).unwrap();
|
||||
|
||||
// Prefill (TTFT).
|
||||
let t0 = Instant::now();
|
||||
let logits = model.forward_prefill_paged(ids, 0, &mut cache);
|
||||
let first = sample_greedy(&logits);
|
||||
let ttft_ms = t0.elapsed().as_secs_f64() * 1000.0;
|
||||
|
||||
let mut generated = vec![first];
|
||||
|
||||
// Decode.
|
||||
let t1 = Instant::now();
|
||||
let mut steps = 0usize;
|
||||
for _ in 1..gen_tokens {
|
||||
let last = *generated.last().unwrap();
|
||||
if eos == Some(last) {
|
||||
break;
|
||||
}
|
||||
let pos = cache.seq_len(0);
|
||||
let logits = model.forward_decode_paged(&[last], &[pos], &[0], &mut cache);
|
||||
let next = sample_greedy(&logits);
|
||||
generated.push(next);
|
||||
steps += 1;
|
||||
}
|
||||
let decode_s = t1.elapsed().as_secs_f64();
|
||||
let decode_tok_s = if steps > 0 && decode_s > 0.0 {
|
||||
steps as f64 / decode_s
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
cache.free_sequence(0);
|
||||
|
||||
if rank == 0 {
|
||||
out.push(PromptResult {
|
||||
gen_ids: generated,
|
||||
ttft_ms,
|
||||
decode_tok_s,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if rank == 0 { Some(out) } else { None }
|
||||
}
|
||||
|
||||
fn arg<'a>(args: &'a [String], flag: &str) -> Option<&'a str> {
|
||||
args.iter()
|
||||
.position(|a| a == flag)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
134
crates/xserv-model/src/bin/bench-verify-cost.rs
Normal file
134
crates/xserv-model/src/bin/bench-verify-cost.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
//! Micro-benchmark: measure the cost of forward_verify_paged_decode_attention
|
||||
//! at different batch sizes (γ+1 values), to understand where speedup comes
|
||||
//! from (or doesn't).
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!(
|
||||
"Usage: bench-verify-cost <target-dir> [--prompt-len N] [--iters N] [--device N]"
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
let target_dir = PathBuf::from(&args[1]);
|
||||
let prompt_len = arg_usize(&args, "--prompt-len", 100);
|
||||
let iters = arg_usize(&args, "--iters", 30);
|
||||
let device = arg_usize(&args, "--device", 0) as u32;
|
||||
|
||||
xserv_cuda::device::set_device(device).unwrap();
|
||||
|
||||
let cfg = ModelConfig::from_file(&target_dir.join("config.json"));
|
||||
eprintln!("Loading target...");
|
||||
let weights = loader::load_model_dir(&target_dir, Device::Cuda(device));
|
||||
let target = Qwen3::from_weights(cfg.clone(), weights);
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
let tok = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
|
||||
let ids = tok.encode(&"the ".repeat(prompt_len))[..prompt_len].to_vec();
|
||||
|
||||
let max_seq_len = 2048;
|
||||
let num_blocks = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE + 4;
|
||||
let mut cache = PagedKVCache::new(&cfg, num_blocks, 0, 16, num_blocks, DType::BF16, device);
|
||||
cache.register_sequence(0).unwrap();
|
||||
|
||||
// Prefill
|
||||
let _ = target.forward_prefill_paged(&ids, 0, &mut cache);
|
||||
sync();
|
||||
|
||||
// Warmup one of each
|
||||
for &n in &[1, 2, 3, 5, 9] {
|
||||
let toks: Vec<u32> = (0..n).map(|_| ids[0]).collect();
|
||||
let _ = target.forward_decode_paged(
|
||||
&toks,
|
||||
&(0..n).map(|i| ids.len() + i).collect::<Vec<_>>(),
|
||||
&vec![0; n],
|
||||
&mut cache,
|
||||
);
|
||||
cache.truncate_sequence(0, ids.len()).unwrap();
|
||||
}
|
||||
sync();
|
||||
|
||||
// Benchmark single-token decode
|
||||
let mut t = 0.0f64;
|
||||
for i in 0..iters {
|
||||
cache.truncate_sequence(0, ids.len()).unwrap();
|
||||
let t0 = Instant::now();
|
||||
let _ = target.forward_decode_paged(&[ids[0]], &[ids.len()], &[0], &mut cache);
|
||||
sync();
|
||||
t += t0.elapsed().as_secs_f64();
|
||||
let _ = i;
|
||||
}
|
||||
let single = t * 1000.0 / iters as f64;
|
||||
println!(
|
||||
"single-token decode: {:.3} ms (mean of {} iters)",
|
||||
single, iters
|
||||
);
|
||||
|
||||
// Benchmark forward_verify_paged_decode_attention at various batch sizes
|
||||
// (batched-GEMV path).
|
||||
for &n in &[1usize, 2, 3, 5, 9] {
|
||||
let toks: Vec<u32> = (0..n).map(|_| ids[0]).collect();
|
||||
let mut t = 0.0f64;
|
||||
for _ in 0..iters {
|
||||
cache.truncate_sequence(0, ids.len()).unwrap();
|
||||
let t0 = Instant::now();
|
||||
let _ = target.forward_verify_paged_decode_attention(&toks, 0, &mut cache);
|
||||
sync();
|
||||
t += t0.elapsed().as_secs_f64();
|
||||
}
|
||||
let ms = t * 1000.0 / iters as f64;
|
||||
println!(
|
||||
"verify (batched-GEMV) batch={}: {:.3} ms ({:.2}× single)",
|
||||
n,
|
||||
ms,
|
||||
ms / single
|
||||
);
|
||||
}
|
||||
|
||||
// Benchmark _with_hidden variant which uses cuBLAS GEMM after Phase 26 fast-verify.
|
||||
let hooks_layers = [2usize, 18, 33];
|
||||
for &n in &[1usize, 2, 3, 5, 9] {
|
||||
let toks: Vec<u32> = (0..n).map(|_| ids[0]).collect();
|
||||
let mut t = 0.0f64;
|
||||
for _ in 0..iters {
|
||||
cache.truncate_sequence(0, ids.len()).unwrap();
|
||||
let t0 = Instant::now();
|
||||
let _ = target.forward_verify_paged_decode_attention_with_hidden(
|
||||
&toks,
|
||||
0,
|
||||
&mut cache,
|
||||
&hooks_layers,
|
||||
);
|
||||
sync();
|
||||
t += t0.elapsed().as_secs_f64();
|
||||
}
|
||||
let ms = t * 1000.0 / iters as f64;
|
||||
println!(
|
||||
"verify (cuBLAS GEMM) batch={}: {:.3} ms ({:.2}× single)",
|
||||
n,
|
||||
ms,
|
||||
ms / single
|
||||
);
|
||||
}
|
||||
|
||||
cache.free_sequence(0);
|
||||
}
|
||||
|
||||
fn sync() {
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
|
||||
fn arg_usize(args: &[String], flag: &str, default: usize) -> usize {
|
||||
args.iter()
|
||||
.position(|a| a == flag)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
174
crates/xserv-model/src/bin/check-eagle3.rs
Normal file
174
crates/xserv-model/src/bin/check-eagle3.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
//! EAGLE3 sanity check: load weights, run one draft step, print top-5 predictions.
|
||||
//!
|
||||
//! This verifies that:
|
||||
//! - Eagle3Head weights load without shape mismatches
|
||||
//! - Target hidden states can be captured via decode_core_with_hidden
|
||||
//! - Eagle3Head::step produces a valid token id (in target vocab)
|
||||
//!
|
||||
//! Does NOT measure speedup — that requires a full γ≥2 speculative loop, which
|
||||
//! is more complex integration work.
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use xserv_model::eagle3::{EAGLE_HOOK_LAYERS, Eagle3Head};
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 3 {
|
||||
eprintln!("Usage: check-eagle3 <target-model-dir> <eagle3-model-dir> [prompt]");
|
||||
std::process::exit(1);
|
||||
}
|
||||
let target_dir = PathBuf::from(&args[1]);
|
||||
let eagle_dir = PathBuf::from(&args[2]);
|
||||
let prompt = args
|
||||
.get(3)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "The capital of France is".to_string());
|
||||
let device: u32 = 0;
|
||||
|
||||
xserv_cuda::device::set_device(device).unwrap();
|
||||
|
||||
let target_config = ModelConfig::from_file(&target_dir.join("config.json"));
|
||||
eprintln!("Loading target Qwen3-8B...");
|
||||
let target_weights = loader::load_model_dir(&target_dir, Device::Cuda(device));
|
||||
let target = Qwen3::from_weights(target_config.clone(), target_weights);
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
eprintln!("Loading EAGLE3 head from {}", eagle_dir.display());
|
||||
let mut eagle = Eagle3Head::load(&eagle_dir, device);
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
|
||||
let embed_tokens = target.embed_tokens_tensor();
|
||||
|
||||
let ids = tokenizer.encode(&prompt);
|
||||
let max_seq_len = 512;
|
||||
|
||||
let num_blocks = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE + 2;
|
||||
let mut cache = PagedKVCache::new(
|
||||
&target_config,
|
||||
num_blocks,
|
||||
0,
|
||||
1,
|
||||
num_blocks,
|
||||
DType::BF16,
|
||||
device,
|
||||
);
|
||||
cache.register_sequence(0).unwrap();
|
||||
|
||||
// Prefill target.
|
||||
let logits = target.forward_prefill_paged(&ids, 0, &mut cache);
|
||||
let target_first = *xserv_kernels::argmax_bf16_to_host(&logits).last().unwrap();
|
||||
let target_first_text = tokenizer.decode(&[target_first]);
|
||||
println!("Prompt: {:?}", prompt);
|
||||
println!(
|
||||
"Target argmax after prefill: {} ({:?})",
|
||||
target_first, target_first_text
|
||||
);
|
||||
|
||||
// Now run one target decode step with target_first to get hidden states at the
|
||||
// hook layers.
|
||||
let pos = cache.seq_len(0);
|
||||
target.decode_prepare(&[pos], &[0], &mut cache);
|
||||
let ids_gpu = upload_u32(&[target_first]);
|
||||
let pos_gpu = upload_u32(&[pos as u32]);
|
||||
let (target_next_logits, hooks) = target.decode_core_with_hidden(
|
||||
ids_gpu.as_ptr() as *const std::ffi::c_void,
|
||||
pos_gpu.as_ptr() as *const std::ffi::c_void,
|
||||
1,
|
||||
&[0],
|
||||
&mut cache,
|
||||
&EAGLE_HOOK_LAYERS,
|
||||
);
|
||||
let target_next = xserv_kernels::argmax_bf16_single(&target_next_logits);
|
||||
let target_next_text = tokenizer.decode(&[target_next]);
|
||||
println!(
|
||||
"Target argmax after 1 decode step: {} ({:?})",
|
||||
target_next, target_next_text
|
||||
);
|
||||
|
||||
for (i, h) in hooks.iter().enumerate() {
|
||||
println!(
|
||||
"hook[{}] (layer {}): shape={:?} dtype={:?}",
|
||||
i,
|
||||
EAGLE_HOOK_LAYERS[i],
|
||||
h.shape(),
|
||||
h.dtype()
|
||||
);
|
||||
}
|
||||
|
||||
// Ask EAGLE what it thinks the NEXT token is (given target_first as prev_token
|
||||
// and the hidden states from the position where target_first lives).
|
||||
// EAGLE should predict target_next (or close to it) to be useful.
|
||||
eagle.reset();
|
||||
let (eagle_pred, eagle_logits) = eagle.step(&hooks, embed_tokens, target_first, pos);
|
||||
let eagle_pred_text = tokenizer.decode(&[eagle_pred]);
|
||||
println!(
|
||||
"EAGLE draft prediction (pairing A: prev=target_first): {} ({:?})",
|
||||
eagle_pred, eagle_pred_text
|
||||
);
|
||||
|
||||
if eagle_pred == target_next {
|
||||
println!("MATCH: EAGLE agrees with target on next token.");
|
||||
} else {
|
||||
println!(
|
||||
"MISMATCH: EAGLE draft={} vs target={} (this is fine per-step; check top-5 below)",
|
||||
eagle_pred, target_next
|
||||
);
|
||||
}
|
||||
|
||||
// Show top-5 from eagle logits (in draft vocab space, mapped to target).
|
||||
print_top5(
|
||||
&eagle_logits,
|
||||
"EAGLE draft top-5 (pairing A)",
|
||||
&eagle,
|
||||
&tokenizer,
|
||||
);
|
||||
|
||||
// Alternative pairing B: pair hooks with target_next (the token those hooks produced
|
||||
// via lm_head), predict token after target_next. Position advances by 1.
|
||||
eagle.reset();
|
||||
let (eagle_pred_b, eagle_logits_b) = eagle.step(&hooks, embed_tokens, target_next, pos + 1);
|
||||
let eagle_pred_b_text = tokenizer.decode(&[eagle_pred_b]);
|
||||
println!(
|
||||
"\nEAGLE draft prediction (pairing B: prev=target_next): {} ({:?})",
|
||||
eagle_pred_b, eagle_pred_b_text
|
||||
);
|
||||
print_top5(
|
||||
&eagle_logits_b,
|
||||
"EAGLE draft top-5 (pairing B)",
|
||||
&eagle,
|
||||
&tokenizer,
|
||||
);
|
||||
}
|
||||
|
||||
fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer {
|
||||
let bytes = unsafe { std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) };
|
||||
let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).unwrap();
|
||||
buf.copy_from_host(bytes).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
fn print_top5(logits: &Tensor, label: &str, eagle: &Eagle3Head, tokenizer: &Tokenizer) {
|
||||
use half::bf16;
|
||||
let cpu = logits.to_device(Device::Cpu);
|
||||
let data = cpu.as_slice::<bf16>();
|
||||
let mut vals: Vec<(usize, f32)> = data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (i, v.to_f32()))
|
||||
.collect();
|
||||
vals.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
println!("{label}:");
|
||||
for (i, val) in vals.iter().take(5) {
|
||||
let target_id = eagle.map_draft_to_target(*i as u32);
|
||||
let text = tokenizer.decode(&[target_id]);
|
||||
println!(
|
||||
" draft_id={} target_id={} val={:.3} text={:?}",
|
||||
i, target_id, val, text
|
||||
);
|
||||
}
|
||||
}
|
||||
49
crates/xserv-model/src/bin/dump-logits.rs
Normal file
49
crates/xserv-model/src/bin/dump-logits.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use half::bf16;
|
||||
use std::path::PathBuf;
|
||||
use xserv_model::{KVCache, ModelConfig, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let prompt = &args[2];
|
||||
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cuda(0));
|
||||
let model = Qwen3::from_weights(config.clone(), weights);
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
|
||||
let token_ids = tokenizer.encode(prompt);
|
||||
eprintln!("Prompt: {prompt}");
|
||||
eprintln!("Token IDs: {token_ids:?}");
|
||||
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(),
|
||||
config.num_kv_heads(),
|
||||
config.head_dim(),
|
||||
DType::BF16,
|
||||
Device::Cuda(0),
|
||||
);
|
||||
let logits = model.forward_with_cache(&token_ids, &mut cache);
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let data = logits_cpu.as_slice::<bf16>();
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
|
||||
// Print top-20 logits for the last position
|
||||
let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
let mut indexed: Vec<(usize, f32)> = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (i, v.to_f32()))
|
||||
.collect();
|
||||
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
|
||||
println!("Top-20 logits (last position):");
|
||||
for (rank, (id, val)) in indexed.iter().take(20).enumerate() {
|
||||
let tok = tokenizer.decode(&[*id as u32]);
|
||||
println!(" [{rank:>2}] id={id:>6} logit={val:>10.4} token={tok:?}");
|
||||
}
|
||||
}
|
||||
1159
crates/xserv-model/src/bin/xserv-chat.rs
Normal file
1159
crates/xserv-model/src/bin/xserv-chat.rs
Normal file
File diff suppressed because it is too large
Load Diff
212
crates/xserv-model/src/bin/xserv-cli.rs
Normal file
212
crates/xserv-model/src/bin/xserv-cli.rs
Normal file
@@ -0,0 +1,212 @@
|
||||
use std::io::{self, Write};
|
||||
use std::path::PathBuf;
|
||||
use xserv_model::{
|
||||
BLOCK_SIZE, KVCache, ModelConfig, PagedKVCache, SamplingParams, loader, sample,
|
||||
sample_greedy_penalized,
|
||||
};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn pick_next(
|
||||
logits: &xserv_tensor::Tensor,
|
||||
sampling: &SamplingParams,
|
||||
history: &[u32],
|
||||
rep_penalty: f32,
|
||||
) -> u32 {
|
||||
if rep_penalty > 1.0 && sampling.temperature == 0.0 {
|
||||
sample_greedy_penalized(logits, history, rep_penalty)
|
||||
} else {
|
||||
sample(logits, sampling)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!(
|
||||
"Usage: xserv-cli <model-dir> [--max-tokens N] [--temperature F] [--top-k N] [--top-p F] [--rep-penalty F] [--rep-window N]"
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let max_tokens = flag(&args, "--max-tokens", 100usize);
|
||||
let sampling = SamplingParams {
|
||||
temperature: flag(&args, "--temperature", 0.0f32),
|
||||
top_k: flag(&args, "--top-k", 0usize),
|
||||
top_p: flag(&args, "--top-p", 1.0f32),
|
||||
};
|
||||
let rep_penalty = flag(&args, "--rep-penalty", 1.0f32);
|
||||
let rep_window = flag(&args, "--rep-window", 512usize);
|
||||
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
let info = xserv_cuda::device::device_info(0).unwrap();
|
||||
eprintln!(
|
||||
"GPU: {} ({} MB free)",
|
||||
info.name,
|
||||
info.free_memory / 1024 / 1024
|
||||
);
|
||||
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
let model_type = config.model_type.as_deref().unwrap_or("unknown");
|
||||
eprintln!(
|
||||
"Model: {model_type}, layers={}, hidden={}, heads={}/{} kv, vocab={}",
|
||||
config.num_layers(),
|
||||
config.hidden(),
|
||||
config.num_heads(),
|
||||
config.num_kv_heads(),
|
||||
config.vocab_size
|
||||
);
|
||||
|
||||
eprintln!("Loading weights...");
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cuda(0));
|
||||
eprintln!("Loaded {} tensors", weights.len());
|
||||
|
||||
let is_qwen3 = model_type.contains("qwen");
|
||||
let is_gpt_oss = model_type.contains("gpt_oss");
|
||||
let dtype = if is_qwen3 || is_gpt_oss {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
|
||||
// Build model
|
||||
enum Model {
|
||||
GPT2(xserv_model::GPT2),
|
||||
Qwen3(xserv_model::Qwen3),
|
||||
GptOss(xserv_model::GptOss),
|
||||
}
|
||||
let model = if is_gpt_oss {
|
||||
Model::GptOss(xserv_model::GptOss::from_weights(config.clone(), weights))
|
||||
} else if is_qwen3 {
|
||||
Model::Qwen3(xserv_model::Qwen3::from_weights(config.clone(), weights))
|
||||
} else {
|
||||
Model::GPT2(xserv_model::GPT2::from_weights(config.clone(), weights))
|
||||
};
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
eprintln!(
|
||||
"Ready (KV cache, dtype={dtype}, temperature={}, top_k={}, top_p={}, rep_penalty={}, rep_window={}).\n",
|
||||
sampling.temperature, sampling.top_k, sampling.top_p, rep_penalty, rep_window
|
||||
);
|
||||
|
||||
loop {
|
||||
print!("xserv> ");
|
||||
io::stdout().flush().unwrap();
|
||||
let mut input = String::new();
|
||||
if io::stdin().read_line(&mut input).unwrap() == 0 {
|
||||
break;
|
||||
}
|
||||
let raw_input = input.trim();
|
||||
if raw_input.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if raw_input == "quit" || raw_input == "exit" {
|
||||
break;
|
||||
}
|
||||
let input = raw_input.replace("\\n", "\n");
|
||||
|
||||
let token_ids = tokenizer.encode(&input);
|
||||
|
||||
if is_gpt_oss {
|
||||
// GptOss uses paged KV cache
|
||||
let max_seq = 2048;
|
||||
let max_blocks_per_seq = (max_seq + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 64;
|
||||
let mut paged_cache = PagedKVCache::new(
|
||||
&config,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
0,
|
||||
);
|
||||
let slot = 0;
|
||||
paged_cache.register_sequence(slot).expect("register slot");
|
||||
|
||||
let model = match &model {
|
||||
Model::GptOss(m) => m,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let logits = model.forward_prefill_paged(&token_ids, slot, &mut paged_cache);
|
||||
let mut history = token_ids.clone();
|
||||
let start = history.len().saturating_sub(rep_window);
|
||||
let mut next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
|
||||
|
||||
print!("{input}");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
for _ in 0..max_tokens {
|
||||
let text = tokenizer.decode(&[next]);
|
||||
print!("{text}");
|
||||
io::stdout().flush().unwrap();
|
||||
history.push(next);
|
||||
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = paged_cache.seq_len(slot);
|
||||
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut paged_cache);
|
||||
let start = history.len().saturating_sub(rep_window);
|
||||
next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
|
||||
}
|
||||
println!();
|
||||
paged_cache.free_sequence(slot);
|
||||
} else {
|
||||
let kv_heads = if is_qwen3 {
|
||||
config.num_kv_heads()
|
||||
} else {
|
||||
config.num_heads()
|
||||
};
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(),
|
||||
kv_heads,
|
||||
config.head_dim(),
|
||||
dtype,
|
||||
Device::Cuda(0),
|
||||
);
|
||||
|
||||
let logits = match &model {
|
||||
Model::GPT2(m) => m.forward_with_cache(&token_ids, &mut cache),
|
||||
Model::Qwen3(m) => m.forward_with_cache(&token_ids, &mut cache),
|
||||
Model::GptOss(_) => unreachable!(),
|
||||
};
|
||||
let mut history = token_ids.clone();
|
||||
let start = history.len().saturating_sub(rep_window);
|
||||
let mut next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
|
||||
|
||||
print!("{input}");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
for _ in 0..max_tokens {
|
||||
let text = tokenizer.decode(&[next]);
|
||||
print!("{text}");
|
||||
io::stdout().flush().unwrap();
|
||||
history.push(next);
|
||||
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
|
||||
let logits = match &model {
|
||||
Model::GPT2(m) => m.forward_with_cache(&[next], &mut cache),
|
||||
Model::Qwen3(m) => m.forward_with_cache(&[next], &mut cache),
|
||||
Model::GptOss(_) => unreachable!(),
|
||||
};
|
||||
let start = history.len().saturating_sub(rep_window);
|
||||
next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
}
|
||||
166
crates/xserv-model/src/config.rs
Normal file
166
crates/xserv-model/src/config.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
use serde::Deserialize;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RopeScaling {
|
||||
pub rope_type: Option<String>,
|
||||
pub factor: Option<f64>,
|
||||
pub original_max_position_embeddings: Option<usize>,
|
||||
pub beta_fast: Option<f64>,
|
||||
pub beta_slow: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelConfig {
|
||||
pub architectures: Option<Vec<String>>,
|
||||
pub model_type: Option<String>,
|
||||
|
||||
// Modern HF naming
|
||||
#[serde(default)]
|
||||
pub hidden_size: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub intermediate_size: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub num_attention_heads: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub num_key_value_heads: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub num_hidden_layers: Option<usize>,
|
||||
pub vocab_size: usize,
|
||||
#[serde(default)]
|
||||
pub max_position_embeddings: Option<usize>,
|
||||
|
||||
// GPT-2 naming
|
||||
#[serde(default)]
|
||||
pub n_embd: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub n_head: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub n_layer: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub n_positions: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub n_inner: Option<usize>,
|
||||
|
||||
// Normalization
|
||||
#[serde(default)]
|
||||
pub layer_norm_eps: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub layer_norm_epsilon: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub rms_norm_eps: Option<f64>,
|
||||
|
||||
// Other
|
||||
#[serde(default)]
|
||||
pub rope_theta: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub tie_word_embeddings: Option<bool>,
|
||||
|
||||
// MoE (gpt-oss)
|
||||
#[serde(default)]
|
||||
pub num_local_experts: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub num_experts_per_tok: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub layer_types: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
pub sliding_window: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub attention_bias: Option<bool>,
|
||||
#[serde(default, rename = "head_dim")]
|
||||
pub explicit_head_dim: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub rope_scaling: Option<RopeScaling>,
|
||||
#[serde(default)]
|
||||
pub swiglu_limit: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub geglu_alpha: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub hidden_act: Option<String>,
|
||||
}
|
||||
|
||||
impl ModelConfig {
|
||||
pub fn from_file(path: &Path) -> Self {
|
||||
let data = std::fs::read_to_string(path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
|
||||
serde_json::from_str(&data)
|
||||
.unwrap_or_else(|e| panic!("failed to parse {}: {e}", path.display()))
|
||||
}
|
||||
|
||||
pub fn hidden(&self) -> usize {
|
||||
self.hidden_size
|
||||
.or(self.n_embd)
|
||||
.expect("hidden_size or n_embd required")
|
||||
}
|
||||
|
||||
pub fn num_heads(&self) -> usize {
|
||||
self.num_attention_heads
|
||||
.or(self.n_head)
|
||||
.expect("num_attention_heads or n_head required")
|
||||
}
|
||||
|
||||
pub fn num_layers(&self) -> usize {
|
||||
self.num_hidden_layers
|
||||
.or(self.n_layer)
|
||||
.expect("num_hidden_layers or n_layer required")
|
||||
}
|
||||
|
||||
pub fn max_seq_len(&self) -> usize {
|
||||
self.max_position_embeddings
|
||||
.or(self.n_positions)
|
||||
.unwrap_or(2048)
|
||||
}
|
||||
|
||||
pub fn ffn_hidden(&self) -> usize {
|
||||
self.intermediate_size
|
||||
.or(self.n_inner)
|
||||
.unwrap_or(self.hidden() * 4)
|
||||
}
|
||||
|
||||
pub fn num_kv_heads(&self) -> usize {
|
||||
self.num_key_value_heads.unwrap_or(self.num_heads())
|
||||
}
|
||||
|
||||
pub fn head_dim(&self) -> usize {
|
||||
self.explicit_head_dim
|
||||
.unwrap_or_else(|| self.hidden() / self.num_heads())
|
||||
}
|
||||
|
||||
pub fn ln_eps(&self) -> f32 {
|
||||
self.layer_norm_eps
|
||||
.or(self.layer_norm_epsilon)
|
||||
.unwrap_or(1e-5) as f32
|
||||
}
|
||||
|
||||
pub fn tied_embeddings(&self) -> bool {
|
||||
self.tie_word_embeddings.unwrap_or(true)
|
||||
}
|
||||
|
||||
pub fn num_experts(&self) -> usize {
|
||||
self.num_local_experts.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn experts_per_token(&self) -> usize {
|
||||
self.num_experts_per_tok.unwrap_or(1)
|
||||
}
|
||||
|
||||
pub fn is_moe(&self) -> bool {
|
||||
self.num_local_experts.unwrap_or(0) > 1
|
||||
}
|
||||
|
||||
pub fn is_sliding_layer(&self, layer_idx: usize) -> bool {
|
||||
self.layer_types
|
||||
.as_ref()
|
||||
.and_then(|lt| lt.get(layer_idx))
|
||||
.map(|t| t == "sliding_attention")
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn window_size(&self) -> usize {
|
||||
self.sliding_window.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn geglu_alpha(&self) -> f32 {
|
||||
self.geglu_alpha.unwrap_or(1.702) as f32
|
||||
}
|
||||
}
|
||||
649
crates/xserv-model/src/decode_graph.rs
Normal file
649
crates/xserv-model/src/decode_graph.rs
Normal file
@@ -0,0 +1,649 @@
|
||||
//! CUDA Graph integration for batch=1 single-sequence decode.
|
||||
//!
|
||||
//! Uses a per-layer split graph approach:
|
||||
//! - Pre-attention graph: RMSNorm + QKV projections + reshape + QK-norm + RoPE
|
||||
//! - Ungraphed: KV cache append + decode attention (variable kv_len)
|
||||
//! - Post-attention graph: merge_heads + O-proj + add_rmsnorm + FFN + residual
|
||||
//! - Final graph: last RMSNorm + lm_head GEMV
|
||||
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
|
||||
use xserv_kernels::dispatch;
|
||||
use xserv_kernels::gemm::{cublas_handle, gemv_scratch_elems};
|
||||
|
||||
use crate::config::ModelConfig;
|
||||
use crate::kv_cache::GpuKVCache;
|
||||
|
||||
/// Pre-allocated intermediate buffers for decode (batch=1).
|
||||
/// All buffers have stable GPU addresses for CUDA Graph replay.
|
||||
struct DecodeBuffers {
|
||||
// Hidden-size buffers: [1, hidden]
|
||||
x: GpuBuffer, // running hidden state
|
||||
normed: GpuBuffer, // rmsnorm output
|
||||
attn_out: GpuBuffer, // attention output [1, num_heads, 1, head_dim]
|
||||
attn_merged: GpuBuffer, // merge_heads output [1, hidden]
|
||||
o_proj: GpuBuffer, // O projection output [1, hidden]
|
||||
normed2: GpuBuffer, // post-attn norm output [1, hidden]
|
||||
sum_out: GpuBuffer, // add_rmsnorm sum output [1, hidden]
|
||||
down: GpuBuffer, // down projection output [1, hidden]
|
||||
|
||||
// QKV projection outputs
|
||||
q_proj: GpuBuffer, // [1, num_heads * head_dim]
|
||||
k_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
|
||||
v_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
|
||||
|
||||
// Reshaped: [1, H, 1, D]
|
||||
q_reshaped: GpuBuffer,
|
||||
k_reshaped: GpuBuffer,
|
||||
v_reshaped: GpuBuffer,
|
||||
|
||||
// After QK-norm (same shape as reshaped)
|
||||
q_normed: GpuBuffer,
|
||||
k_normed: GpuBuffer,
|
||||
|
||||
// RoPE transposed: [1, H, D]
|
||||
q_rope: GpuBuffer,
|
||||
k_rope: GpuBuffer,
|
||||
|
||||
// After RoPE transpose back: [1, H, 1, D]
|
||||
q_final: GpuBuffer,
|
||||
k_final: GpuBuffer,
|
||||
|
||||
// FFN intermediates
|
||||
gate: GpuBuffer, // [1, intermediate]
|
||||
up: GpuBuffer, // [1, intermediate]
|
||||
silu_out: GpuBuffer, // [1, intermediate]
|
||||
|
||||
// GEMV fp32 scratch for deterministic K-block partials.
|
||||
fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs
|
||||
fp32_q: GpuBuffer, // for Q projection
|
||||
fp32_kv: GpuBuffer, // for K/V projection
|
||||
fp32_intermediate: GpuBuffer, // for gate/up projections
|
||||
fp32_vocab: GpuBuffer, // for lm_head
|
||||
|
||||
// Token ID and position (GPU-resident, updated before replay)
|
||||
token_id_gpu: GpuBuffer, // 4 bytes (u32)
|
||||
position_gpu: GpuBuffer, // 4 bytes (u32)
|
||||
|
||||
// Final output
|
||||
logits: GpuBuffer, // [1, vocab_size]
|
||||
}
|
||||
|
||||
pub struct DecodeGraphState {
|
||||
stream: CudaStream,
|
||||
buffers: DecodeBuffers,
|
||||
|
||||
// Per-layer graph pairs
|
||||
pre_attn_graphs: Vec<CudaGraph>,
|
||||
post_attn_graphs: Vec<CudaGraph>,
|
||||
final_graph: CudaGraph,
|
||||
|
||||
captured: bool,
|
||||
|
||||
// Model dimensions
|
||||
hidden: usize,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
intermediate: usize,
|
||||
vocab_size: usize,
|
||||
num_layers: usize,
|
||||
eps: f32,
|
||||
}
|
||||
|
||||
impl DecodeGraphState {
|
||||
pub fn new(config: &ModelConfig) -> Self {
|
||||
let hidden = config.hidden();
|
||||
let num_heads = config.num_heads();
|
||||
let num_kv_heads = config.num_kv_heads();
|
||||
let head_dim = config.head_dim();
|
||||
let intermediate = config.ffn_hidden();
|
||||
let vocab_size = config.vocab_size;
|
||||
let num_layers = config.num_layers();
|
||||
let eps = config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
let es = 2usize; // BF16 = 2 bytes
|
||||
|
||||
let stream = CudaStream::new().expect("create CUDA stream for graph");
|
||||
|
||||
let alloc = |size: usize| -> GpuBuffer {
|
||||
GpuBuffer::alloc(size).expect("alloc decode graph buffer")
|
||||
};
|
||||
|
||||
let buffers = DecodeBuffers {
|
||||
x: alloc(hidden * es),
|
||||
normed: alloc(hidden * es),
|
||||
attn_out: alloc(num_heads * head_dim * es),
|
||||
attn_merged: alloc(hidden * es),
|
||||
o_proj: alloc(hidden * es),
|
||||
normed2: alloc(hidden * es),
|
||||
sum_out: alloc(hidden * es),
|
||||
down: alloc(hidden * es),
|
||||
|
||||
q_proj: alloc(num_heads * head_dim * es),
|
||||
k_proj: alloc(num_kv_heads * head_dim * es),
|
||||
v_proj: alloc(num_kv_heads * head_dim * es),
|
||||
|
||||
q_reshaped: alloc(num_heads * head_dim * es),
|
||||
k_reshaped: alloc(num_kv_heads * head_dim * es),
|
||||
v_reshaped: alloc(num_kv_heads * head_dim * es),
|
||||
|
||||
q_normed: alloc(num_heads * head_dim * es),
|
||||
k_normed: alloc(num_kv_heads * head_dim * es),
|
||||
|
||||
q_rope: alloc(num_heads * head_dim * es),
|
||||
k_rope: alloc(num_kv_heads * head_dim * es),
|
||||
|
||||
q_final: alloc(num_heads * head_dim * es),
|
||||
k_final: alloc(num_kv_heads * head_dim * es),
|
||||
|
||||
gate: alloc(intermediate * es),
|
||||
up: alloc(intermediate * es),
|
||||
silu_out: alloc(intermediate * es),
|
||||
|
||||
fp32_hidden: alloc(
|
||||
gemv_scratch_elems(hidden, hidden).max(gemv_scratch_elems(intermediate, hidden))
|
||||
* 4,
|
||||
),
|
||||
fp32_q: alloc(gemv_scratch_elems(hidden, num_heads * head_dim) * 4),
|
||||
fp32_kv: alloc(gemv_scratch_elems(hidden, num_kv_heads * head_dim) * 4),
|
||||
fp32_intermediate: alloc(gemv_scratch_elems(hidden, intermediate) * 4),
|
||||
fp32_vocab: alloc(gemv_scratch_elems(hidden, vocab_size) * 4),
|
||||
|
||||
token_id_gpu: alloc(4),
|
||||
position_gpu: alloc(4),
|
||||
|
||||
logits: alloc(vocab_size * es),
|
||||
};
|
||||
|
||||
let pre_attn_graphs = (0..num_layers).map(|_| CudaGraph::new()).collect();
|
||||
let post_attn_graphs = (0..num_layers).map(|_| CudaGraph::new()).collect();
|
||||
|
||||
Self {
|
||||
stream,
|
||||
buffers,
|
||||
pre_attn_graphs,
|
||||
post_attn_graphs,
|
||||
final_graph: CudaGraph::new(),
|
||||
captured: false,
|
||||
hidden,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
intermediate,
|
||||
vocab_size,
|
||||
num_layers,
|
||||
eps,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_captured(&self) -> bool {
|
||||
self.captured
|
||||
}
|
||||
|
||||
/// Capture all per-layer graphs. Called once after the first decode step.
|
||||
pub fn capture(
|
||||
&mut self,
|
||||
layers: &[LayerWeightPtrs],
|
||||
norm_weight: *const c_void,
|
||||
lm_head_wt: *const c_void,
|
||||
_embed_table: *const c_void,
|
||||
rope_cos: *const c_void,
|
||||
rope_sin: *const c_void,
|
||||
) {
|
||||
let s = self.stream.as_raw();
|
||||
let h = self.hidden as i32;
|
||||
let nh = self.num_heads as i32;
|
||||
let nkv = self.num_kv_heads as i32;
|
||||
let hd = self.head_dim as i32;
|
||||
let inter = self.intermediate as i32;
|
||||
let vocab = self.vocab_size as i32;
|
||||
let eps = self.eps;
|
||||
|
||||
let cublas = cublas_handle();
|
||||
|
||||
// Set cuBLAS to use our stream
|
||||
unsafe {
|
||||
dispatch::set_cublas_stream(cublas, s);
|
||||
}
|
||||
|
||||
for (l, lw) in layers.iter().enumerate() {
|
||||
// === Pre-attention graph ===
|
||||
self.pre_attn_graphs[l]
|
||||
.begin_capture(&self.stream)
|
||||
.expect("begin pre-attn capture");
|
||||
unsafe {
|
||||
// RMSNorm
|
||||
dispatch::rmsnorm_bf16(
|
||||
self.buffers.x.as_ptr() as _,
|
||||
lw.input_norm,
|
||||
self.buffers.normed.as_mut_ptr() as _,
|
||||
1,
|
||||
h,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
|
||||
// Q projection (GEMV)
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _,
|
||||
lw.q_proj_wt,
|
||||
self.buffers.q_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_q.as_mut_ptr() as _,
|
||||
h,
|
||||
nh * hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// K projection (GEMV)
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _,
|
||||
lw.k_proj_wt,
|
||||
self.buffers.k_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_kv.as_mut_ptr() as _,
|
||||
h,
|
||||
nkv * hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// V projection (GEMV)
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _,
|
||||
lw.v_proj_wt,
|
||||
self.buffers.v_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_kv.as_mut_ptr() as _,
|
||||
h,
|
||||
nkv * hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// Reshape heads: [1, H*D] -> [1, H, 1, D]
|
||||
dispatch::reshape_heads_bf16(
|
||||
self.buffers.q_proj.as_ptr() as _,
|
||||
self.buffers.q_reshaped.as_mut_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::reshape_heads_bf16(
|
||||
self.buffers.k_proj.as_ptr() as _,
|
||||
self.buffers.k_reshaped.as_mut_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::reshape_heads_bf16(
|
||||
self.buffers.v_proj.as_ptr() as _,
|
||||
self.buffers.v_reshaped.as_mut_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// QK norm (head-level rmsnorm: treat [1,H,1,D] as [H, D])
|
||||
dispatch::rmsnorm_bf16(
|
||||
self.buffers.q_reshaped.as_ptr() as _,
|
||||
lw.q_norm,
|
||||
self.buffers.q_normed.as_mut_ptr() as _,
|
||||
nh,
|
||||
hd,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
dispatch::rmsnorm_bf16(
|
||||
self.buffers.k_reshaped.as_ptr() as _,
|
||||
lw.k_norm,
|
||||
self.buffers.k_normed.as_mut_ptr() as _,
|
||||
nkv,
|
||||
hd,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
|
||||
// Transpose for RoPE: [1,H,1,D] -> [1,H,D]
|
||||
dispatch::transpose_hsd_to_shd_bf16(
|
||||
self.buffers.q_normed.as_ptr() as _,
|
||||
self.buffers.q_rope.as_mut_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::transpose_hsd_to_shd_bf16(
|
||||
self.buffers.k_normed.as_ptr() as _,
|
||||
self.buffers.k_rope.as_mut_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// RoPE (in-place, reads position_gpu)
|
||||
dispatch::rope_bf16(
|
||||
self.buffers.q_rope.as_mut_ptr() as _,
|
||||
rope_cos,
|
||||
rope_sin,
|
||||
self.buffers.position_gpu.as_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::rope_bf16(
|
||||
self.buffers.k_rope.as_mut_ptr() as _,
|
||||
rope_cos,
|
||||
rope_sin,
|
||||
self.buffers.position_gpu.as_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// Transpose back: [1,H,D] -> [1,H,1,D]
|
||||
dispatch::transpose_shd_to_hsd_bf16(
|
||||
self.buffers.q_rope.as_ptr() as _,
|
||||
self.buffers.q_final.as_mut_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::transpose_shd_to_hsd_bf16(
|
||||
self.buffers.k_rope.as_ptr() as _,
|
||||
self.buffers.k_final.as_mut_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
}
|
||||
self.pre_attn_graphs[l]
|
||||
.end_capture(&self.stream)
|
||||
.expect("end pre-attn capture");
|
||||
|
||||
// === Post-attention graph ===
|
||||
self.post_attn_graphs[l]
|
||||
.begin_capture(&self.stream)
|
||||
.expect("begin post-attn capture");
|
||||
unsafe {
|
||||
// Merge heads: [1,H,1,D] -> [1, hidden]
|
||||
// attn_out is written by ungraphed attention
|
||||
dispatch::merge_heads_bf16(
|
||||
self.buffers.attn_out.as_ptr() as _,
|
||||
self.buffers.attn_merged.as_mut_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// O projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.attn_merged.as_ptr() as _,
|
||||
lw.o_proj_wt,
|
||||
self.buffers.o_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_hidden.as_mut_ptr() as _,
|
||||
nh * hd,
|
||||
h,
|
||||
s,
|
||||
);
|
||||
|
||||
// Fused Add+RMSNorm: normed2 = rmsnorm(o_proj + x), sum_out = o_proj + x
|
||||
dispatch::add_rmsnorm_bf16(
|
||||
self.buffers.o_proj.as_ptr() as _,
|
||||
self.buffers.x.as_ptr() as _,
|
||||
lw.post_norm,
|
||||
self.buffers.normed2.as_mut_ptr() as _,
|
||||
self.buffers.sum_out.as_mut_ptr() as _,
|
||||
1,
|
||||
h,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
|
||||
// Gate projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed2.as_ptr() as _,
|
||||
lw.gate_proj_wt,
|
||||
self.buffers.gate.as_mut_ptr() as _,
|
||||
self.buffers.fp32_intermediate.as_mut_ptr() as _,
|
||||
h,
|
||||
inter,
|
||||
s,
|
||||
);
|
||||
|
||||
// Up projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed2.as_ptr() as _,
|
||||
lw.up_proj_wt,
|
||||
self.buffers.up.as_mut_ptr() as _,
|
||||
self.buffers.fp32_intermediate.as_mut_ptr() as _,
|
||||
h,
|
||||
inter,
|
||||
s,
|
||||
);
|
||||
|
||||
// Fused SiLU x Mul
|
||||
dispatch::silu_mul_bf16(
|
||||
self.buffers.gate.as_ptr() as _,
|
||||
self.buffers.up.as_ptr() as _,
|
||||
self.buffers.silu_out.as_mut_ptr() as _,
|
||||
inter,
|
||||
s,
|
||||
);
|
||||
|
||||
// Down projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.silu_out.as_ptr() as _,
|
||||
lw.down_proj_wt,
|
||||
self.buffers.down.as_mut_ptr() as _,
|
||||
self.buffers.fp32_hidden.as_mut_ptr() as _,
|
||||
inter,
|
||||
h,
|
||||
s,
|
||||
);
|
||||
|
||||
// x = sum_out + down (residual connection for next layer)
|
||||
dispatch::add_bf16(
|
||||
self.buffers.sum_out.as_ptr() as _,
|
||||
self.buffers.down.as_ptr() as _,
|
||||
self.buffers.x.as_mut_ptr() as _,
|
||||
h,
|
||||
s,
|
||||
);
|
||||
}
|
||||
self.post_attn_graphs[l]
|
||||
.end_capture(&self.stream)
|
||||
.expect("end post-attn capture");
|
||||
}
|
||||
|
||||
// === Final graph: norm + lm_head ===
|
||||
self.final_graph
|
||||
.begin_capture(&self.stream)
|
||||
.expect("begin final capture");
|
||||
unsafe {
|
||||
dispatch::rmsnorm_bf16(
|
||||
self.buffers.x.as_ptr() as _,
|
||||
norm_weight,
|
||||
self.buffers.normed.as_mut_ptr() as _,
|
||||
1,
|
||||
h,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _,
|
||||
lm_head_wt,
|
||||
self.buffers.logits.as_mut_ptr() as _,
|
||||
self.buffers.fp32_vocab.as_mut_ptr() as _,
|
||||
h,
|
||||
vocab,
|
||||
s,
|
||||
);
|
||||
}
|
||||
self.final_graph
|
||||
.end_capture(&self.stream)
|
||||
.expect("end final capture");
|
||||
|
||||
// Reset cuBLAS back to null stream
|
||||
unsafe {
|
||||
dispatch::set_cublas_stream(cublas, std::ptr::null_mut());
|
||||
}
|
||||
|
||||
self.captured = true;
|
||||
}
|
||||
|
||||
/// Execute a single decode step using captured graphs.
|
||||
pub fn execute(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
position: u32,
|
||||
cache: &mut GpuKVCache,
|
||||
_layers: &[LayerWeightPtrs],
|
||||
embed_table: *const c_void,
|
||||
vocab_size: i32,
|
||||
hidden_size: i32,
|
||||
) {
|
||||
assert!(self.captured, "must call capture() before execute()");
|
||||
let s = self.stream.as_raw();
|
||||
let nkv = self.num_kv_heads;
|
||||
let nh = self.num_heads;
|
||||
let hd = self.head_dim;
|
||||
let es = 2usize; // BF16
|
||||
|
||||
// Upload token ID and position to fixed GPU buffers
|
||||
self.buffers
|
||||
.token_id_gpu
|
||||
.copy_from_host(&token_id.to_le_bytes())
|
||||
.unwrap();
|
||||
self.buffers
|
||||
.position_gpu
|
||||
.copy_from_host(&position.to_le_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Embedding (outside graph since token_id changes each step)
|
||||
unsafe {
|
||||
dispatch::embedding_bf16(
|
||||
embed_table,
|
||||
self.buffers.token_id_gpu.as_ptr() as _,
|
||||
self.buffers.x.as_mut_ptr() as _,
|
||||
1,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
s,
|
||||
);
|
||||
}
|
||||
|
||||
for l in 0..self.num_layers {
|
||||
// Pre-attention graph (norm + QKV + reshape + QK-norm + RoPE)
|
||||
self.pre_attn_graphs[l]
|
||||
.launch(&self.stream)
|
||||
.expect("launch pre-attn graph");
|
||||
|
||||
// Ungraphed: KV cache append
|
||||
// k_final shape: [1, num_kv_heads, 1, head_dim] (after RoPE pipeline)
|
||||
// v_reshaped shape: [1, num_kv_heads, 1, head_dim] (V skips RoPE)
|
||||
let pos = position as usize;
|
||||
|
||||
let k_buf_size = nkv * hd * es;
|
||||
let v_buf_size = nkv * hd * es;
|
||||
let shape = [1usize, nkv, 1, hd];
|
||||
|
||||
// Synchronize before accessing buffers for KV cache append
|
||||
self.stream.synchronize().expect("sync before kv cache");
|
||||
|
||||
let k_view = unsafe {
|
||||
crate::kv_cache::tensor_from_gpu_buffer_pub(
|
||||
GpuBuffer::borrow_raw(self.buffers.k_final.as_mut_ptr(), k_buf_size),
|
||||
&shape,
|
||||
xserv_tensor::DType::BF16,
|
||||
0,
|
||||
)
|
||||
};
|
||||
let v_view = unsafe {
|
||||
crate::kv_cache::tensor_from_gpu_buffer_pub(
|
||||
GpuBuffer::borrow_raw(self.buffers.v_reshaped.as_mut_ptr(), v_buf_size),
|
||||
&shape,
|
||||
xserv_tensor::DType::BF16,
|
||||
0,
|
||||
)
|
||||
};
|
||||
cache.append(l, &k_view, &v_view, 1, pos);
|
||||
|
||||
// Ungraphed: get full KV cache and run decode attention
|
||||
let (k_full, v_full) = cache.get_kv_len(l, pos + 1);
|
||||
let kv_len = (pos + 1) as i32;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
|
||||
// Attention output written to attn_out (separate from q_final)
|
||||
unsafe {
|
||||
dispatch::decode_attention_bf16(
|
||||
self.buffers.q_final.as_ptr() as _,
|
||||
k_full.data_ptr() as _,
|
||||
v_full.data_ptr() as _,
|
||||
self.buffers.attn_out.as_mut_ptr() as _,
|
||||
1,
|
||||
nh as i32,
|
||||
nkv as i32,
|
||||
kv_len,
|
||||
hd as i32,
|
||||
scale,
|
||||
s,
|
||||
);
|
||||
}
|
||||
|
||||
// Synchronize before post-attention graph reads attn_out
|
||||
self.stream.synchronize().expect("sync before post-attn");
|
||||
|
||||
// Post-attention graph (merge + O-proj + add_rmsnorm + FFN + residual)
|
||||
self.post_attn_graphs[l]
|
||||
.launch(&self.stream)
|
||||
.expect("launch post-attn graph");
|
||||
}
|
||||
|
||||
// Final graph (norm + lm_head)
|
||||
self.final_graph
|
||||
.launch(&self.stream)
|
||||
.expect("launch final graph");
|
||||
|
||||
// Sync to ensure logits are ready
|
||||
self.stream.synchronize().expect("sync after decode");
|
||||
}
|
||||
|
||||
/// Get the logits buffer (for reading results after execute).
|
||||
pub fn logits_buffer(&self) -> &GpuBuffer {
|
||||
&self.buffers.logits
|
||||
}
|
||||
|
||||
/// Invalidate captured graphs (e.g. when switching sequences).
|
||||
pub fn invalidate(&mut self) {
|
||||
self.captured = false;
|
||||
self.pre_attn_graphs = (0..self.num_layers).map(|_| CudaGraph::new()).collect();
|
||||
self.post_attn_graphs = (0..self.num_layers).map(|_| CudaGraph::new()).collect();
|
||||
self.final_graph = CudaGraph::new();
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for DecodeGraphState {}
|
||||
|
||||
/// Lightweight struct holding raw pointers to a layer's weight tensors.
|
||||
/// Used to avoid passing the full model struct into the graph capture code.
|
||||
pub struct LayerWeightPtrs {
|
||||
pub input_norm: *const c_void,
|
||||
pub q_proj_wt: *const c_void,
|
||||
pub k_proj_wt: *const c_void,
|
||||
pub v_proj_wt: *const c_void,
|
||||
pub o_proj_wt: *const c_void,
|
||||
pub q_norm: *const c_void,
|
||||
pub k_norm: *const c_void,
|
||||
pub post_norm: *const c_void,
|
||||
pub gate_proj_wt: *const c_void,
|
||||
pub up_proj_wt: *const c_void,
|
||||
pub down_proj_wt: *const c_void,
|
||||
}
|
||||
|
||||
unsafe impl Send for LayerWeightPtrs {}
|
||||
unsafe impl Sync for LayerWeightPtrs {}
|
||||
425
crates/xserv-model/src/eagle3.rs
Normal file
425
crates/xserv-model/src/eagle3.rs
Normal file
@@ -0,0 +1,425 @@
|
||||
//! EAGLE3 speculative draft head for Qwen3-8B (Phase 25).
|
||||
//!
|
||||
//! Loads the AngelSlim/Qwen3-8B_eagle3 pytorch_model.bin and provides a
|
||||
//! single-step forward pass that takes 3 target hidden states + the previous
|
||||
//! token and returns a draft token in the target vocabulary.
|
||||
//!
|
||||
//! Architecture (from weights):
|
||||
//! - fc: [hidden, 3*hidden] → fuse 3 target hidden states
|
||||
//! - midlayer: 1 decoder layer (attn input dim = 2*hidden)
|
||||
//! - norm + lm_head: → [draft_vocab_size=32000]
|
||||
//! - d2t: draft_id → target_id offset mapping
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
/// Target layers to hook for EAGLE3 auxiliary hidden states, for Qwen3-8B
|
||||
/// (36 layers). Value comes from AngelSlim/vLLM speculators training config
|
||||
/// `dflash_qwen3_8b_sharegpt_online_5k.sh` which specifies target_layer_ids
|
||||
/// = "2 18 33". Must match training-time selection or EAGLE outputs are wrong.
|
||||
pub const EAGLE_HOOK_LAYERS: [usize; 3] = [2, 18, 33];
|
||||
const DRAFT_VOCAB_SIZE: usize = 32000;
|
||||
|
||||
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert_eq!(b.ndim(), 2);
|
||||
matmul(a, b, GemmBackend::CuBlas)
|
||||
}
|
||||
|
||||
pub struct Eagle3Head {
|
||||
fc_wt: Tensor, // [hidden, 3*hidden] transposed for matmul
|
||||
hidden_norm: Tensor, // [hidden]
|
||||
input_layernorm: Tensor, // [hidden]
|
||||
q_proj_wt: Tensor, // [num_heads*head_dim, 2*hidden]
|
||||
k_proj_wt: Tensor, // [num_kv_heads*head_dim, 2*hidden]
|
||||
v_proj_wt: Tensor, // [num_kv_heads*head_dim, 2*hidden]
|
||||
o_proj_wt: Tensor, // [hidden, num_heads*head_dim]
|
||||
gate_proj_wt: Tensor, // [intermediate, hidden]
|
||||
up_proj_wt: Tensor, // [intermediate, hidden]
|
||||
down_proj_wt: Tensor, // [hidden, intermediate]
|
||||
post_attention_layernorm: Tensor, // [hidden]
|
||||
norm: Tensor, // [hidden] final
|
||||
lm_head_wt: Tensor, // [draft_vocab, hidden]
|
||||
d2t: Vec<i64>, // [draft_vocab] offset mapping
|
||||
/// t2d[target_id] = true iff target_id has a corresponding draft-vocab id
|
||||
/// (i.e. can potentially be produced by EAGLE). Used to measure the
|
||||
/// coverage cap on acceptance.
|
||||
t2d: Vec<bool>,
|
||||
hidden_size: usize,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
max_seq_len: usize,
|
||||
rope_cache: RopeCache,
|
||||
// Stateful 1-layer KV cache: [1, num_kv_heads, max_seq_len, head_dim] BF16.
|
||||
// We slice `..current_len` for attention. The head is tiny (~64 KB per
|
||||
// 1000 tokens) so pre-allocating max_seq_len wastes negligible memory.
|
||||
k_cache: Tensor,
|
||||
v_cache: Tensor,
|
||||
current_len: usize,
|
||||
}
|
||||
|
||||
impl Eagle3Head {
|
||||
pub fn load(dir: &Path, device: u32) -> Self {
|
||||
let (weights, d2t, t2d) = load_eagle3_weights(dir, device);
|
||||
let hidden_size = 4096;
|
||||
let num_heads = 32;
|
||||
let num_kv_heads = 8;
|
||||
let head_dim = 128;
|
||||
let intermediate_size = 12288;
|
||||
let max_seq_len = 2048;
|
||||
let rope_theta = 1_000_000.0f32;
|
||||
|
||||
let get = |name: &str| -> Tensor {
|
||||
weights
|
||||
.get(name)
|
||||
.unwrap_or_else(|| panic!("missing eagle3 weight: {name}"))
|
||||
.clone()
|
||||
};
|
||||
|
||||
let fc_wt = get("fc.weight").transpose(0, 1).contiguous();
|
||||
let q_proj_wt = get("midlayer.self_attn.q_proj.weight")
|
||||
.transpose(0, 1)
|
||||
.contiguous();
|
||||
let k_proj_wt = get("midlayer.self_attn.k_proj.weight")
|
||||
.transpose(0, 1)
|
||||
.contiguous();
|
||||
let v_proj_wt = get("midlayer.self_attn.v_proj.weight")
|
||||
.transpose(0, 1)
|
||||
.contiguous();
|
||||
let o_proj_wt = get("midlayer.self_attn.o_proj.weight")
|
||||
.transpose(0, 1)
|
||||
.contiguous();
|
||||
let gate_proj_wt = get("midlayer.mlp.gate_proj.weight")
|
||||
.transpose(0, 1)
|
||||
.contiguous();
|
||||
let up_proj_wt = get("midlayer.mlp.up_proj.weight")
|
||||
.transpose(0, 1)
|
||||
.contiguous();
|
||||
let down_proj_wt = get("midlayer.mlp.down_proj.weight")
|
||||
.transpose(0, 1)
|
||||
.contiguous();
|
||||
let hidden_norm = get("midlayer.hidden_norm.weight");
|
||||
let input_layernorm = get("midlayer.input_layernorm.weight");
|
||||
let post_attention_layernorm = get("midlayer.post_attention_layernorm.weight");
|
||||
let norm = get("norm.weight");
|
||||
let lm_head_wt = get("lm_head.weight").transpose(0, 1).contiguous();
|
||||
|
||||
assert_eq!(d2t.len(), DRAFT_VOCAB_SIZE);
|
||||
|
||||
let rope_cache = RopeCache::new(max_seq_len, head_dim, rope_theta);
|
||||
|
||||
let k_cache = Tensor::zeros(
|
||||
&[1, num_kv_heads, max_seq_len, head_dim],
|
||||
DType::BF16,
|
||||
Device::Cuda(device),
|
||||
);
|
||||
let v_cache = Tensor::zeros(
|
||||
&[1, num_kv_heads, max_seq_len, head_dim],
|
||||
DType::BF16,
|
||||
Device::Cuda(device),
|
||||
);
|
||||
|
||||
Self {
|
||||
fc_wt,
|
||||
hidden_norm,
|
||||
input_layernorm,
|
||||
q_proj_wt,
|
||||
k_proj_wt,
|
||||
v_proj_wt,
|
||||
o_proj_wt,
|
||||
gate_proj_wt,
|
||||
up_proj_wt,
|
||||
down_proj_wt,
|
||||
post_attention_layernorm,
|
||||
norm,
|
||||
lm_head_wt,
|
||||
d2t,
|
||||
t2d,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
rope_cache,
|
||||
k_cache,
|
||||
v_cache,
|
||||
current_len: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the internal KV cache for a fresh sequence.
|
||||
pub fn reset(&mut self) {
|
||||
self.current_len = 0;
|
||||
}
|
||||
|
||||
/// Truncate the internal KV cache to `new_len` entries. Used to discard
|
||||
/// K/V of rejected drafts after a speculative round.
|
||||
pub fn truncate_to(&mut self, new_len: usize) {
|
||||
assert!(new_len <= self.current_len);
|
||||
self.current_len = new_len;
|
||||
}
|
||||
|
||||
/// Current number of committed K/V entries in the internal EAGLE cache.
|
||||
pub fn current_len(&self) -> usize {
|
||||
self.current_len
|
||||
}
|
||||
|
||||
/// One draft step: produce a token in target vocabulary space.
|
||||
///
|
||||
/// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers
|
||||
/// - `embed_table`: the target model's embed_tokens (shared, not copied)
|
||||
/// - `prev_token`: the previous committed token
|
||||
/// - `position`: the decode position for RoPE
|
||||
///
|
||||
/// Returns (draft_token_in_target_vocab, draft_logits_tensor).
|
||||
pub fn step(
|
||||
&mut self,
|
||||
target_hidden: &[Tensor; 3],
|
||||
embed_table: &Tensor,
|
||||
prev_token: u32,
|
||||
position: usize,
|
||||
) -> (u32, Tensor) {
|
||||
let (id, logits, _) = self.step_with_aux(target_hidden, embed_table, prev_token, position);
|
||||
(id, logits)
|
||||
}
|
||||
|
||||
/// Like `step`, but also returns the final hidden state (aux) usable as
|
||||
/// the fused_h for a subsequent recursive draft step via `step_recursive`.
|
||||
pub fn step_with_aux(
|
||||
&mut self,
|
||||
target_hidden: &[Tensor; 3],
|
||||
embed_table: &Tensor,
|
||||
prev_token: u32,
|
||||
position: usize,
|
||||
) -> (u32, Tensor, Tensor) {
|
||||
// Fuse 3 target hidden states into fused_h via fc.
|
||||
let h_cat = concat_hidden(target_hidden);
|
||||
let fused_h = matmul_2d(&h_cat, &self.fc_wt);
|
||||
self.forward_from_fused(fused_h, embed_table, prev_token, position)
|
||||
}
|
||||
|
||||
/// Recursive draft step: reuses the previous EAGLE step's aux as fused_h,
|
||||
/// bypassing the fc+3-hidden fusion. Used for γ≥2 chained drafts.
|
||||
pub fn step_recursive(
|
||||
&mut self,
|
||||
fused_h: Tensor,
|
||||
embed_table: &Tensor,
|
||||
prev_token: u32,
|
||||
position: usize,
|
||||
) -> (u32, Tensor, Tensor) {
|
||||
self.forward_from_fused(fused_h, embed_table, prev_token, position)
|
||||
}
|
||||
|
||||
fn forward_from_fused(
|
||||
&mut self,
|
||||
fused_h: Tensor,
|
||||
embed_table: &Tensor,
|
||||
prev_token: u32,
|
||||
position: usize,
|
||||
) -> (u32, Tensor, Tensor) {
|
||||
let eps = 1e-6f32;
|
||||
assert!(
|
||||
self.current_len < self.max_seq_len,
|
||||
"EAGLE KV cache overflow: {} >= {}",
|
||||
self.current_len,
|
||||
self.max_seq_len
|
||||
);
|
||||
|
||||
let emb = embedding(embed_table, &[prev_token]);
|
||||
let residual = fused_h.clone();
|
||||
let emb_normed = rmsnorm(&emb, &self.input_layernorm, eps);
|
||||
let h_normed = rmsnorm(&fused_h, &self.hidden_norm, eps);
|
||||
let attn_in = concat_last_dim(&emb_normed, &h_normed);
|
||||
|
||||
let q = matmul_2d(&attn_in, &self.q_proj_wt);
|
||||
let k = matmul_2d(&attn_in, &self.k_proj_wt);
|
||||
let v = matmul_2d(&attn_in, &self.v_proj_wt);
|
||||
|
||||
let q_3d = q.reshape(&[1, self.num_heads, self.head_dim]);
|
||||
let k_3d = k.reshape(&[1, self.num_kv_heads, self.head_dim]);
|
||||
let positions = [position as u32];
|
||||
rope_inplace(&q_3d, &self.rope_cache, &positions);
|
||||
rope_inplace(&k_3d, &self.rope_cache, &positions);
|
||||
|
||||
let v_3d = v.reshape(&[1, self.num_kv_heads, self.head_dim]);
|
||||
self.append_to_kv_cache(&k_3d, &v_3d);
|
||||
self.current_len += 1;
|
||||
let kv_len = self.current_len;
|
||||
let k_view = self.k_cache.narrow(2, 0, kv_len).contiguous();
|
||||
let v_view = self.v_cache.narrow(2, 0, kv_len).contiguous();
|
||||
|
||||
let q_4d = q_3d.reshape(&[1, self.num_heads, 1, self.head_dim]);
|
||||
let attn_out = decode_attention(&q_4d, &k_view, &v_view);
|
||||
|
||||
let attn_merged = attn_out.reshape(&[1, self.num_heads * self.head_dim]);
|
||||
let attn_proj = matmul_2d(&attn_merged, &self.o_proj_wt);
|
||||
|
||||
let (mlp_in, residual) =
|
||||
add_rmsnorm(&attn_proj, &residual, &self.post_attention_layernorm, eps);
|
||||
|
||||
let gate = matmul_2d(&mlp_in, &self.gate_proj_wt);
|
||||
let up = matmul_2d(&mlp_in, &self.up_proj_wt);
|
||||
let hidden = silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden, &self.down_proj_wt);
|
||||
|
||||
let (x, prenorm) = add_rmsnorm(&down, &residual, &self.norm, eps);
|
||||
let logits = matmul_2d(&x, &self.lm_head_wt);
|
||||
|
||||
let draft_id = argmax_bf16_single(&logits);
|
||||
let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32;
|
||||
// aux for recursive drafting = PRE-norm hidden (default norm_output=False
|
||||
// in vllm/llama_eagle3.py). Feeding the pre-norm state matches training.
|
||||
(target_id, logits, prenorm)
|
||||
}
|
||||
|
||||
/// Write new K/V rows (shape [1, num_kv_heads, head_dim]) at position
|
||||
/// `current_len` inside the [1, num_kv_heads, max_seq_len, head_dim] cache.
|
||||
fn append_to_kv_cache(&mut self, new_k: &Tensor, new_v: &Tensor) {
|
||||
let head_bytes = self.head_dim * self.k_cache.dtype().size_bytes();
|
||||
for h in 0..self.num_kv_heads {
|
||||
for (cache, src) in [(&self.k_cache, new_k), (&self.v_cache, new_v)] {
|
||||
let dst = unsafe {
|
||||
(cache.data_ptr() as *mut u8)
|
||||
.add(((h * self.max_seq_len) + self.current_len) * head_bytes)
|
||||
};
|
||||
let s = unsafe { (src.data_ptr() as *const u8).add(h * head_bytes) };
|
||||
d2d(dst, s, head_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Map a draft-vocab token id to the full target-vocab id via d2t.
|
||||
pub fn map_draft_to_target(&self, draft_id: u32) -> u32 {
|
||||
(draft_id as i64 + self.d2t[draft_id as usize]) as u32
|
||||
}
|
||||
|
||||
/// Returns true iff `target_id` is representable in the draft vocabulary
|
||||
/// (i.e., EAGLE could in principle produce it).
|
||||
pub fn target_id_in_draft_vocab(&self, target_id: u32) -> bool {
|
||||
self.t2d.get(target_id as usize).copied().unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
fn d2d(dst: *mut u8, src: *const u8, bytes: usize) {
|
||||
unsafe {
|
||||
xserv_cuda::ffi::cudaMemcpy(dst, src, bytes, xserv_cuda::ffi::CUDA_MEMCPY_D2D);
|
||||
}
|
||||
}
|
||||
|
||||
fn concat_hidden(hidden: &[Tensor; 3]) -> Tensor {
|
||||
let h = hidden[0].shape()[1];
|
||||
let dtype = hidden[0].dtype();
|
||||
let device = hidden[0].device();
|
||||
let elem_bytes = dtype.size_bytes();
|
||||
let out = Tensor::empty(&[1, 3 * h], dtype, device);
|
||||
for (i, t) in hidden.iter().enumerate() {
|
||||
assert!(t.is_contiguous());
|
||||
let dst = unsafe { (out.data_ptr() as *mut u8).add(i * h * elem_bytes) };
|
||||
d2d(dst, t.data_ptr() as *const u8, h * elem_bytes);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn concat_last_dim(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
let da = a.shape()[1];
|
||||
let db = b.shape()[1];
|
||||
let dtype = a.dtype();
|
||||
let device = a.device();
|
||||
let elem_bytes = dtype.size_bytes();
|
||||
let out = Tensor::empty(&[1, da + db], dtype, device);
|
||||
d2d(
|
||||
out.data_ptr() as *mut u8,
|
||||
a.data_ptr() as *const u8,
|
||||
da * elem_bytes,
|
||||
);
|
||||
let dst = unsafe { (out.data_ptr() as *mut u8).add(da * elem_bytes) };
|
||||
d2d(dst, b.data_ptr() as *const u8, db * elem_bytes);
|
||||
out
|
||||
}
|
||||
|
||||
fn repeat_kv_for_single_token(kv: &Tensor, repeats: usize) -> Tensor {
|
||||
if repeats == 1 {
|
||||
return kv.clone();
|
||||
}
|
||||
let nkv = kv.shape()[1];
|
||||
let d = kv.shape()[2];
|
||||
let dtype = kv.dtype();
|
||||
let device = kv.device();
|
||||
let head_bytes = d * dtype.size_bytes();
|
||||
let out = Tensor::empty(&[1, nkv * repeats, d], dtype, device);
|
||||
for h in 0..nkv {
|
||||
let src = unsafe { (kv.data_ptr() as *const u8).add(h * head_bytes) };
|
||||
for r in 0..repeats {
|
||||
let dst = unsafe { (out.data_ptr() as *mut u8).add((h * repeats + r) * head_bytes) };
|
||||
d2d(dst, src, head_bytes);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Load EAGLE3 weights from safetensors, handling int64 d2t + bool t2d specially.
|
||||
fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap<String, Tensor>, Vec<i64>, Vec<bool>) {
|
||||
let st_path = dir.join("model.safetensors");
|
||||
assert!(
|
||||
st_path.exists(),
|
||||
"Eagle3 model.safetensors not found in {}. Convert with:\n\
|
||||
python3 -c \"import torch; from safetensors.torch import save_file; \
|
||||
sd=torch.load('pytorch_model.bin', map_location='cpu', weights_only=False); \
|
||||
save_file(sd, 'model.safetensors')\"",
|
||||
dir.display()
|
||||
);
|
||||
|
||||
let data = std::fs::read(&st_path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", st_path.display()));
|
||||
let st = safetensors::SafeTensors::deserialize(&data)
|
||||
.unwrap_or_else(|e| panic!("failed to parse {}: {e}", st_path.display()));
|
||||
|
||||
let mut tensors = HashMap::new();
|
||||
let mut d2t_vec: Vec<i64> = Vec::new();
|
||||
let mut t2d_vec: Vec<bool> = Vec::new();
|
||||
|
||||
for (name, view) in st.tensors() {
|
||||
if name == "t2d" {
|
||||
let raw = view.data();
|
||||
assert_eq!(view.dtype(), safetensors::Dtype::BOOL);
|
||||
t2d_vec = raw.iter().map(|&b| b != 0).collect();
|
||||
continue;
|
||||
}
|
||||
if name == "d2t" {
|
||||
let raw = view.data();
|
||||
assert_eq!(view.dtype(), safetensors::Dtype::I64);
|
||||
let n = raw.len() / 8;
|
||||
d2t_vec = (0..n)
|
||||
.map(|i| i64::from_le_bytes(raw[i * 8..(i + 1) * 8].try_into().unwrap()))
|
||||
.collect();
|
||||
continue;
|
||||
}
|
||||
let dtype = match view.dtype() {
|
||||
safetensors::Dtype::BF16 => DType::BF16,
|
||||
safetensors::Dtype::F32 => DType::F32,
|
||||
safetensors::Dtype::F16 => DType::F16,
|
||||
other => {
|
||||
eprintln!("eagle3: skipping {name} with unsupported dtype {other:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let shape: Vec<usize> = view.shape().to_vec();
|
||||
let raw = view.data();
|
||||
let t = crate::loader::make_tensor(raw, &shape, dtype);
|
||||
let t = t.to_device(Device::Cuda(device));
|
||||
tensors.insert(name.to_string(), t);
|
||||
}
|
||||
|
||||
assert!(
|
||||
!d2t_vec.is_empty(),
|
||||
"d2t tensor not found in eagle3 weights"
|
||||
);
|
||||
assert!(
|
||||
!t2d_vec.is_empty(),
|
||||
"t2d tensor not found in eagle3 weights"
|
||||
);
|
||||
(tensors, d2t_vec, t2d_vec)
|
||||
}
|
||||
437
crates/xserv-model/src/gpt2.rs
Normal file
437
crates/xserv-model/src/gpt2.rs
Normal file
@@ -0,0 +1,437 @@
|
||||
use std::collections::HashMap;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
use crate::config::ModelConfig;
|
||||
|
||||
pub struct GPT2 {
|
||||
pub config: ModelConfig,
|
||||
wte: Tensor,
|
||||
wpe: Tensor,
|
||||
layers: Vec<GPT2Block>,
|
||||
ln_f_g: Tensor,
|
||||
ln_f_b: Tensor,
|
||||
lm_head: Tensor, // precomputed wte^T
|
||||
}
|
||||
|
||||
struct GPT2Block {
|
||||
ln_1_g: Tensor,
|
||||
ln_1_b: Tensor,
|
||||
attn_qkv_w: Tensor,
|
||||
attn_qkv_b: Tensor,
|
||||
attn_out_w: Tensor,
|
||||
attn_out_b: Tensor,
|
||||
ln_2_g: Tensor,
|
||||
ln_2_b: Tensor,
|
||||
mlp_fc_w: Tensor,
|
||||
mlp_fc_b: Tensor,
|
||||
mlp_proj_w: Tensor,
|
||||
mlp_proj_b: Tensor,
|
||||
}
|
||||
|
||||
pub struct KVCache {
|
||||
// Per layer, per head: raw bytes (works for both f32 and bf16)
|
||||
k: Vec<Vec<Vec<u8>>>, // [num_layers][num_heads][seq_len * head_dim * elem_size]
|
||||
v: Vec<Vec<Vec<u8>>>,
|
||||
len: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
elem_size: usize,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
pub fn new(
|
||||
num_layers: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
) -> Self {
|
||||
Self {
|
||||
k: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(),
|
||||
v: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(),
|
||||
len: 0,
|
||||
num_heads,
|
||||
head_dim,
|
||||
elem_size: dtype.size_bytes(),
|
||||
dtype,
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn seq_len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
/// Append from a CPU tensor with shape [1, H, new_tokens, D].
|
||||
pub fn append_kv_tensor(
|
||||
&mut self,
|
||||
layer: usize,
|
||||
k_cpu: &Tensor,
|
||||
v_cpu: &Tensor,
|
||||
new_tokens: usize,
|
||||
) {
|
||||
let hd = self.head_dim;
|
||||
let es = self.elem_size;
|
||||
let k_bytes = k_cpu.storage().as_cpu_bytes();
|
||||
let v_bytes = v_cpu.storage().as_cpu_bytes();
|
||||
let chunk = new_tokens * hd * es;
|
||||
for h in 0..self.num_heads {
|
||||
let off = h * chunk;
|
||||
self.k[layer][h].extend_from_slice(&k_bytes[off..off + chunk]);
|
||||
self.v[layer][h].extend_from_slice(&v_bytes[off..off + chunk]);
|
||||
}
|
||||
if layer == 0 {
|
||||
self.len += new_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
/// Reconstruct [1, H, seq_len, D] tensors.
|
||||
pub fn get_kv_tensors(&self, layer: usize) -> (Tensor, Tensor) {
|
||||
let sl = self.len;
|
||||
let hd = self.head_dim;
|
||||
let nh = self.num_heads;
|
||||
let es = self.elem_size;
|
||||
let head_bytes = sl * hd * es;
|
||||
let total = nh * head_bytes;
|
||||
let mut k_data = vec![0u8; total];
|
||||
let mut v_data = vec![0u8; total];
|
||||
for h in 0..nh {
|
||||
let off = h * head_bytes;
|
||||
k_data[off..off + head_bytes].copy_from_slice(&self.k[layer][h]);
|
||||
v_data[off..off + head_bytes].copy_from_slice(&self.v[layer][h]);
|
||||
}
|
||||
let shape = &[1, nh, sl, hd];
|
||||
let k = tensor_from_raw_bytes(&k_data, shape, self.dtype).to_device(self.device);
|
||||
let v = tensor_from_raw_bytes(&v_data, shape, self.dtype).to_device(self.device);
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
fn tensor_from_raw_bytes(bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
let data: &[f32] = unsafe {
|
||||
std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4)
|
||||
};
|
||||
Tensor::from_slice(data, shape)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data: &[half::bf16] = unsafe {
|
||||
std::slice::from_raw_parts(bytes.as_ptr() as *const half::bf16, bytes.len() / 2)
|
||||
};
|
||||
Tensor::from_slice(data, shape)
|
||||
}
|
||||
_ => panic!("unsupported dtype for KV cache"),
|
||||
}
|
||||
}
|
||||
|
||||
impl GPT2 {
|
||||
pub fn from_weights(config: ModelConfig, mut w: HashMap<String, Tensor>) -> Self {
|
||||
crate::init_kernels();
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name)
|
||||
.unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
|
||||
let wte = take(&mut w, "wte.weight");
|
||||
let wpe = take(&mut w, "wpe.weight");
|
||||
let ln_f_g = take(&mut w, "ln_f.weight");
|
||||
let ln_f_b = take(&mut w, "ln_f.bias");
|
||||
let lm_head = wte.transpose(0, 1).contiguous();
|
||||
|
||||
let num_layers = config.num_layers();
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
for i in 0..num_layers {
|
||||
let p = format!("h.{i}");
|
||||
layers.push(GPT2Block {
|
||||
ln_1_g: take(&mut w, &format!("{p}.ln_1.weight")),
|
||||
ln_1_b: take(&mut w, &format!("{p}.ln_1.bias")),
|
||||
attn_qkv_w: take(&mut w, &format!("{p}.attn.c_attn.weight")),
|
||||
attn_qkv_b: take(&mut w, &format!("{p}.attn.c_attn.bias")),
|
||||
attn_out_w: take(&mut w, &format!("{p}.attn.c_proj.weight")),
|
||||
attn_out_b: take(&mut w, &format!("{p}.attn.c_proj.bias")),
|
||||
ln_2_g: take(&mut w, &format!("{p}.ln_2.weight")),
|
||||
ln_2_b: take(&mut w, &format!("{p}.ln_2.bias")),
|
||||
mlp_fc_w: take(&mut w, &format!("{p}.mlp.c_fc.weight")),
|
||||
mlp_fc_b: take(&mut w, &format!("{p}.mlp.c_fc.bias")),
|
||||
mlp_proj_w: take(&mut w, &format!("{p}.mlp.c_proj.weight")),
|
||||
mlp_proj_b: take(&mut w, &format!("{p}.mlp.c_proj.bias")),
|
||||
});
|
||||
}
|
||||
|
||||
Self {
|
||||
config,
|
||||
wte,
|
||||
wpe,
|
||||
layers,
|
||||
ln_f_g,
|
||||
ln_f_b,
|
||||
lm_head,
|
||||
}
|
||||
}
|
||||
|
||||
/// Full forward pass without KV cache (for testing / correctness comparison).
|
||||
pub fn forward(&self, token_ids: &[u32]) -> Tensor {
|
||||
let seq_len = token_ids.len();
|
||||
let hidden = self.config.hidden();
|
||||
let num_heads = self.config.num_heads();
|
||||
let head_dim = self.config.head_dim();
|
||||
|
||||
let tok_emb = embedding(&self.wte, token_ids);
|
||||
let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
|
||||
let pos_emb = embedding(&self.wpe, &pos_ids);
|
||||
let mut x = add_tensors(&tok_emb, &pos_emb);
|
||||
|
||||
for layer in &self.layers {
|
||||
x = self.transformer_block(layer, &x, None, 0, seq_len, num_heads, head_dim, hidden);
|
||||
}
|
||||
|
||||
let x = layernorm(&x, &self.ln_f_g, &self.ln_f_b, self.config.ln_eps());
|
||||
matmul_2d(&x, &self.lm_head)
|
||||
}
|
||||
|
||||
/// Forward pass with KV cache. First call = prefill, subsequent = decode.
|
||||
pub fn forward_with_cache(&self, token_ids: &[u32], cache: &mut KVCache) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
let pos_offset = cache.seq_len();
|
||||
let hidden = self.config.hidden();
|
||||
let num_heads = self.config.num_heads();
|
||||
let head_dim = self.config.head_dim();
|
||||
|
||||
let tok_emb = embedding(&self.wte, token_ids);
|
||||
let pos_ids: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
let pos_emb = embedding(&self.wpe, &pos_ids);
|
||||
let mut x = add_tensors(&tok_emb, &pos_emb);
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
x = self.transformer_block(
|
||||
layer,
|
||||
&x,
|
||||
Some((cache, layer_idx)),
|
||||
pos_offset,
|
||||
new_tokens,
|
||||
num_heads,
|
||||
head_dim,
|
||||
hidden,
|
||||
);
|
||||
}
|
||||
|
||||
let x = layernorm(&x, &self.ln_f_g, &self.ln_f_b, self.config.ln_eps());
|
||||
matmul_2d(&x, &self.lm_head)
|
||||
}
|
||||
|
||||
fn transformer_block(
|
||||
&self,
|
||||
layer: &GPT2Block,
|
||||
x: &Tensor,
|
||||
cache: Option<(&mut KVCache, usize)>,
|
||||
_pos_offset: usize,
|
||||
new_tokens: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
hidden: usize,
|
||||
) -> Tensor {
|
||||
let residual = x.clone();
|
||||
let normed = layernorm(x, &layer.ln_1_g, &layer.ln_1_b, self.config.ln_eps());
|
||||
|
||||
let qkv = linear(&normed, &layer.attn_qkv_w, Some(&layer.attn_qkv_b));
|
||||
let (q, k_new, v_new) = split_qkv(&qkv, num_heads, head_dim, new_tokens);
|
||||
|
||||
let (k_full, v_full) = if let Some((cache, layer_idx)) = cache {
|
||||
let k_cpu = k_new.to_device(Device::Cpu);
|
||||
let v_cpu = v_new.to_device(Device::Cpu);
|
||||
cache.append_kv_tensor(layer_idx, &k_cpu, &v_cpu, new_tokens);
|
||||
cache.get_kv_tensors(layer_idx)
|
||||
} else {
|
||||
(k_new, v_new)
|
||||
};
|
||||
|
||||
let attn_out = attention(&q, &k_full, &v_full, true);
|
||||
let attn_out = merge_heads(&attn_out, new_tokens, hidden);
|
||||
let attn_out = linear(&attn_out, &layer.attn_out_w, Some(&layer.attn_out_b));
|
||||
let x = add_tensors(&residual, &attn_out);
|
||||
|
||||
let residual = x.clone();
|
||||
let normed = layernorm(&x, &layer.ln_2_g, &layer.ln_2_b, self.config.ln_eps());
|
||||
let fc = linear(&normed, &layer.mlp_fc_w, Some(&layer.mlp_fc_b));
|
||||
let activated = gelu(&fc);
|
||||
let proj = linear(&activated, &layer.mlp_proj_w, Some(&layer.mlp_proj_b));
|
||||
add_tensors(&residual, &proj)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper ops (unchanged) ---
|
||||
|
||||
fn linear(x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Tensor {
|
||||
let out = matmul_2d(x, weight);
|
||||
if let Some(b) = bias {
|
||||
add_bias(&out, b)
|
||||
} else {
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert_eq!(b.ndim(), 2);
|
||||
matmul(a, b, GemmBackend::CuBlas)
|
||||
}
|
||||
|
||||
fn add_tensors(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
xserv_kernels::add(a, b)
|
||||
}
|
||||
|
||||
fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
|
||||
// bias: [N], x: [S, N] — broadcast add via reshape
|
||||
assert_eq!(x.ndim(), 2);
|
||||
assert_eq!(bias.ndim(), 1);
|
||||
let n = bias.shape()[0];
|
||||
assert_eq!(x.shape()[1], n);
|
||||
let rows = x.shape()[0];
|
||||
// Broadcast: tile bias to [S, N] on CPU, then GPU add
|
||||
let b_cpu = bias.to_device(Device::Cpu);
|
||||
match x.dtype() {
|
||||
DType::F32 => {
|
||||
let bd = b_cpu.as_slice::<f32>();
|
||||
let tiled: Vec<f32> = (0..rows).flat_map(|_| bd.iter().copied()).collect();
|
||||
let b_full = Tensor::from_slice(&tiled, x.shape()).to_device(x.device());
|
||||
xserv_kernels::add(x, &b_full)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let bd = b_cpu.as_slice::<half::bf16>();
|
||||
let tiled: Vec<half::bf16> = (0..rows).flat_map(|_| bd.iter().copied()).collect();
|
||||
let b_full = Tensor::from_slice(&tiled, x.shape()).to_device(x.device());
|
||||
xserv_kernels::add(x, &b_full)
|
||||
}
|
||||
_ => panic!("unsupported dtype"),
|
||||
}
|
||||
}
|
||||
|
||||
fn split_qkv(
|
||||
qkv: &Tensor,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
seq_len: usize,
|
||||
) -> (Tensor, Tensor, Tensor) {
|
||||
let hidden = num_heads * head_dim;
|
||||
let qkv_cpu = qkv.to_device(Device::Cpu);
|
||||
let device = qkv.device();
|
||||
let dtype = qkv.dtype();
|
||||
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
let data = qkv_cpu.as_slice::<f32>();
|
||||
let mut q_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
let mut k_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
let mut v_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
for s in 0..seq_len {
|
||||
let row = &data[s * 3 * hidden..(s + 1) * 3 * hidden];
|
||||
for h in 0..num_heads {
|
||||
let src_off = h * head_dim;
|
||||
let dst_off = (h * seq_len + s) * head_dim;
|
||||
q_data[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(
|
||||
&row[2 * hidden + src_off..2 * hidden + src_off + head_dim],
|
||||
);
|
||||
}
|
||||
}
|
||||
let q =
|
||||
Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k =
|
||||
Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v =
|
||||
Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
(q, k, v)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = qkv_cpu.as_slice::<half::bf16>();
|
||||
let mut q_data = vec![half::bf16::ZERO; num_heads * seq_len * head_dim];
|
||||
let mut k_data = vec![half::bf16::ZERO; num_heads * seq_len * head_dim];
|
||||
let mut v_data = vec![half::bf16::ZERO; num_heads * seq_len * head_dim];
|
||||
for s in 0..seq_len {
|
||||
let row = &data[s * 3 * hidden..(s + 1) * 3 * hidden];
|
||||
for h in 0..num_heads {
|
||||
let src_off = h * head_dim;
|
||||
let dst_off = (h * seq_len + s) * head_dim;
|
||||
q_data[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(
|
||||
&row[2 * hidden + src_off..2 * hidden + src_off + head_dim],
|
||||
);
|
||||
}
|
||||
}
|
||||
let q =
|
||||
Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k =
|
||||
Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v =
|
||||
Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
(q, k, v)
|
||||
}
|
||||
_ => panic!("unsupported dtype {:?} in split_qkv", dtype),
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_heads(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor {
|
||||
let num_heads = x.shape()[1];
|
||||
let head_dim = x.shape()[3];
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let device = x.device();
|
||||
let dtype = x.dtype();
|
||||
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
let src = x_cpu.as_slice::<f32>();
|
||||
let mut out = vec![0.0f32; seq_len * hidden];
|
||||
for s in 0..seq_len {
|
||||
for h in 0..num_heads {
|
||||
let src_off = (h * seq_len + s) * head_dim;
|
||||
let dst_off = s * hidden + h * head_dim;
|
||||
out[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(device)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let src = x_cpu.as_slice::<half::bf16>();
|
||||
let mut out = vec![half::bf16::ZERO; seq_len * hidden];
|
||||
for s in 0..seq_len {
|
||||
for h in 0..num_heads {
|
||||
let src_off = (h * seq_len + s) * head_dim;
|
||||
let dst_off = s * hidden + h * head_dim;
|
||||
out[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(device)
|
||||
}
|
||||
_ => panic!("unsupported dtype {:?} in merge_heads", dtype),
|
||||
}
|
||||
}
|
||||
|
||||
/// Greedy sampling: return the argmax token ID from the last position's logits.
|
||||
pub fn sample_greedy(logits: &Tensor) -> u32 {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let data = logits_cpu.as_slice::<f32>();
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.map(|(idx, _)| idx as u32)
|
||||
.unwrap()
|
||||
}
|
||||
1046
crates/xserv-model/src/gpt_oss.rs
Normal file
1046
crates/xserv-model/src/gpt_oss.rs
Normal file
File diff suppressed because it is too large
Load Diff
195
crates/xserv-model/src/gpt_oss_graph.rs
Normal file
195
crates/xserv-model/src/gpt_oss_graph.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
//! CUDA-graph replay for gpt-oss batch=1 decode (Phase 21).
|
||||
//!
|
||||
//! A decode step launches ~200 kernels; with sparse MoE the GPU work is only
|
||||
//! a few ms, so launch overhead dominates TPOT. The whole step (embedding →
|
||||
//! 24 layers → logits) is captured ONCE into a CUDA graph and replayed per
|
||||
//! token with a single `cudaGraphLaunch`.
|
||||
//!
|
||||
//! Why the existing forward is capturable as-is:
|
||||
//! - Every per-step variable input lives in a stable-address device buffer
|
||||
//! whose CONTENTS are updated outside the captured region: token id and
|
||||
//! position (persistent buffers owned here), block table and context lens
|
||||
//! (PagedKVCache GPU buffers, refreshed by `decode_prepare`). The KV scatter
|
||||
//! and paged attention kernels read their write/read positions from those
|
||||
//! buffers, and the sparse-MoE GEMVs read expert ids from `topk_ids` written
|
||||
//! earlier in the same graph — all data-dependent, no host branching.
|
||||
//! - Kernel launches go through the thread-local launch stream
|
||||
//! (`xserv_cuda::stream::push_stream`), so the capture stream sees them.
|
||||
//! - Intermediate tensors come from the caching allocator. Blocks freed while
|
||||
//! capturing are quarantined (`allocator::begin_retain`) for the graph's
|
||||
//! lifetime so no later allocation can take ownership of memory the graph
|
||||
//! still references on every replay.
|
||||
//!
|
||||
//! Capture preconditions: at least one EAGER decode step must have run first,
|
||||
//! so the allocator pool already holds every bucket size the step needs
|
||||
//! (a pool-miss inside capture would call cudaMalloc — illegal while
|
||||
//! capturing) and cuBLAS has finished its one-time per-shape setup.
|
||||
|
||||
use std::ffi::c_void;
|
||||
|
||||
use xserv_cuda::allocator::{self, RetainedBlocks};
|
||||
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
|
||||
use xserv_tensor::Tensor;
|
||||
|
||||
use crate::gpt_oss::GptOss;
|
||||
use crate::paged_kv_cache::PagedKVCache;
|
||||
|
||||
pub struct GptOssDecodeGraph {
|
||||
stream: CudaStream,
|
||||
graph: CudaGraph,
|
||||
ids_buf: GpuBuffer, // [1] u32, persistent graph input
|
||||
pos_buf: GpuBuffer, // [1] u32, persistent graph input
|
||||
logits: Tensor, // graph output; rewritten in place by every replay
|
||||
_arena: RetainedBlocks,
|
||||
}
|
||||
|
||||
impl GptOssDecodeGraph {
|
||||
/// Capture one batch=1 decode step and replay it once (capture records
|
||||
/// without executing, so the replay performs this token's computation).
|
||||
pub fn capture(
|
||||
model: &GptOss,
|
||||
token: u32,
|
||||
position: usize,
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Self {
|
||||
let stream = CudaStream::new().expect("create capture stream");
|
||||
let mut ids_buf = allocator::cached_alloc(4).expect("alloc ids buf");
|
||||
let mut pos_buf = allocator::cached_alloc(4).expect("alloc pos buf");
|
||||
|
||||
model.decode_prepare(&[position], &[slot], cache);
|
||||
ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
|
||||
pos_buf
|
||||
.copy_from_host(&(position as u32).to_le_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Retained warmup: run the exact step once eagerly with the quarantine
|
||||
// ON. Freed intermediates are held back instead of recycled, so the
|
||||
// pool ends up stocked with a dedicated block for EVERY allocation the
|
||||
// step performs. The capture below repeats the same allocation
|
||||
// sequence and therefore never misses the pool — a pool miss would
|
||||
// call cudaMalloc, which is illegal while a stream is capturing (this
|
||||
// is also why one block per bucket is not enough: the capture's own
|
||||
// quarantine keeps freed blocks out of reuse). Re-running the step is
|
||||
// idempotent: the KV scatter rewrites the same cache position.
|
||||
allocator::begin_retain();
|
||||
{
|
||||
let _guard = xserv_cuda::push_stream(&stream);
|
||||
let _ = model.decode_core(
|
||||
ids_buf.as_ptr() as *const c_void,
|
||||
pos_buf.as_ptr() as *const c_void,
|
||||
1,
|
||||
cache,
|
||||
);
|
||||
}
|
||||
drop(allocator::end_retain()); // release the warmup blocks to the pool
|
||||
stream.synchronize().expect("warmup sync");
|
||||
|
||||
allocator::begin_retain();
|
||||
let mut graph = CudaGraph::new();
|
||||
let logits;
|
||||
{
|
||||
let _guard = xserv_cuda::stream::push_stream(&stream);
|
||||
graph
|
||||
.begin_capture(&stream)
|
||||
.expect("begin decode-graph capture");
|
||||
logits = model.decode_core(
|
||||
ids_buf.as_ptr() as *const c_void,
|
||||
pos_buf.as_ptr() as *const c_void,
|
||||
1,
|
||||
cache,
|
||||
);
|
||||
graph
|
||||
.end_capture(&stream)
|
||||
.expect("end decode-graph capture");
|
||||
}
|
||||
let arena = allocator::end_retain();
|
||||
|
||||
graph.launch(&stream).expect("first decode-graph replay");
|
||||
cache.advance_seq_len(slot, 1);
|
||||
|
||||
Self {
|
||||
stream,
|
||||
graph,
|
||||
ids_buf,
|
||||
pos_buf,
|
||||
logits,
|
||||
_arena: arena,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run one decode step by replaying the captured graph.
|
||||
pub fn step(
|
||||
&mut self,
|
||||
model: &GptOss,
|
||||
token: u32,
|
||||
position: usize,
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
model.decode_prepare(&[position], &[slot], cache);
|
||||
self.ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
|
||||
self.pos_buf
|
||||
.copy_from_host(&(position as u32).to_le_bytes())
|
||||
.unwrap();
|
||||
self.graph
|
||||
.launch(&self.stream)
|
||||
.expect("decode-graph replay");
|
||||
cache.advance_seq_len(slot, 1);
|
||||
// Shallow clone: the caller reads these logits before the next replay
|
||||
// rewrites the underlying buffer.
|
||||
self.logits.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Lazy capture policy: first decode step of the process runs eager (warms the
|
||||
/// allocator pool + cuBLAS so capture performs no "unsafe" CUDA calls), the
|
||||
/// second is captured, the rest replay. Batch>1 always falls back to eager.
|
||||
/// Disable with XSERV_DECODE_GRAPH=0.
|
||||
pub struct GraphedGptOssDecoder {
|
||||
graph: Option<GptOssDecodeGraph>,
|
||||
eager_steps: u32,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl GraphedGptOssDecoder {
|
||||
pub fn new() -> Self {
|
||||
let enabled = std::env::var("XSERV_DECODE_GRAPH")
|
||||
.map(|v| v != "0")
|
||||
.unwrap_or(true);
|
||||
Self {
|
||||
graph: None,
|
||||
eager_steps: 0,
|
||||
enabled,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode(
|
||||
&mut self,
|
||||
model: &GptOss,
|
||||
tokens: &[u32],
|
||||
positions: &[usize],
|
||||
slots: &[usize],
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
if self.enabled && tokens.len() == 1 {
|
||||
if let Some(g) = self.graph.as_mut() {
|
||||
return g.step(model, tokens[0], positions[0], slots[0], cache);
|
||||
}
|
||||
if self.eager_steps >= 1 {
|
||||
let g = GptOssDecodeGraph::capture(model, tokens[0], positions[0], slots[0], cache);
|
||||
let logits = g.logits.clone();
|
||||
self.graph = Some(g);
|
||||
return logits;
|
||||
}
|
||||
}
|
||||
self.eager_steps += 1;
|
||||
model.forward_decode_paged(tokens, positions, slots, cache)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GraphedGptOssDecoder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
214
crates/xserv-model/src/kv_cache.rs
Normal file
214
crates/xserv-model/src/kv_cache.rs
Normal file
@@ -0,0 +1,214 @@
|
||||
use crate::config::ModelConfig;
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Tensor};
|
||||
|
||||
/// GPU-resident KV cache. Pre-allocates max_seq_len on GPU,
|
||||
/// appends new K/V via D2D copy at offset (no CPU round-trip).
|
||||
pub struct GpuKVCache {
|
||||
// Per layer: contiguous GPU buffer for K and V
|
||||
// Layout: [num_kv_heads, max_seq_len, head_dim] — contiguous per head
|
||||
k_bufs: Vec<GpuBuffer>,
|
||||
v_bufs: Vec<GpuBuffer>,
|
||||
// Per layer: pre-allocated staging buffers for get_kv_len output.
|
||||
// Size: num_kv_heads * max_seq_len * head_dim * elem_size (max possible output).
|
||||
// Avoids cudaMalloc/cudaFree on every get_kv_len call.
|
||||
k_staging: Vec<GpuBuffer>,
|
||||
v_staging: Vec<GpuBuffer>,
|
||||
seq_len: usize,
|
||||
max_seq_len: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
elem_size: usize,
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
}
|
||||
|
||||
impl GpuKVCache {
|
||||
pub fn new(config: &ModelConfig, max_seq_len: usize, dtype: DType, device: u32) -> Self {
|
||||
let num_layers = config.num_layers();
|
||||
let num_kv_heads = config.num_kv_heads();
|
||||
let head_dim = config.head_dim();
|
||||
let elem_size = dtype.size_bytes();
|
||||
let buf_size = num_kv_heads * max_seq_len * head_dim * elem_size;
|
||||
|
||||
let mut k_bufs = Vec::with_capacity(num_layers);
|
||||
let mut v_bufs = Vec::with_capacity(num_layers);
|
||||
let mut k_staging = Vec::with_capacity(num_layers);
|
||||
let mut v_staging = Vec::with_capacity(num_layers);
|
||||
for _ in 0..num_layers {
|
||||
let mut k = GpuBuffer::alloc(buf_size).expect("alloc KV cache K");
|
||||
let mut v = GpuBuffer::alloc(buf_size).expect("alloc KV cache V");
|
||||
k.zero().unwrap();
|
||||
v.zero().unwrap();
|
||||
k_bufs.push(k);
|
||||
v_bufs.push(v);
|
||||
k_staging.push(GpuBuffer::alloc(buf_size).expect("alloc KV staging K"));
|
||||
v_staging.push(GpuBuffer::alloc(buf_size).expect("alloc KV staging V"));
|
||||
}
|
||||
|
||||
Self {
|
||||
k_bufs,
|
||||
v_bufs,
|
||||
k_staging,
|
||||
v_staging,
|
||||
seq_len: 0,
|
||||
max_seq_len,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
elem_size,
|
||||
dtype,
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn seq_len(&self) -> usize {
|
||||
self.seq_len
|
||||
}
|
||||
pub fn max_seq_len(&self) -> usize {
|
||||
self.max_seq_len
|
||||
}
|
||||
|
||||
/// Append new K/V tensors for a given layer.
|
||||
/// k_new, v_new: [1, num_kv_heads, new_tokens, head_dim] on GPU, contiguous.
|
||||
/// `write_pos` is the sequence position to write at (caller manages this).
|
||||
pub fn append(
|
||||
&mut self,
|
||||
layer: usize,
|
||||
k_new: &Tensor,
|
||||
v_new: &Tensor,
|
||||
new_tokens: usize,
|
||||
write_pos: usize,
|
||||
) {
|
||||
assert!(
|
||||
write_pos + new_tokens <= self.max_seq_len,
|
||||
"KV cache overflow"
|
||||
);
|
||||
let es = self.elem_size;
|
||||
let hd = self.head_dim;
|
||||
let max_s = self.max_seq_len;
|
||||
let nh = self.num_kv_heads;
|
||||
|
||||
let k_src = k_new.storage().gpu_buffer();
|
||||
let v_src = v_new.storage().gpu_buffer();
|
||||
|
||||
for h in 0..nh {
|
||||
let src_off = h * new_tokens * hd * es;
|
||||
let dst_off = (h * max_s + write_pos) * hd * es;
|
||||
let count = new_tokens * hd * es;
|
||||
self.k_bufs[layer]
|
||||
.copy_from_device_at(k_src, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
self.v_bufs[layer]
|
||||
.copy_from_device_at(v_src, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn advance_seq_len(&mut self, new_tokens: usize) {
|
||||
self.seq_len += new_tokens;
|
||||
assert!(
|
||||
self.seq_len <= self.max_seq_len,
|
||||
"KV cache seq_len ({}) exceeds max_seq_len ({})",
|
||||
self.seq_len,
|
||||
self.max_seq_len
|
||||
);
|
||||
}
|
||||
|
||||
/// Get K/V cache tensors for a layer up to `seq_len` tokens: [1, num_kv_heads, seq_len, head_dim]
|
||||
pub fn get_kv(&mut self, layer: usize) -> (Tensor, Tensor) {
|
||||
let sl = self.seq_len;
|
||||
self.get_kv_len(layer, sl)
|
||||
}
|
||||
|
||||
pub fn get_kv_len(&mut self, layer: usize, sl: usize) -> (Tensor, Tensor) {
|
||||
assert!(
|
||||
sl <= self.max_seq_len,
|
||||
"get_kv_len: sl ({sl}) exceeds max_seq_len ({})",
|
||||
self.max_seq_len
|
||||
);
|
||||
let hd = self.head_dim;
|
||||
let nh = self.num_kv_heads;
|
||||
let es = self.elem_size;
|
||||
let max_s = self.max_seq_len;
|
||||
|
||||
// Copy each head's valid portion into pre-allocated staging buffers.
|
||||
// Split borrows: staging (mut) vs cache (shared) are separate struct fields,
|
||||
// so the borrow checker allows simultaneous &mut staging + &cache.
|
||||
let out_size = nh * sl * hd * es;
|
||||
let k_stg = &mut self.k_staging[layer];
|
||||
let k_buf = &self.k_bufs[layer];
|
||||
let v_stg = &mut self.v_staging[layer];
|
||||
let v_buf = &self.v_bufs[layer];
|
||||
for h in 0..nh {
|
||||
let src_off = (h * max_s) * hd * es;
|
||||
let dst_off = (h * sl) * hd * es;
|
||||
let count = sl * hd * es;
|
||||
k_stg
|
||||
.copy_from_device_at(k_buf, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
v_stg
|
||||
.copy_from_device_at(v_buf, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
}
|
||||
// Grab raw pointers before dropping the mutable borrows
|
||||
let k_ptr = k_stg.as_mut_ptr();
|
||||
let v_ptr = v_stg.as_mut_ptr();
|
||||
|
||||
// Create Tensors that borrow from the staging buffers (no cudaMalloc/cudaFree).
|
||||
// Safety: staging buffers are owned by GpuKVCache and outlive the returned Tensors
|
||||
// in practice (Tensors are consumed within the same forward pass before the next
|
||||
// get_kv_len call overwrites the staging buffer).
|
||||
let shape = &[1usize, nh, sl, hd];
|
||||
let k = unsafe {
|
||||
tensor_from_gpu_buffer(
|
||||
GpuBuffer::borrow_raw(k_ptr, out_size),
|
||||
shape,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
};
|
||||
let v = unsafe {
|
||||
tensor_from_gpu_buffer(
|
||||
GpuBuffer::borrow_raw(v_ptr, out_size),
|
||||
shape,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
};
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Tensor from a GpuBuffer (takes ownership).
|
||||
unsafe fn tensor_from_gpu_buffer(
|
||||
buf: GpuBuffer,
|
||||
shape: &[usize],
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Tensor {
|
||||
use smallvec::SmallVec;
|
||||
use xserv_tensor::shape::contiguous_strides;
|
||||
use xserv_tensor::storage::Storage;
|
||||
|
||||
let storage = Storage::cuda(buf, device);
|
||||
Tensor::from_storage(
|
||||
storage,
|
||||
SmallVec::from_slice(shape),
|
||||
contiguous_strides(shape),
|
||||
0,
|
||||
dtype,
|
||||
)
|
||||
}
|
||||
|
||||
/// Public version for use by other modules (e.g., batched decode concat).
|
||||
///
|
||||
/// # Safety
|
||||
/// `buf` must be a valid GPU allocation with at least `product(shape) * dtype.size_bytes()` bytes.
|
||||
pub unsafe fn tensor_from_gpu_buffer_pub(
|
||||
buf: GpuBuffer,
|
||||
shape: &[usize],
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Tensor {
|
||||
tensor_from_gpu_buffer(buf, shape, dtype, device)
|
||||
}
|
||||
28
crates/xserv-model/src/lib.rs
Normal file
28
crates/xserv-model/src/lib.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
pub mod config;
|
||||
pub mod decode_graph;
|
||||
pub mod eagle3;
|
||||
pub mod gpt2;
|
||||
pub mod gpt_oss;
|
||||
pub mod gpt_oss_graph;
|
||||
pub mod kv_cache;
|
||||
pub mod loader;
|
||||
pub mod paged_kv_cache;
|
||||
pub mod qwen3;
|
||||
pub mod qwen3_graph;
|
||||
pub mod sampling;
|
||||
|
||||
pub use config::ModelConfig;
|
||||
pub use decode_graph::{DecodeGraphState, LayerWeightPtrs};
|
||||
pub use gpt_oss::GptOss;
|
||||
pub use gpt_oss_graph::{GptOssDecodeGraph, GraphedGptOssDecoder};
|
||||
pub use gpt2::{GPT2, KVCache};
|
||||
pub use kv_cache::GpuKVCache;
|
||||
pub use paged_kv_cache::{BLOCK_SIZE, BlockAllocator, Location, PagedKVCache};
|
||||
pub use qwen3::Qwen3;
|
||||
pub use sampling::{SamplingParams, sample, sample_greedy_penalized};
|
||||
|
||||
/// Initialize GPU kernel hooks. Called automatically by model constructors,
|
||||
/// but safe to call multiple times (idempotent via OnceLock).
|
||||
pub fn init_kernels() {
|
||||
xserv_kernels::init();
|
||||
}
|
||||
93
crates/xserv-model/src/loader.rs
Normal file
93
crates/xserv-model/src/loader.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use half::{bf16, f16};
|
||||
use safetensors::SafeTensors;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
pub fn load_safetensors(path: &Path, device: Device) -> HashMap<String, Tensor> {
|
||||
let data =
|
||||
std::fs::read(path).unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
|
||||
let st = SafeTensors::deserialize(&data)
|
||||
.unwrap_or_else(|e| panic!("failed to parse safetensors {}: {e}", path.display()));
|
||||
|
||||
let mut tensors = HashMap::new();
|
||||
|
||||
for (name, view) in st.tensors() {
|
||||
let shape: Vec<usize> = view.shape().to_vec();
|
||||
let raw_bytes = view.data();
|
||||
let dtype = match view.dtype() {
|
||||
safetensors::Dtype::F32 => DType::F32,
|
||||
safetensors::Dtype::F16 => DType::F16,
|
||||
safetensors::Dtype::BF16 => DType::BF16,
|
||||
safetensors::Dtype::F8_E4M3 => DType::FP8E4M3,
|
||||
other => {
|
||||
eprintln!("skipping tensor {name}: unsupported dtype {other:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let tensor = make_tensor(raw_bytes, &shape, dtype);
|
||||
let tensor = tensor.to_device(device);
|
||||
tensors.insert(name.to_string(), tensor);
|
||||
}
|
||||
|
||||
tensors
|
||||
}
|
||||
|
||||
/// Load from a directory containing model.safetensors (or sharded files) + config.json.
|
||||
pub fn load_model_dir(dir: &Path, device: Device) -> HashMap<String, Tensor> {
|
||||
let single = dir.join("model.safetensors");
|
||||
if single.exists() {
|
||||
return load_safetensors(&single, device);
|
||||
}
|
||||
|
||||
// Try sharded: model-00001-of-NNNNN.safetensors
|
||||
let mut all_tensors = HashMap::new();
|
||||
let mut entries: Vec<_> = std::fs::read_dir(dir)
|
||||
.unwrap()
|
||||
.filter_map(|e| e.ok())
|
||||
.filter(|e| {
|
||||
e.path()
|
||||
.file_name()
|
||||
.map(|f| f.to_string_lossy().ends_with(".safetensors"))
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect();
|
||||
entries.sort_by_key(|e| e.file_name());
|
||||
|
||||
for entry in entries {
|
||||
let tensors = load_safetensors(&entry.path(), device);
|
||||
all_tensors.extend(tensors);
|
||||
}
|
||||
|
||||
assert!(
|
||||
!all_tensors.is_empty(),
|
||||
"no safetensors files found in {}",
|
||||
dir.display()
|
||||
);
|
||||
all_tensors
|
||||
}
|
||||
|
||||
pub(crate) fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
let floats: &[f32] = unsafe {
|
||||
std::slice::from_raw_parts(raw_bytes.as_ptr() as *const f32, raw_bytes.len() / 4)
|
||||
};
|
||||
Tensor::from_slice(floats, shape)
|
||||
}
|
||||
DType::F16 => {
|
||||
let halfs: &[f16] = unsafe {
|
||||
std::slice::from_raw_parts(raw_bytes.as_ptr() as *const f16, raw_bytes.len() / 2)
|
||||
};
|
||||
Tensor::from_slice(halfs, shape)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let bfs: &[bf16] = unsafe {
|
||||
std::slice::from_raw_parts(raw_bytes.as_ptr() as *const bf16, raw_bytes.len() / 2)
|
||||
};
|
||||
Tensor::from_slice(bfs, shape)
|
||||
}
|
||||
DType::FP8E4M3 => Tensor::from_raw_bytes(raw_bytes, shape, DType::FP8E4M3),
|
||||
}
|
||||
}
|
||||
908
crates/xserv-model/src/paged_kv_cache.rs
Normal file
908
crates/xserv-model/src/paged_kv_cache.rs
Normal file
@@ -0,0 +1,908 @@
|
||||
//! Paged KV cache: vLLM-style block-based KV cache with O(1) allocation
|
||||
//! and indirection via per-sequence block tables.
|
||||
//!
|
||||
//! Physical layout per layer:
|
||||
//! K pool: [total_blocks, num_kv_heads, BLOCK_SIZE, head_dim] BF16
|
||||
//! V pool: same
|
||||
//!
|
||||
//! Logical view per sequence: a list of physical block ids. Token at logical
|
||||
//! position p lives in block_ids[p / BLOCK_SIZE] at slot (p % BLOCK_SIZE).
|
||||
|
||||
use crate::config::ModelConfig;
|
||||
use xserv_cuda::{GpuBuffer, PinnedBuffer};
|
||||
use xserv_tensor::{DType, Tensor};
|
||||
|
||||
pub const BLOCK_SIZE: usize = 16;
|
||||
|
||||
/// Stack-based block allocator: O(1) alloc/free.
|
||||
pub struct BlockAllocator {
|
||||
free_stack: Vec<u32>,
|
||||
total: usize,
|
||||
}
|
||||
|
||||
impl BlockAllocator {
|
||||
pub fn new(total_blocks: usize) -> Self {
|
||||
// Reserve block 0 as a sentinel "null" block (never allocated).
|
||||
// Free list contains [total-1, total-2, ..., 1] so pop returns 1 first.
|
||||
// total_blocks==0 means "disabled" (e.g. swap off): empty free list.
|
||||
let mut free_stack = Vec::with_capacity(total_blocks.saturating_sub(1));
|
||||
for b in (1..total_blocks).rev() {
|
||||
free_stack.push(b as u32);
|
||||
}
|
||||
Self {
|
||||
free_stack,
|
||||
total: total_blocks,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn alloc(&mut self) -> Option<u32> {
|
||||
self.free_stack.pop()
|
||||
}
|
||||
|
||||
pub fn free(&mut self, block: u32) {
|
||||
debug_assert!((block as usize) < self.total && block != 0);
|
||||
self.free_stack.push(block);
|
||||
}
|
||||
|
||||
pub fn free_count(&self) -> usize {
|
||||
self.free_stack.len()
|
||||
}
|
||||
|
||||
pub fn total(&self) -> usize {
|
||||
self.total
|
||||
}
|
||||
|
||||
pub fn can_alloc(&self, n: usize) -> bool {
|
||||
self.free_stack.len() >= n
|
||||
}
|
||||
}
|
||||
|
||||
/// Where a sequence's KV blocks currently live.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub enum Location {
|
||||
Gpu,
|
||||
Cpu,
|
||||
}
|
||||
|
||||
/// Per-sequence state held in the cache.
|
||||
#[derive(Clone)]
|
||||
pub struct SeqState {
|
||||
/// Block ids into the GPU pool when `location == Gpu`, or into the CPU
|
||||
/// (pinned host) pool when `location == Cpu`.
|
||||
pub block_ids: Vec<u32>,
|
||||
pub seq_len: usize,
|
||||
pub location: Location,
|
||||
}
|
||||
|
||||
pub struct PagedKVCache {
|
||||
// [layer]: GpuBuffer of size total_blocks * nkv * BLOCK_SIZE * hd * elem_size
|
||||
k_pools: Vec<GpuBuffer>,
|
||||
v_pools: Vec<GpuBuffer>,
|
||||
|
||||
// CPU (pinned host) swap pools, same per-layer layout as the GPU pools but
|
||||
// sized for `cpu_total_blocks`. Empty when swap is disabled.
|
||||
cpu_k_pools: Vec<PinnedBuffer>,
|
||||
cpu_v_pools: Vec<PinnedBuffer>,
|
||||
cpu_allocator: BlockAllocator,
|
||||
|
||||
// Bytes occupied by one block within a single layer pool:
|
||||
// num_kv_heads * BLOCK_SIZE * head_dim * elem_size.
|
||||
block_bytes: usize,
|
||||
|
||||
allocator: BlockAllocator,
|
||||
seq_states: Vec<Option<SeqState>>,
|
||||
|
||||
// GPU-resident per-sequence metadata. Uploaded each step via sync_to_gpu().
|
||||
// block_table_gpu: i32 [max_seqs, max_blocks_per_seq]
|
||||
// context_lens_gpu: i32 [max_seqs]
|
||||
block_table_gpu: GpuBuffer,
|
||||
context_lens_gpu: GpuBuffer,
|
||||
// Host-side staging mirroring the GPU buffers above.
|
||||
block_table_host: Vec<i32>,
|
||||
context_lens_host: Vec<i32>,
|
||||
|
||||
// Config
|
||||
num_layers: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
elem_size: usize,
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
max_seqs: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
}
|
||||
|
||||
impl PagedKVCache {
|
||||
/// Bytes occupied by all KV blocks for ONE physical block across the whole
|
||||
/// model (both K and V, all layers). Use this to size pools against VRAM.
|
||||
pub fn bytes_per_block(config: &ModelConfig, dtype: DType) -> usize {
|
||||
2 * config.num_layers()
|
||||
* config.num_kv_heads()
|
||||
* BLOCK_SIZE
|
||||
* config.head_dim()
|
||||
* dtype.size_bytes()
|
||||
}
|
||||
|
||||
/// Create a new paged cache.
|
||||
/// - `total_blocks`: total number of physical GPU blocks across all sequences.
|
||||
/// - `cpu_total_blocks`: physical blocks in the pinned-host swap pool (0 = swap off).
|
||||
/// - `max_seqs`: max number of concurrent sequences (slots), incl. swapped.
|
||||
/// - `max_blocks_per_seq`: capacity of the block table per slot
|
||||
/// (must be >= ceil(max_seq_len / BLOCK_SIZE)).
|
||||
pub fn new(
|
||||
config: &ModelConfig,
|
||||
total_blocks: usize,
|
||||
cpu_total_blocks: usize,
|
||||
max_seqs: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Self {
|
||||
Self::new_tp(
|
||||
config,
|
||||
config.num_kv_heads(),
|
||||
total_blocks,
|
||||
cpu_total_blocks,
|
||||
max_seqs,
|
||||
max_blocks_per_seq,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
}
|
||||
|
||||
/// Like `new`, but with an explicit `num_kv_heads` — under tensor parallelism
|
||||
/// each rank only stores its `num_kv_heads / world` heads, so the pool is
|
||||
/// sized for the local head count, not the model's full count.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new_tp(
|
||||
config: &ModelConfig,
|
||||
num_kv_heads: usize,
|
||||
total_blocks: usize,
|
||||
cpu_total_blocks: usize,
|
||||
max_seqs: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Self {
|
||||
assert!(
|
||||
total_blocks >= 2,
|
||||
"need at least 2 blocks (one is sentinel)"
|
||||
);
|
||||
let num_layers = config.num_layers();
|
||||
let head_dim = config.head_dim();
|
||||
let elem_size = dtype.size_bytes();
|
||||
let block_bytes = num_kv_heads * BLOCK_SIZE * head_dim * elem_size;
|
||||
let pool_bytes = total_blocks * block_bytes;
|
||||
|
||||
let mut k_pools = Vec::with_capacity(num_layers);
|
||||
let mut v_pools = Vec::with_capacity(num_layers);
|
||||
for _ in 0..num_layers {
|
||||
let mut k = GpuBuffer::alloc(pool_bytes).expect("alloc paged K pool");
|
||||
let mut v = GpuBuffer::alloc(pool_bytes).expect("alloc paged V pool");
|
||||
k.zero().unwrap();
|
||||
v.zero().unwrap();
|
||||
k_pools.push(k);
|
||||
v_pools.push(v);
|
||||
}
|
||||
|
||||
// Pinned-host swap pools (one per layer, mirroring the GPU layout).
|
||||
let mut cpu_k_pools = Vec::new();
|
||||
let mut cpu_v_pools = Vec::new();
|
||||
if cpu_total_blocks >= 2 {
|
||||
let cpu_pool_bytes = cpu_total_blocks * block_bytes;
|
||||
for _ in 0..num_layers {
|
||||
cpu_k_pools
|
||||
.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU K swap pool"));
|
||||
cpu_v_pools
|
||||
.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU V swap pool"));
|
||||
}
|
||||
}
|
||||
let cpu_allocator = BlockAllocator::new(if cpu_total_blocks >= 2 {
|
||||
cpu_total_blocks
|
||||
} else {
|
||||
0
|
||||
});
|
||||
|
||||
let block_table_gpu =
|
||||
GpuBuffer::alloc(max_seqs * max_blocks_per_seq * std::mem::size_of::<i32>())
|
||||
.expect("alloc block table");
|
||||
let context_lens_gpu =
|
||||
GpuBuffer::alloc(max_seqs * std::mem::size_of::<i32>()).expect("alloc context lens");
|
||||
|
||||
let block_table_host = vec![0i32; max_seqs * max_blocks_per_seq];
|
||||
let context_lens_host = vec![0i32; max_seqs];
|
||||
|
||||
let seq_states = (0..max_seqs).map(|_| None).collect();
|
||||
|
||||
Self {
|
||||
k_pools,
|
||||
v_pools,
|
||||
cpu_k_pools,
|
||||
cpu_v_pools,
|
||||
cpu_allocator,
|
||||
block_bytes,
|
||||
allocator: BlockAllocator::new(total_blocks),
|
||||
seq_states,
|
||||
block_table_gpu,
|
||||
context_lens_gpu,
|
||||
block_table_host,
|
||||
context_lens_host,
|
||||
num_layers,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
elem_size,
|
||||
dtype,
|
||||
device,
|
||||
max_seqs,
|
||||
max_blocks_per_seq,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_layers(&self) -> usize {
|
||||
self.num_layers
|
||||
}
|
||||
pub fn num_kv_heads(&self) -> usize {
|
||||
self.num_kv_heads
|
||||
}
|
||||
pub fn head_dim(&self) -> usize {
|
||||
self.head_dim
|
||||
}
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
pub fn max_seqs(&self) -> usize {
|
||||
self.max_seqs
|
||||
}
|
||||
pub fn max_blocks_per_seq(&self) -> usize {
|
||||
self.max_blocks_per_seq
|
||||
}
|
||||
pub fn free_blocks(&self) -> usize {
|
||||
self.allocator.free_count()
|
||||
}
|
||||
pub fn total_blocks(&self) -> usize {
|
||||
self.allocator.total()
|
||||
}
|
||||
|
||||
pub fn k_pool(&self, layer: usize) -> &GpuBuffer {
|
||||
&self.k_pools[layer]
|
||||
}
|
||||
pub fn v_pool(&self, layer: usize) -> &GpuBuffer {
|
||||
&self.v_pools[layer]
|
||||
}
|
||||
pub fn block_table_gpu(&self) -> &GpuBuffer {
|
||||
&self.block_table_gpu
|
||||
}
|
||||
pub fn context_lens_gpu(&self) -> &GpuBuffer {
|
||||
&self.context_lens_gpu
|
||||
}
|
||||
|
||||
pub fn seq_len(&self, slot: usize) -> usize {
|
||||
self.seq_states[slot]
|
||||
.as_ref()
|
||||
.map(|s| s.seq_len)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn is_slot_free(&self, slot: usize) -> bool {
|
||||
self.seq_states[slot].is_none()
|
||||
}
|
||||
|
||||
/// Register a new sequence at `slot`. Allocates the first block.
|
||||
/// Returns Err(()) if no slot or no blocks are available.
|
||||
pub fn register_sequence(&mut self, slot: usize) -> Result<(), &'static str> {
|
||||
if slot >= self.max_seqs {
|
||||
return Err("slot out of range");
|
||||
}
|
||||
if self.seq_states[slot].is_some() {
|
||||
return Err("slot already in use");
|
||||
}
|
||||
let block = self.allocator.alloc().ok_or("out of blocks")?;
|
||||
self.seq_states[slot] = Some(SeqState {
|
||||
block_ids: vec![block],
|
||||
seq_len: 0,
|
||||
location: Location::Gpu,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Free all blocks for `slot` and clear the slot. Frees from whichever pool
|
||||
/// (GPU or CPU) the sequence currently lives in.
|
||||
pub fn free_sequence(&mut self, slot: usize) {
|
||||
if let Some(state) = self.seq_states[slot].take() {
|
||||
let alloc = match state.location {
|
||||
Location::Gpu => &mut self.allocator,
|
||||
Location::Cpu => &mut self.cpu_allocator,
|
||||
};
|
||||
for b in state.block_ids {
|
||||
alloc.free(b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of blocks needed to hold `seq_len + new_tokens` tokens, beyond
|
||||
/// what is currently allocated for `slot`.
|
||||
pub fn additional_blocks_needed(&self, slot: usize, new_tokens: usize) -> usize {
|
||||
let state = self.seq_states[slot].as_ref().expect("unregistered slot");
|
||||
let cur = state.block_ids.len();
|
||||
let needed_total = (state.seq_len + new_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
if needed_total > cur {
|
||||
needed_total - cur
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-allocate enough physical blocks in `slot` to cover positions
|
||||
/// `[0, end_pos)`. Call once before the per-layer append loop so that
|
||||
/// every layer's append uses the same block table.
|
||||
pub fn ensure_capacity(&mut self, slot: usize, end_pos: usize) {
|
||||
let state = self.seq_states[slot].as_mut().expect("unregistered slot");
|
||||
let needed_total = (end_pos + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
while state.block_ids.len() < needed_total {
|
||||
let b = self
|
||||
.allocator
|
||||
.alloc()
|
||||
.expect("out of blocks (caller must check)");
|
||||
assert!(
|
||||
state.block_ids.len() < self.max_blocks_per_seq,
|
||||
"block table overflow"
|
||||
);
|
||||
state.block_ids.push(b);
|
||||
}
|
||||
}
|
||||
|
||||
/// Append `num_tokens` of K/V into the paged pool for `slot` at logical
|
||||
/// position `start_pos`. Caller must have called `ensure_capacity(slot, start_pos + num_tokens)`
|
||||
/// first (or accept that this method may also extend block list).
|
||||
/// Does NOT touch `seq_len`. Call `advance_seq_len(slot, num_tokens)` after
|
||||
/// every layer has been written.
|
||||
///
|
||||
/// `k_new`, `v_new`: GPU tensors with logical shape
|
||||
/// [1, num_kv_heads, num_tokens, head_dim]
|
||||
/// stored contiguously (head-major, then tokens, then dim).
|
||||
///
|
||||
/// Implementation: a single `reshape_and_cache` kernel per call. The
|
||||
/// previous Rust loop fired `num_tokens * num_kv_heads` cudaMemcpys per
|
||||
/// layer (≈290k for a 1024-token Qwen3 prefill across 36 layers).
|
||||
pub fn append_tokens(
|
||||
&mut self,
|
||||
slot: usize,
|
||||
layer: usize,
|
||||
k_new: &Tensor,
|
||||
v_new: &Tensor,
|
||||
num_tokens: usize,
|
||||
start_pos: usize,
|
||||
) {
|
||||
if num_tokens == 0 {
|
||||
return;
|
||||
}
|
||||
// Make sure blocks exist for the target range.
|
||||
self.ensure_capacity(slot, start_pos + num_tokens);
|
||||
|
||||
let nkv = self.num_kv_heads;
|
||||
let hd = self.head_dim;
|
||||
let bs = BLOCK_SIZE;
|
||||
|
||||
// Stage block_ids on the GPU. Pool-allocated so this is essentially
|
||||
// free after the first call (same bucket every step).
|
||||
let block_ids: Vec<i32> = self.seq_states[slot]
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.block_ids
|
||||
.iter()
|
||||
.map(|&b| b as i32)
|
||||
.collect();
|
||||
let bytes = block_ids.len() * std::mem::size_of::<i32>();
|
||||
let mut block_ids_gpu =
|
||||
xserv_cuda::allocator::cached_alloc(bytes).expect("alloc append block_ids");
|
||||
let block_ids_bytes =
|
||||
unsafe { std::slice::from_raw_parts(block_ids.as_ptr() as *const u8, bytes) };
|
||||
block_ids_gpu
|
||||
.copy_from_host(block_ids_bytes)
|
||||
.expect("upload block_ids");
|
||||
|
||||
let k_src = k_new.data_ptr() as *const std::ffi::c_void;
|
||||
let v_src = v_new.data_ptr() as *const std::ffi::c_void;
|
||||
let k_pool_ptr = self.k_pools[layer].as_mut_ptr() as *mut std::ffi::c_void;
|
||||
let v_pool_ptr = self.v_pools[layer].as_mut_ptr() as *mut std::ffi::c_void;
|
||||
|
||||
unsafe {
|
||||
xserv_kernels::reshape_and_cache_bf16(
|
||||
k_src,
|
||||
v_src,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
block_ids_gpu.as_ptr() as *const i32,
|
||||
num_tokens,
|
||||
nkv,
|
||||
hd,
|
||||
start_pos,
|
||||
bs,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
// block_ids_gpu drops here; the launch on the null stream will have
|
||||
// finished consuming it before any subsequent op alloc()s the same
|
||||
// bucket (null stream is sequential).
|
||||
}
|
||||
|
||||
/// Batched append for the multi-sequence decode step: writes one new
|
||||
/// K/V token per active sequence into `layer`'s pool, using
|
||||
/// `block_table_gpu` and `context_lens_gpu` directly. Caller must have
|
||||
/// just run `sync_active_batch_with_lens(slots, kv_lens)` so that:
|
||||
/// - row `i` of block_table_gpu holds the block ids for `slots[i]`
|
||||
/// - context_lens_gpu[i] == seq_len(slots[i]) + 1 (the kv_len **after**
|
||||
/// this step — i.e., the new token will be written at index kv_len-1)
|
||||
///
|
||||
/// `k_new`, `v_new`: GPU tensors, contiguous, BF16, shape
|
||||
/// `[batch, num_kv_heads, head_dim]`.
|
||||
///
|
||||
/// Like `append_tokens`, this does **not** touch `seq_len`. Call
|
||||
/// `advance_seq_len(slot, 1)` for each slot after every layer has been
|
||||
/// written.
|
||||
pub fn append_tokens_batched(
|
||||
&mut self,
|
||||
layer: usize,
|
||||
k_new: &Tensor,
|
||||
v_new: &Tensor,
|
||||
batch: usize,
|
||||
) {
|
||||
if batch == 0 {
|
||||
return;
|
||||
}
|
||||
let nkv = self.num_kv_heads;
|
||||
let hd = self.head_dim;
|
||||
debug_assert_eq!(k_new.shape(), &[batch, nkv, hd]);
|
||||
debug_assert_eq!(v_new.shape(), &[batch, nkv, hd]);
|
||||
|
||||
let k_src = k_new.data_ptr() as *const std::ffi::c_void;
|
||||
let v_src = v_new.data_ptr() as *const std::ffi::c_void;
|
||||
let k_pool_ptr = self.k_pools[layer].as_mut_ptr() as *mut std::ffi::c_void;
|
||||
let v_pool_ptr = self.v_pools[layer].as_mut_ptr() as *mut std::ffi::c_void;
|
||||
let bt_ptr = self.block_table_gpu.as_ptr() as *const i32;
|
||||
let cl_ptr = self.context_lens_gpu.as_ptr() as *const i32;
|
||||
|
||||
unsafe {
|
||||
xserv_kernels::reshape_and_cache_batched_bf16(
|
||||
k_src,
|
||||
v_src,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
batch,
|
||||
nkv,
|
||||
hd,
|
||||
BLOCK_SIZE,
|
||||
self.max_blocks_per_seq,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Advance the logical seq_len after append_tokens for ALL layers has completed.
|
||||
pub fn advance_seq_len(&mut self, slot: usize, num_tokens: usize) {
|
||||
let state = self.seq_states[slot].as_mut().expect("unregistered slot");
|
||||
state.seq_len += num_tokens;
|
||||
}
|
||||
|
||||
/// Roll a registered sequence back to `new_len` tokens.
|
||||
///
|
||||
/// This only changes cache metadata and frees whole physical blocks that are
|
||||
/// no longer reachable. Bytes inside retained blocks are left untouched; the
|
||||
/// logical `seq_len` prevents attention from reading them, and later writes
|
||||
/// to the same positions overwrite them.
|
||||
pub fn truncate_sequence(&mut self, slot: usize, new_len: usize) -> Result<(), &'static str> {
|
||||
if slot >= self.max_seqs {
|
||||
return Err("truncate_sequence: slot out of range");
|
||||
}
|
||||
let state = self.seq_states[slot]
|
||||
.as_mut()
|
||||
.ok_or("truncate_sequence: empty slot")?;
|
||||
if new_len > state.seq_len {
|
||||
return Err("truncate_sequence: cannot extend");
|
||||
}
|
||||
|
||||
let needed_blocks = ((new_len + BLOCK_SIZE - 1) / BLOCK_SIZE).max(1);
|
||||
while state.block_ids.len() > needed_blocks {
|
||||
let block = state.block_ids.pop().expect("checked len");
|
||||
match state.location {
|
||||
Location::Gpu => self.allocator.free(block),
|
||||
Location::Cpu => self.cpu_allocator.free(block),
|
||||
}
|
||||
}
|
||||
state.seq_len = new_len;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Copy K/V data from `src_pos` to `dst_pos` within the same slot, across
|
||||
/// all layers. Used by tree speculative decoding to remap an accepted
|
||||
/// sibling's K/V to the canonical sequential position after acceptance.
|
||||
///
|
||||
/// Requires: both positions within the currently-allocated block range.
|
||||
pub fn copy_kv_position(&self, slot: usize, src_pos: usize, dst_pos: usize) {
|
||||
let state = self.seq_states[slot]
|
||||
.as_ref()
|
||||
.expect("copy_kv_position: slot not registered");
|
||||
assert!(
|
||||
src_pos < state.seq_len && dst_pos < state.seq_len,
|
||||
"copy_kv_position: positions must be within seq_len"
|
||||
);
|
||||
// Upload this sequence's block_ids to a small GPU buffer.
|
||||
let block_ids_host: Vec<i32> = state.block_ids.iter().map(|&b| b as i32).collect();
|
||||
let bytes: &[u8] = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
block_ids_host.as_ptr() as *const u8,
|
||||
block_ids_host.len() * 4,
|
||||
)
|
||||
};
|
||||
let mut ids_buf =
|
||||
xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc block_ids for copy");
|
||||
ids_buf.copy_from_host(bytes).unwrap();
|
||||
let ids_ptr = ids_buf.as_ptr() as *const i32;
|
||||
|
||||
let stream = xserv_cuda::current_stream_raw();
|
||||
let num_layers = self.k_pools.len();
|
||||
for layer in 0..num_layers {
|
||||
unsafe {
|
||||
xserv_kernels::copy_kv_position(
|
||||
self.k_pools[layer].as_ptr() as *mut std::ffi::c_void,
|
||||
self.v_pools[layer].as_ptr() as *mut std::ffi::c_void,
|
||||
ids_ptr,
|
||||
src_pos,
|
||||
dst_pos,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
BLOCK_SIZE,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Refresh the host-side block table + context lens from `seq_states`,
|
||||
/// then upload to GPU. Call once per decode step before the paged kernel.
|
||||
pub fn sync_to_gpu(&mut self) {
|
||||
let stride = self.max_blocks_per_seq;
|
||||
for slot in 0..self.max_seqs {
|
||||
let row = &mut self.block_table_host[slot * stride..(slot + 1) * stride];
|
||||
row.fill(0);
|
||||
let len = match &self.seq_states[slot] {
|
||||
Some(s) => {
|
||||
for (i, b) in s.block_ids.iter().enumerate() {
|
||||
row[i] = *b as i32;
|
||||
}
|
||||
s.seq_len as i32
|
||||
}
|
||||
None => 0,
|
||||
};
|
||||
self.context_lens_host[slot] = len;
|
||||
}
|
||||
|
||||
self.upload_metadata();
|
||||
}
|
||||
|
||||
/// Pack the given active slots into rows 0..slots.len() of block_table_gpu
|
||||
/// and context_lens_gpu, then upload. Used by paged decode where the kernel
|
||||
/// iterates over `batch` active sequences in order.
|
||||
pub fn sync_active_batch_to_gpu(&mut self, slots: &[usize]) {
|
||||
let lens: Vec<i32> = slots
|
||||
.iter()
|
||||
.map(|&s| self.seq_states[s].as_ref().unwrap().seq_len as i32)
|
||||
.collect();
|
||||
self.sync_active_batch_with_lens(slots, &lens);
|
||||
}
|
||||
|
||||
/// Like sync_active_batch_to_gpu but uses caller-supplied kv_lens (number
|
||||
/// of valid K/V tokens to attend over per active row). Useful when the
|
||||
/// kv_len for the current step differs from the cached seq_len (e.g.
|
||||
/// before advance_seq_len has run).
|
||||
pub fn sync_active_batch_with_lens(&mut self, slots: &[usize], kv_lens: &[i32]) {
|
||||
assert_eq!(slots.len(), kv_lens.len());
|
||||
assert!(
|
||||
slots.len() <= self.max_seqs,
|
||||
"active batch exceeds max_seqs"
|
||||
);
|
||||
let stride = self.max_blocks_per_seq;
|
||||
for row in &mut self.block_table_host {
|
||||
*row = 0;
|
||||
}
|
||||
for cl in &mut self.context_lens_host {
|
||||
*cl = 0;
|
||||
}
|
||||
for (i, &slot) in slots.iter().enumerate() {
|
||||
let s = self.seq_states[slot]
|
||||
.as_ref()
|
||||
.expect("unregistered slot in active batch");
|
||||
let row = &mut self.block_table_host[i * stride..(i + 1) * stride];
|
||||
for (j, b) in s.block_ids.iter().enumerate() {
|
||||
row[j] = *b as i32;
|
||||
}
|
||||
self.context_lens_host[i] = kv_lens[i];
|
||||
}
|
||||
self.upload_metadata();
|
||||
}
|
||||
|
||||
fn upload_metadata(&mut self) {
|
||||
let bt_bytes = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
self.block_table_host.as_ptr() as *const u8,
|
||||
self.block_table_host.len() * std::mem::size_of::<i32>(),
|
||||
)
|
||||
};
|
||||
self.block_table_gpu.copy_from_host(bt_bytes).unwrap();
|
||||
|
||||
let cl_bytes = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
self.context_lens_host.as_ptr() as *const u8,
|
||||
self.context_lens_host.len() * std::mem::size_of::<i32>(),
|
||||
)
|
||||
};
|
||||
self.context_lens_gpu.copy_from_host(cl_bytes).unwrap();
|
||||
}
|
||||
|
||||
/// Materialize a contiguous K/V tensor for a sequence at `layer`, shaped
|
||||
/// [1, num_kv_heads, seq_len, head_dim]. Used for prefill, where Flash
|
||||
/// Attention 2 expects contiguous K/V.
|
||||
///
|
||||
/// Allocates from the cached allocator; the returned Tensors own their storage.
|
||||
pub fn gather_kv_contiguous(&self, slot: usize, layer: usize) -> (Tensor, Tensor) {
|
||||
let state = self.seq_states[slot].as_ref().expect("unregistered slot");
|
||||
let sl = state.seq_len;
|
||||
let nkv = self.num_kv_heads;
|
||||
let hd = self.head_dim;
|
||||
let es = self.elem_size;
|
||||
let bs = BLOCK_SIZE;
|
||||
|
||||
let out_bytes = nkv * sl * hd * es;
|
||||
let mut k_dst = xserv_cuda::allocator::cached_alloc(out_bytes).expect("alloc gather K");
|
||||
let mut v_dst = xserv_cuda::allocator::cached_alloc(out_bytes).expect("alloc gather V");
|
||||
|
||||
let k_pool = &self.k_pools[layer];
|
||||
let v_pool = &self.v_pools[layer];
|
||||
|
||||
let mut p = 0usize;
|
||||
while p < sl {
|
||||
let logical_blk = p / bs;
|
||||
let slot_in_blk = p % bs;
|
||||
let chunk = (bs - slot_in_blk).min(sl - p);
|
||||
let phys = state.block_ids[logical_blk] as usize;
|
||||
|
||||
for h in 0..nkv {
|
||||
let src_off = ((phys * nkv + h) * bs + slot_in_blk) * hd * es;
|
||||
let dst_off = (h * sl + p) * hd * es;
|
||||
let count = chunk * hd * es;
|
||||
k_dst
|
||||
.copy_from_device_at(k_pool, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
v_dst
|
||||
.copy_from_device_at(v_pool, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
}
|
||||
p += chunk;
|
||||
}
|
||||
|
||||
let shape = &[1usize, nkv, sl, hd];
|
||||
let k = unsafe { tensor_from_owned_buf(k_dst, shape, self.dtype, self.device) };
|
||||
let v = unsafe { tensor_from_owned_buf(v_dst, shape, self.dtype, self.device) };
|
||||
(k, v)
|
||||
}
|
||||
|
||||
// ----- Swapping (vLLM-style preemption to pinned host memory) -----
|
||||
|
||||
pub fn free_cpu_blocks(&self) -> usize {
|
||||
self.cpu_allocator.free_count()
|
||||
}
|
||||
pub fn swap_enabled(&self) -> bool {
|
||||
!self.cpu_k_pools.is_empty()
|
||||
}
|
||||
|
||||
pub fn is_swapped(&self, slot: usize) -> bool {
|
||||
matches!(
|
||||
self.seq_states[slot].as_ref().map(|s| s.location),
|
||||
Some(Location::Cpu)
|
||||
)
|
||||
}
|
||||
|
||||
/// Number of physical blocks currently held by `slot` (in either pool).
|
||||
pub fn block_count(&self, slot: usize) -> usize {
|
||||
self.seq_states[slot]
|
||||
.as_ref()
|
||||
.map(|s| s.block_ids.len())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Whether a swapped sequence at `slot` can be brought back (enough free GPU blocks).
|
||||
pub fn can_swap_in(&self, slot: usize) -> bool {
|
||||
self.allocator.can_alloc(self.block_count(slot))
|
||||
}
|
||||
|
||||
/// Whether the GPU sequence at `slot` can be evicted (enough free CPU blocks).
|
||||
pub fn can_swap_out(&self, slot: usize) -> bool {
|
||||
self.cpu_allocator.can_alloc(self.block_count(slot))
|
||||
}
|
||||
|
||||
/// Evict `slot`'s KV from GPU to pinned host memory and free its GPU blocks.
|
||||
/// The slot stays registered (location = Cpu); the sequence is paused.
|
||||
pub fn swap_out(&mut self, slot: usize) -> Result<(), &'static str> {
|
||||
let state = self.seq_states[slot]
|
||||
.as_ref()
|
||||
.ok_or("swap_out: empty slot")?;
|
||||
if state.location == Location::Cpu {
|
||||
return Ok(());
|
||||
}
|
||||
let gpu_ids = state.block_ids.clone();
|
||||
let n = gpu_ids.len();
|
||||
if !self.cpu_allocator.can_alloc(n) {
|
||||
return Err("swap_out: CPU pool full");
|
||||
}
|
||||
|
||||
let cpu_ids: Vec<u32> = (0..n)
|
||||
.map(|_| self.cpu_allocator.alloc().expect("checked can_alloc"))
|
||||
.collect();
|
||||
|
||||
let bb = self.block_bytes;
|
||||
for layer in 0..self.num_layers {
|
||||
for i in 0..n {
|
||||
let g_off = gpu_ids[i] as usize * bb;
|
||||
let c_off = cpu_ids[i] as usize * bb;
|
||||
self.k_pools[layer]
|
||||
.copy_to_host_at(
|
||||
&mut self.cpu_k_pools[layer].as_mut_slice()[c_off..c_off + bb],
|
||||
g_off,
|
||||
bb,
|
||||
)
|
||||
.unwrap();
|
||||
self.v_pools[layer]
|
||||
.copy_to_host_at(
|
||||
&mut self.cpu_v_pools[layer].as_mut_slice()[c_off..c_off + bb],
|
||||
g_off,
|
||||
bb,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
for b in gpu_ids {
|
||||
self.allocator.free(b);
|
||||
}
|
||||
let state = self.seq_states[slot].as_mut().unwrap();
|
||||
state.block_ids = cpu_ids;
|
||||
state.location = Location::Cpu;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Bring `slot`'s KV back from host to GPU and free its CPU blocks.
|
||||
pub fn swap_in(&mut self, slot: usize) -> Result<(), &'static str> {
|
||||
let state = self.seq_states[slot]
|
||||
.as_ref()
|
||||
.ok_or("swap_in: empty slot")?;
|
||||
if state.location == Location::Gpu {
|
||||
return Ok(());
|
||||
}
|
||||
let cpu_ids = state.block_ids.clone();
|
||||
let n = cpu_ids.len();
|
||||
if !self.allocator.can_alloc(n) {
|
||||
return Err("swap_in: GPU pool full");
|
||||
}
|
||||
|
||||
let gpu_ids: Vec<u32> = (0..n)
|
||||
.map(|_| self.allocator.alloc().expect("checked can_alloc"))
|
||||
.collect();
|
||||
|
||||
let bb = self.block_bytes;
|
||||
for layer in 0..self.num_layers {
|
||||
for i in 0..n {
|
||||
let g_off = gpu_ids[i] as usize * bb;
|
||||
let c_off = cpu_ids[i] as usize * bb;
|
||||
self.k_pools[layer]
|
||||
.copy_from_host_at(
|
||||
&self.cpu_k_pools[layer].as_slice()[c_off..c_off + bb],
|
||||
g_off,
|
||||
bb,
|
||||
)
|
||||
.unwrap();
|
||||
self.v_pools[layer]
|
||||
.copy_from_host_at(
|
||||
&self.cpu_v_pools[layer].as_slice()[c_off..c_off + bb],
|
||||
g_off,
|
||||
bb,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
for b in cpu_ids {
|
||||
self.cpu_allocator.free(b);
|
||||
}
|
||||
let state = self.seq_states[slot].as_mut().unwrap();
|
||||
state.block_ids = gpu_ids;
|
||||
state.location = Location::Gpu;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn tiny_config() -> ModelConfig {
|
||||
serde_json::from_value(serde_json::json!({
|
||||
"model_type": "qwen3",
|
||||
"hidden_size": 8,
|
||||
"intermediate_size": 16,
|
||||
"num_attention_heads": 1,
|
||||
"num_key_value_heads": 1,
|
||||
"num_hidden_layers": 1,
|
||||
"vocab_size": 32,
|
||||
"max_position_embeddings": 64
|
||||
}))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_sequence_frees_whole_blocks_and_keeps_slot_registered() {
|
||||
if xserv_cuda::device::set_device(0).is_err() {
|
||||
eprintln!("skipping CUDA-backed PagedKVCache test: device 0 unavailable");
|
||||
return;
|
||||
}
|
||||
|
||||
let config = tiny_config();
|
||||
let mut cache = PagedKVCache::new(&config, 5, 0, 1, 4, DType::BF16, 0);
|
||||
|
||||
assert_eq!(
|
||||
cache.truncate_sequence(1, 0),
|
||||
Err("truncate_sequence: slot out of range")
|
||||
);
|
||||
assert_eq!(
|
||||
cache.truncate_sequence(0, 0),
|
||||
Err("truncate_sequence: empty slot")
|
||||
);
|
||||
|
||||
cache.register_sequence(0).unwrap();
|
||||
cache.ensure_capacity(0, BLOCK_SIZE * 3 + 1);
|
||||
cache.advance_seq_len(0, BLOCK_SIZE * 3 + 1);
|
||||
assert_eq!(cache.seq_len(0), BLOCK_SIZE * 3 + 1);
|
||||
assert_eq!(cache.block_count(0), 4);
|
||||
assert_eq!(cache.free_blocks(), 0);
|
||||
|
||||
cache.truncate_sequence(0, BLOCK_SIZE + 1).unwrap();
|
||||
assert_eq!(cache.seq_len(0), BLOCK_SIZE + 1);
|
||||
assert_eq!(cache.block_count(0), 2);
|
||||
assert_eq!(cache.free_blocks(), 2);
|
||||
|
||||
cache.truncate_sequence(0, BLOCK_SIZE).unwrap();
|
||||
assert_eq!(cache.seq_len(0), BLOCK_SIZE);
|
||||
assert_eq!(cache.block_count(0), 1);
|
||||
assert_eq!(cache.free_blocks(), 3);
|
||||
|
||||
cache.truncate_sequence(0, 0).unwrap();
|
||||
assert_eq!(cache.seq_len(0), 0);
|
||||
assert_eq!(cache.block_count(0), 1);
|
||||
assert_eq!(cache.free_blocks(), 3);
|
||||
assert_eq!(
|
||||
cache.truncate_sequence(0, 1),
|
||||
Err("truncate_sequence: cannot extend")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn tensor_from_owned_buf(
|
||||
buf: GpuBuffer,
|
||||
shape: &[usize],
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Tensor {
|
||||
use smallvec::SmallVec;
|
||||
use xserv_tensor::shape::contiguous_strides;
|
||||
use xserv_tensor::storage::Storage;
|
||||
|
||||
let storage = Storage::cuda(buf, device);
|
||||
Tensor::from_storage(
|
||||
storage,
|
||||
SmallVec::from_slice(shape),
|
||||
contiguous_strides(shape),
|
||||
0,
|
||||
dtype,
|
||||
)
|
||||
}
|
||||
1747
crates/xserv-model/src/qwen3.rs
Normal file
1747
crates/xserv-model/src/qwen3.rs
Normal file
File diff suppressed because it is too large
Load Diff
185
crates/xserv-model/src/qwen3_graph.rs
Normal file
185
crates/xserv-model/src/qwen3_graph.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
//! CUDA-graph replay for Qwen3 batch=1 decode (Phase 24 / speculative draft).
|
||||
//!
|
||||
//! Same pattern as `gpt_oss_graph.rs`, but for the Qwen3 dense decode path used
|
||||
//! by speculative decoding's draft model. A Qwen3-0.6B decode step is ~140
|
||||
//! kernel launches; wrapping the whole step into one `cudaGraphLaunch` cuts
|
||||
//! the ~4× γ draft cost per speculative round.
|
||||
//!
|
||||
//! See `gpt_oss_graph.rs` for the design commentary; the capture preconditions,
|
||||
//! retained-warmup mechanism, and quarantine lifetime are all identical here.
|
||||
|
||||
use std::ffi::c_void;
|
||||
|
||||
use xserv_cuda::allocator::{self, RetainedBlocks};
|
||||
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
|
||||
use xserv_tensor::Tensor;
|
||||
|
||||
use crate::paged_kv_cache::PagedKVCache;
|
||||
use crate::qwen3::Qwen3;
|
||||
|
||||
pub struct Qwen3DecodeGraph {
|
||||
stream: CudaStream,
|
||||
graph: CudaGraph,
|
||||
ids_buf: GpuBuffer, // [1] u32, persistent graph input
|
||||
pos_buf: GpuBuffer, // [1] u32, persistent graph input
|
||||
logits: Tensor, // graph output; rewritten in place by every replay
|
||||
_arena: RetainedBlocks,
|
||||
}
|
||||
|
||||
impl Qwen3DecodeGraph {
|
||||
/// Capture one batch=1 decode step and replay it once.
|
||||
pub fn capture(
|
||||
model: &Qwen3,
|
||||
token: u32,
|
||||
position: usize,
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Self {
|
||||
let stream = CudaStream::new().expect("create capture stream");
|
||||
let mut ids_buf = allocator::cached_alloc(4).expect("alloc ids buf");
|
||||
let mut pos_buf = allocator::cached_alloc(4).expect("alloc pos buf");
|
||||
|
||||
model.decode_prepare(&[position], &[slot], cache);
|
||||
ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
|
||||
pos_buf
|
||||
.copy_from_host(&(position as u32).to_le_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Retained warmup: run the exact step once eagerly with the quarantine
|
||||
// ON to stock the pool. See gpt_oss_graph.rs:66-86 for the full
|
||||
// rationale. Re-running the step is idempotent: the KV scatter
|
||||
// overwrites the same cache position and advance_seq_len is *inside*
|
||||
// decode_core, so we roll it back afterwards.
|
||||
let seq_len_before = cache.seq_len(slot);
|
||||
allocator::begin_retain();
|
||||
{
|
||||
let _guard = xserv_cuda::push_stream(&stream);
|
||||
let _ = model.decode_core(
|
||||
ids_buf.as_ptr() as *const c_void,
|
||||
pos_buf.as_ptr() as *const c_void,
|
||||
1,
|
||||
&[slot],
|
||||
cache,
|
||||
);
|
||||
}
|
||||
drop(allocator::end_retain());
|
||||
stream.synchronize().expect("warmup sync");
|
||||
// decode_core advanced seq_len; roll back so capture starts from the
|
||||
// same logical state as the eager warmup.
|
||||
cache
|
||||
.truncate_sequence(slot, seq_len_before)
|
||||
.expect("rollback after warmup");
|
||||
|
||||
allocator::begin_retain();
|
||||
let mut graph = CudaGraph::new();
|
||||
let logits;
|
||||
{
|
||||
let _guard = xserv_cuda::stream::push_stream(&stream);
|
||||
graph
|
||||
.begin_capture(&stream)
|
||||
.expect("begin decode-graph capture");
|
||||
logits = model.decode_core(
|
||||
ids_buf.as_ptr() as *const c_void,
|
||||
pos_buf.as_ptr() as *const c_void,
|
||||
1,
|
||||
&[slot],
|
||||
cache,
|
||||
);
|
||||
graph
|
||||
.end_capture(&stream)
|
||||
.expect("end decode-graph capture");
|
||||
}
|
||||
let arena = allocator::end_retain();
|
||||
|
||||
// The capture path called advance_seq_len (host-side) but the actual
|
||||
// GPU compute has not yet run. Roll back and let the first replay
|
||||
// advance it exactly once with real K/V writes.
|
||||
cache
|
||||
.truncate_sequence(slot, seq_len_before)
|
||||
.expect("rollback after capture");
|
||||
|
||||
graph.launch(&stream).expect("first decode-graph replay");
|
||||
cache.advance_seq_len(slot, 1);
|
||||
|
||||
Self {
|
||||
stream,
|
||||
graph,
|
||||
ids_buf,
|
||||
pos_buf,
|
||||
logits,
|
||||
_arena: arena,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run one decode step by replaying the captured graph.
|
||||
pub fn step(
|
||||
&mut self,
|
||||
model: &Qwen3,
|
||||
token: u32,
|
||||
position: usize,
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
model.decode_prepare(&[position], &[slot], cache);
|
||||
self.ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
|
||||
self.pos_buf
|
||||
.copy_from_host(&(position as u32).to_le_bytes())
|
||||
.unwrap();
|
||||
self.graph
|
||||
.launch(&self.stream)
|
||||
.expect("decode-graph replay");
|
||||
cache.advance_seq_len(slot, 1);
|
||||
self.logits.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Lazy capture policy: first decode step of the process runs eager, the
|
||||
/// second is captured, the rest replay. Batch>1 always falls back to eager.
|
||||
/// Disable with `XSERV_DECODE_GRAPH=0`.
|
||||
pub struct GraphedQwen3Decoder {
|
||||
graph: Option<Qwen3DecodeGraph>,
|
||||
eager_steps: u32,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl GraphedQwen3Decoder {
|
||||
pub fn new() -> Self {
|
||||
let enabled = std::env::var("XSERV_DECODE_GRAPH")
|
||||
.map(|v| v != "0")
|
||||
.unwrap_or(true);
|
||||
Self {
|
||||
graph: None,
|
||||
eager_steps: 0,
|
||||
enabled,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode(
|
||||
&mut self,
|
||||
model: &Qwen3,
|
||||
tokens: &[u32],
|
||||
positions: &[usize],
|
||||
slots: &[usize],
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
if self.enabled && tokens.len() == 1 {
|
||||
if let Some(g) = self.graph.as_mut() {
|
||||
return g.step(model, tokens[0], positions[0], slots[0], cache);
|
||||
}
|
||||
if self.eager_steps >= 1 {
|
||||
let g = Qwen3DecodeGraph::capture(model, tokens[0], positions[0], slots[0], cache);
|
||||
let logits = g.logits.clone();
|
||||
self.graph = Some(g);
|
||||
return logits;
|
||||
}
|
||||
}
|
||||
self.eager_steps += 1;
|
||||
model.forward_decode_paged(tokens, positions, slots, cache)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GraphedQwen3Decoder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
197
crates/xserv-model/src/sampling.rs
Normal file
197
crates/xserv-model/src/sampling.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
use half::bf16;
|
||||
use rand::Rng;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SamplingParams {
|
||||
pub temperature: f32,
|
||||
pub top_k: usize,
|
||||
pub top_p: f32,
|
||||
}
|
||||
|
||||
impl Default for SamplingParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
temperature: 0.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sample a token from logits with shape [seq_len, vocab_size].
|
||||
/// Uses the last position's logits. Handles both F32 and BF16 dtypes.
|
||||
pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
// Greedy fast path: GPU argmax + 4-byte D2H instead of copying the whole
|
||||
// [seq, vocab] logits to the host and scanning it (~201k bf16/token).
|
||||
// NaN logits lose every `>` comparison in the kernel, matching the
|
||||
// NaN-safe host argmax below.
|
||||
if params.temperature == 0.0
|
||||
&& logits.dtype() == DType::BF16
|
||||
&& matches!(logits.device(), Device::Cuda(_))
|
||||
&& logits.is_contiguous()
|
||||
{
|
||||
let ids = xserv_kernels::argmax_bf16_to_host(logits);
|
||||
return *ids.last().unwrap();
|
||||
}
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
|
||||
// Extract last row as f32
|
||||
let mut last_row: Vec<f32> = match logits.dtype() {
|
||||
DType::F32 => {
|
||||
let data = logits_cpu.as_slice::<f32>();
|
||||
data[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec()
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = logits_cpu.as_slice::<bf16>();
|
||||
data[(seq_len - 1) * vocab_size..seq_len * vocab_size]
|
||||
.iter()
|
||||
.map(|v| v.to_f32())
|
||||
.collect()
|
||||
}
|
||||
_ => panic!("unsupported dtype for sampling: {:?}", logits.dtype()),
|
||||
};
|
||||
|
||||
// Greedy
|
||||
if params.temperature == 0.0 {
|
||||
return argmax(&last_row);
|
||||
}
|
||||
|
||||
// NaN-safe: sampling path uses partial_cmp().unwrap() in top-k/top-p
|
||||
// sorts and softmax; a single NaN logit would panic the engine thread.
|
||||
// Replace NaN with -inf (equivalent to masking) instead.
|
||||
let mut nan_seen = false;
|
||||
for v in last_row.iter_mut() {
|
||||
if v.is_nan() {
|
||||
nan_seen = true;
|
||||
*v = f32::NEG_INFINITY;
|
||||
}
|
||||
}
|
||||
if nan_seen {
|
||||
eprintln!("[sampling] WARNING: NaN logits encountered in sample()");
|
||||
}
|
||||
|
||||
// Apply temperature
|
||||
let mut logits_f32: Vec<f32> = last_row.iter().map(|v| v / params.temperature).collect();
|
||||
|
||||
// Top-k filtering
|
||||
if params.top_k > 0 && params.top_k < vocab_size {
|
||||
let mut indices: Vec<usize> = (0..vocab_size).collect();
|
||||
indices.select_nth_unstable_by(params.top_k, |&a, &b| {
|
||||
logits_f32[b].partial_cmp(&logits_f32[a]).unwrap()
|
||||
});
|
||||
// Everything after top_k should be masked
|
||||
for &i in &indices[params.top_k..] {
|
||||
logits_f32[i] = f32::NEG_INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
// Top-p (nucleus) filtering
|
||||
if params.top_p < 1.0 {
|
||||
// Sort indices by descending logit value
|
||||
let mut indices: Vec<usize> = (0..vocab_size).collect();
|
||||
indices.sort_unstable_by(|&a, &b| logits_f32[b].partial_cmp(&logits_f32[a]).unwrap());
|
||||
|
||||
// Compute softmax probabilities for the sorted order
|
||||
let max_val = logits_f32[indices[0]];
|
||||
let sorted_probs: Vec<f32> = indices
|
||||
.iter()
|
||||
.map(|&i| (logits_f32[i] - max_val).exp())
|
||||
.collect();
|
||||
let sum: f32 = sorted_probs.iter().sum();
|
||||
let sorted_probs: Vec<f32> = sorted_probs.iter().map(|v| v / sum).collect();
|
||||
|
||||
// Cumulative sum, find cutoff
|
||||
let mut cumsum = 0.0f32;
|
||||
let mut cutoff = indices.len();
|
||||
for (rank, &prob) in sorted_probs.iter().enumerate() {
|
||||
cumsum += prob;
|
||||
if cumsum > params.top_p {
|
||||
cutoff = rank + 1; // keep at least this many
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Mask everything beyond cutoff
|
||||
for &i in &indices[cutoff..] {
|
||||
logits_f32[i] = f32::NEG_INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
// Softmax
|
||||
let max_val = logits_f32.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exps: Vec<f32> = logits_f32.iter().map(|v| (v - max_val).exp()).collect();
|
||||
let sum: f32 = exps.iter().sum();
|
||||
let probs: Vec<f32> = exps.iter().map(|v| v / sum).collect();
|
||||
|
||||
// Weighted random sampling
|
||||
let mut rng = rand::thread_rng();
|
||||
let r: f32 = rng.r#gen();
|
||||
let mut cumsum = 0.0f32;
|
||||
for (i, &p) in probs.iter().enumerate() {
|
||||
cumsum += p;
|
||||
if cumsum > r {
|
||||
return i as u32;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback (rounding edge case)
|
||||
(vocab_size - 1) as u32
|
||||
}
|
||||
|
||||
/// Greedy argmax with a repetition penalty applied to `recent` token ids
|
||||
/// (HF-style: divide positive logits by `penalty`, multiply negative by it).
|
||||
/// `penalty <= 1.0` is a no-op. Mitigates greedy repetition loops on reasoning
|
||||
/// models without changing the forward pass. NaN-safe.
|
||||
pub fn sample_greedy_penalized(logits: &Tensor, recent: &[u32], penalty: f32) -> u32 {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let mut last_row: Vec<f32> = match logits.dtype() {
|
||||
DType::F32 => {
|
||||
logits_cpu.as_slice::<f32>()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec()
|
||||
}
|
||||
DType::BF16 => logits_cpu.as_slice::<bf16>()
|
||||
[(seq_len - 1) * vocab_size..seq_len * vocab_size]
|
||||
.iter()
|
||||
.map(|v| v.to_f32())
|
||||
.collect(),
|
||||
_ => panic!("unsupported dtype for sampling: {:?}", logits.dtype()),
|
||||
};
|
||||
if penalty > 1.0 {
|
||||
for &id in recent {
|
||||
let i = id as usize;
|
||||
if i < last_row.len() {
|
||||
let v = last_row[i];
|
||||
last_row[i] = if v > 0.0 { v / penalty } else { v * penalty };
|
||||
}
|
||||
}
|
||||
}
|
||||
argmax(&last_row)
|
||||
}
|
||||
|
||||
fn argmax(data: &[f32]) -> u32 {
|
||||
// NaN-safe: a single NaN logit must not crash the engine thread (a
|
||||
// partial_cmp().unwrap() panics on NaN). Skip NaNs; warn once if seen.
|
||||
let mut best_i = 0usize;
|
||||
let mut best = f32::NEG_INFINITY;
|
||||
let mut nan_seen = false;
|
||||
for (i, &v) in data.iter().enumerate() {
|
||||
if v.is_nan() {
|
||||
nan_seen = true;
|
||||
continue;
|
||||
}
|
||||
if v > best {
|
||||
best = v;
|
||||
best_i = i;
|
||||
}
|
||||
}
|
||||
if nan_seen {
|
||||
eprintln!("[sampling] WARNING: NaN logits encountered in argmax");
|
||||
}
|
||||
best_i as u32
|
||||
}
|
||||
24
crates/xserv-server/Cargo.toml
Normal file
24
crates/xserv-server/Cargo.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "xserv-server"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "xserv-server"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
xserv-cuda = { path = "../xserv-cuda" }
|
||||
xserv-tensor = { path = "../xserv-tensor" }
|
||||
xserv-kernels = { path = "../xserv-kernels" }
|
||||
xserv-model = { path = "../xserv-model" }
|
||||
xserv-tokenizer = { path = "../xserv-tokenizer" }
|
||||
xserv-distributed = { path = "../xserv-distributed" }
|
||||
half.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
tokio.workspace = true
|
||||
axum.workspace = true
|
||||
uuid.workspace = true
|
||||
tokio-stream.workspace = true
|
||||
minijinja.workspace = true
|
||||
573
crates/xserv-server/src/api.rs
Normal file
573
crates/xserv-server/src/api.rs
Normal file
@@ -0,0 +1,573 @@
|
||||
use axum::Extension;
|
||||
use axum::Json;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::convert::Infallible;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
use crate::engine::{GenerateEvent, GenerateRequest};
|
||||
use xserv_model::SamplingParams;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ChatRequest {
|
||||
#[serde(default)]
|
||||
pub model: Option<String>,
|
||||
pub messages: Vec<Message>,
|
||||
#[serde(default = "default_max_tokens")]
|
||||
pub max_tokens: usize,
|
||||
#[serde(default)]
|
||||
pub stream: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(default)]
|
||||
pub top_k: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub top_p: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
fn default_max_tokens() -> usize {
|
||||
256
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ModelsResponse {
|
||||
object: &'static str,
|
||||
data: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ModelInfo {
|
||||
id: String,
|
||||
object: &'static str,
|
||||
owned_by: &'static str,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Chat Template: Jinja2 rendering via minijinja
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct ChatTemplate {
|
||||
source: String,
|
||||
model_type: String,
|
||||
}
|
||||
|
||||
impl ChatTemplate {
|
||||
pub fn load(model_dir: &Path, model_type: &str) -> Self {
|
||||
// 1. Try standalone chat_template.jinja file
|
||||
let jinja_path = model_dir.join("chat_template.jinja");
|
||||
if jinja_path.exists() {
|
||||
let source = std::fs::read_to_string(&jinja_path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", jinja_path.display()));
|
||||
eprintln!("[chat-template] loaded from {}", jinja_path.display());
|
||||
return Self {
|
||||
source,
|
||||
model_type: model_type.to_string(),
|
||||
};
|
||||
}
|
||||
|
||||
// 2. Try tokenizer_config.json → chat_template field
|
||||
let tok_cfg_path = model_dir.join("tokenizer_config.json");
|
||||
if tok_cfg_path.exists() {
|
||||
if let Ok(data) = std::fs::read_to_string(&tok_cfg_path) {
|
||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||
if let Some(ct) = v.get("chat_template").and_then(|v| v.as_str()) {
|
||||
eprintln!("[chat-template] loaded from tokenizer_config.json");
|
||||
return Self {
|
||||
source: ct.to_string(),
|
||||
model_type: model_type.to_string(),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. No template found — use empty source, will fall back to hardcoded
|
||||
eprintln!("[chat-template] no Jinja template found, using hardcoded fallback");
|
||||
Self {
|
||||
source: String::new(),
|
||||
model_type: model_type.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn render(&self, messages: &[Message]) -> String {
|
||||
if self.source.is_empty() {
|
||||
return build_prompt_hardcoded(messages, &self.model_type);
|
||||
}
|
||||
|
||||
match self.render_jinja(messages) {
|
||||
Ok(prompt) => prompt,
|
||||
Err(e) => {
|
||||
eprintln!("[chat-template] Jinja render error: {e}, falling back to hardcoded");
|
||||
build_prompt_hardcoded(messages, &self.model_type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn render_jinja(&self, messages: &[Message]) -> Result<String, minijinja::Error> {
|
||||
let mut env = minijinja::Environment::new();
|
||||
|
||||
// Register custom functions the template may call.
|
||||
env.add_function("strftime_now", strftime_now);
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
|
||||
// Python str methods used by harmony/gpt-oss templates.
|
||||
env.add_filter("startswith", |s: String, prefix: String| -> bool {
|
||||
s.starts_with(&prefix)
|
||||
});
|
||||
|
||||
env.add_template("chat", &self.source)?;
|
||||
let tmpl = env.get_template("chat")?;
|
||||
|
||||
let ctx = minijinja::context! {
|
||||
messages => minijinja::Value::from_serialize(messages),
|
||||
add_generation_prompt => true,
|
||||
bos_token => "",
|
||||
eos_token => "",
|
||||
};
|
||||
|
||||
tmpl.render(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
fn strftime_now(fmt: String) -> String {
|
||||
use std::time::SystemTime;
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
// Only support %Y-%m-%d (the only format used by known templates)
|
||||
let days = now / 86400;
|
||||
let (y, m, d) = days_to_ymd(days);
|
||||
fmt.replace("%Y", &format!("{y:04}"))
|
||||
.replace("%m", &format!("{m:02}"))
|
||||
.replace("%d", &format!("{d:02}"))
|
||||
}
|
||||
|
||||
fn days_to_ymd(days_since_epoch: u64) -> (u32, u32, u32) {
|
||||
// Civil calendar from days since 1970-01-01 (Rata Die algorithm)
|
||||
let z = days_since_epoch as i64 + 719468;
|
||||
let era = (if z >= 0 { z } else { z - 146096 }) / 146097;
|
||||
let doe = (z - era * 146097) as u32;
|
||||
let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
|
||||
let y = yoe as i64 + era * 400;
|
||||
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
|
||||
let mp = (5 * doy + 2) / 153;
|
||||
let d = doy - (153 * mp + 2) / 5 + 1;
|
||||
let m = if mp < 10 { mp + 3 } else { mp - 9 };
|
||||
let y = if m <= 2 { y + 1 } else { y };
|
||||
(y as u32, m, d)
|
||||
}
|
||||
|
||||
fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
|
||||
Err(minijinja::Error::new(
|
||||
minijinja::ErrorKind::InvalidOperation,
|
||||
msg,
|
||||
))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Hardcoded fallback templates (for models without a Jinja template)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn build_prompt_hardcoded(messages: &[Message], model_type: &str) -> String {
|
||||
if model_type == "gpt_oss" {
|
||||
return build_prompt_gpt_oss(messages);
|
||||
}
|
||||
// Default: Qwen3 ChatML format
|
||||
let mut prompt = String::new();
|
||||
for msg in messages {
|
||||
match msg.role.as_str() {
|
||||
"system" | "user" | "assistant" => {
|
||||
prompt.push_str("<|im_start|>");
|
||||
prompt.push_str(&msg.role);
|
||||
prompt.push('\n');
|
||||
prompt.push_str(&msg.content);
|
||||
prompt.push_str("<|im_end|>\n");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
prompt.push_str("<|im_start|>assistant\n");
|
||||
prompt.push_str("<think>\n\n</think>\n\n");
|
||||
prompt
|
||||
}
|
||||
|
||||
fn build_prompt_gpt_oss(messages: &[Message]) -> String {
|
||||
let mut prompt = String::new();
|
||||
// Canonical harmony system message (mirrors the model's chat_template.jinja
|
||||
// build_system_message macro). A hand-rolled substitute puts gpt-oss out of
|
||||
// distribution and destabilizes channel selection. This hardcoded builder is
|
||||
// only a fallback for gpt-oss models that ship no Jinja template; the
|
||||
// gpt-oss-20b release does ship one, so the template path is normally used.
|
||||
prompt.push_str("<|start|>system<|message|>");
|
||||
prompt.push_str("You are ChatGPT, a large language model trained by OpenAI.\n");
|
||||
prompt.push_str("Knowledge cutoff: 2024-06\n");
|
||||
prompt.push_str(&format!(
|
||||
"Current date: {}\n\n",
|
||||
strftime_now("%Y-%m-%d".to_string())
|
||||
));
|
||||
prompt.push_str("Reasoning: low\n\n");
|
||||
prompt.push_str("# Valid channels: analysis, commentary, final. Channel must be included for every message.");
|
||||
prompt.push_str("<|end|>");
|
||||
let dev_instructions: String = messages
|
||||
.iter()
|
||||
.filter(|m| m.role == "system")
|
||||
.map(|m| m.content.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n");
|
||||
if !dev_instructions.is_empty() {
|
||||
prompt.push_str("<|start|>developer<|message|># Instructions\n\n");
|
||||
prompt.push_str(&dev_instructions);
|
||||
prompt.push_str("<|end|>");
|
||||
}
|
||||
for msg in messages {
|
||||
match msg.role.as_str() {
|
||||
"user" => {
|
||||
prompt.push_str("<|start|>user<|message|>");
|
||||
prompt.push_str(&msg.content);
|
||||
prompt.push_str("<|end|>");
|
||||
}
|
||||
"assistant" => {
|
||||
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
|
||||
prompt.push_str(&msg.content);
|
||||
prompt.push_str("<|end|>");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
|
||||
prompt
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HTTP handlers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn health() -> &'static str {
|
||||
"ok"
|
||||
}
|
||||
|
||||
pub async fn list_models(Extension(state): Extension<Arc<AppState>>) -> Json<ModelsResponse> {
|
||||
Json(ModelsResponse {
|
||||
object: "list",
|
||||
data: vec![ModelInfo {
|
||||
id: state.model_name.clone(),
|
||||
object: "model",
|
||||
owned_by: "xserv",
|
||||
}],
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn chat_completions(
|
||||
Extension(state): Extension<Arc<AppState>>,
|
||||
Json(req): Json<ChatRequest>,
|
||||
) -> Response {
|
||||
if req.stream == Some(true) {
|
||||
chat_stream(state, req)
|
||||
} else {
|
||||
chat_non_stream(state, req).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
||||
let id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||
let model_name = state.model_name.clone();
|
||||
let created = unix_timestamp();
|
||||
|
||||
if let Some(response) = validate_request(&req, &model_name) {
|
||||
return response;
|
||||
}
|
||||
|
||||
let prompt = state.chat_template.render(&req.messages);
|
||||
let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt);
|
||||
let prompt_token_count = prompt_tokens.len();
|
||||
|
||||
let max_seq_len = state.max_seq_len;
|
||||
if prompt_token_count >= max_seq_len {
|
||||
return bad_request(format!(
|
||||
"prompt is {} tokens, exceeds max_seq_len {}",
|
||||
prompt_token_count, max_seq_len
|
||||
));
|
||||
}
|
||||
let max_tokens = req.max_tokens.min(max_seq_len - prompt_token_count);
|
||||
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<GenerateEvent>(64);
|
||||
let gen_req = GenerateRequest {
|
||||
prompt_tokens,
|
||||
max_tokens,
|
||||
sampling: sampling_params(&req),
|
||||
sender: tx,
|
||||
};
|
||||
if let Err(resp) = submit_to_engine(&state, gen_req) {
|
||||
return resp;
|
||||
}
|
||||
|
||||
let mut content = String::new();
|
||||
let mut completion_token_count: usize = 0;
|
||||
let mut finish_reason = "length".to_string();
|
||||
while let Some(event) = rx.recv().await {
|
||||
match event {
|
||||
GenerateEvent::Token { text, .. } => {
|
||||
completion_token_count += 1;
|
||||
content.push_str(&text);
|
||||
}
|
||||
GenerateEvent::Done { finish_reason: fr } => {
|
||||
finish_reason = fr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let fr_value = match normalize_finish_reason(&finish_reason) {
|
||||
Some(s) => serde_json::Value::String(s.to_string()),
|
||||
None => serde_json::Value::Null,
|
||||
};
|
||||
Json(serde_json::json!({
|
||||
"id": id,
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": { "role": "assistant", "content": content },
|
||||
"finish_reason": fr_value,
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": prompt_token_count + completion_token_count
|
||||
}
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn chat_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
||||
let id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||
let model_name = state.model_name.clone();
|
||||
let created = unix_timestamp();
|
||||
|
||||
if let Some(response) = validate_request(&req, &model_name) {
|
||||
return response;
|
||||
}
|
||||
|
||||
let prompt = state.chat_template.render(&req.messages);
|
||||
let prompt_tokens = state.engine_tokenizer.lock().unwrap().encode(&prompt);
|
||||
|
||||
let max_seq_len = state.max_seq_len;
|
||||
if prompt_tokens.len() >= max_seq_len {
|
||||
return bad_request(format!(
|
||||
"prompt is {} tokens, exceeds max_seq_len {}",
|
||||
prompt_tokens.len(),
|
||||
max_seq_len
|
||||
));
|
||||
}
|
||||
let max_tokens = req.max_tokens.min(max_seq_len - prompt_tokens.len());
|
||||
|
||||
let (engine_tx, engine_rx) = tokio::sync::mpsc::channel::<GenerateEvent>(64);
|
||||
let gen_req = GenerateRequest {
|
||||
prompt_tokens,
|
||||
max_tokens,
|
||||
sampling: sampling_params(&req),
|
||||
sender: engine_tx,
|
||||
};
|
||||
if let Err(resp) = submit_to_engine(&state, gen_req) {
|
||||
return resp;
|
||||
}
|
||||
|
||||
// SSE event channel: engine events -> SSE events
|
||||
let (sse_tx, sse_rx) = tokio::sync::mpsc::channel::<Result<Event, Infallible>>(64);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut engine_stream = ReceiverStream::new(engine_rx);
|
||||
let mut first = true;
|
||||
|
||||
while let Some(event) = engine_stream.next().await {
|
||||
match event {
|
||||
GenerateEvent::Token { text, .. } => {
|
||||
if first {
|
||||
// First chunk: role announcement
|
||||
let chunk =
|
||||
make_chunk(&id, &model_name, created, None, Some("assistant"), None);
|
||||
let _ = sse_tx.send(Ok(Event::default().data(chunk))).await;
|
||||
first = false;
|
||||
}
|
||||
let chunk = make_chunk(&id, &model_name, created, Some(&text), None, None);
|
||||
if sse_tx.send(Ok(Event::default().data(chunk))).await.is_err() {
|
||||
return; // client disconnected
|
||||
}
|
||||
}
|
||||
GenerateEvent::Done { finish_reason } => {
|
||||
if first {
|
||||
// Edge case: Done arrived with no tokens
|
||||
let chunk =
|
||||
make_chunk(&id, &model_name, created, None, Some("assistant"), None);
|
||||
let _ = sse_tx.send(Ok(Event::default().data(chunk))).await;
|
||||
}
|
||||
// Only "stop" and "length" are OpenAI-standard values. Internal
|
||||
// codes like "error" (client-stalled from tp/pp engine) map to
|
||||
// null so SDK clients see a clean stream close.
|
||||
let fr = normalize_finish_reason(&finish_reason);
|
||||
let chunk = make_chunk(&id, &model_name, created, None, None, fr);
|
||||
let _ = sse_tx.send(Ok(Event::default().data(chunk))).await;
|
||||
let _ = sse_tx
|
||||
.send(Ok(Event::default().data("[DONE]".to_string())))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Sse::new(ReceiverStream::new(sse_rx))
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn validate_request(req: &ChatRequest, model_name: &str) -> Option<Response> {
|
||||
if let Some(model) = &req.model {
|
||||
if model != model_name {
|
||||
return Some(bad_request(format!(
|
||||
"model '{model}' is not loaded; available model is '{model_name}'"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if req.max_tokens == 0 {
|
||||
return Some(bad_request("max_tokens must be greater than 0"));
|
||||
}
|
||||
|
||||
if let Some(t) = req.temperature {
|
||||
if !t.is_finite() || t < 0.0 {
|
||||
return Some(bad_request("temperature must be a finite value >= 0"));
|
||||
}
|
||||
}
|
||||
if let Some(p) = req.top_p {
|
||||
if !p.is_finite() || !(0.0..=1.0).contains(&p) {
|
||||
return Some(bad_request("top_p must be in [0, 1]"));
|
||||
}
|
||||
}
|
||||
if let Some(k) = req.top_k {
|
||||
if k > 1_000_000 {
|
||||
return Some(bad_request("top_k must be <= 1_000_000"));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Hand a request to the engine thread. Poison-tolerant (recovers the lock if a
|
||||
/// prior handler panicked) and returns a clean 503 instead of panicking when the
|
||||
/// engine thread is gone, so one dead engine doesn't cascade into every request.
|
||||
fn submit_to_engine(state: &AppState, req: GenerateRequest) -> Result<(), Response> {
|
||||
let sender = state
|
||||
.engine_sender
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
sender.try_send(req).map_err(|err| match err {
|
||||
std::sync::mpsc::TrySendError::Full(_) => {
|
||||
service_unavailable("inference engine is busy, retry later")
|
||||
}
|
||||
std::sync::mpsc::TrySendError::Disconnected(_) => {
|
||||
service_unavailable("inference engine is not available")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn service_unavailable(message: impl Into<String>) -> Response {
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": message.into(), "type": "server_error" }
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn bad_request(message: impl Into<String>) -> Response {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": message.into(),
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn make_chunk(
|
||||
id: &str,
|
||||
model: &str,
|
||||
created: u64,
|
||||
content: Option<&str>,
|
||||
role: Option<&str>,
|
||||
finish_reason: Option<&str>,
|
||||
) -> String {
|
||||
let mut delta = serde_json::Map::new();
|
||||
if let Some(r) = role {
|
||||
delta.insert("role".into(), serde_json::Value::String(r.into()));
|
||||
// Role chunk also includes empty content per OpenAI spec
|
||||
delta.insert("content".into(), serde_json::Value::String(String::new()));
|
||||
}
|
||||
if let Some(c) = content {
|
||||
delta.insert("content".into(), serde_json::Value::String(c.into()));
|
||||
}
|
||||
|
||||
let fr = match finish_reason {
|
||||
Some(r) => serde_json::Value::String(r.into()),
|
||||
None => serde_json::Value::Null,
|
||||
};
|
||||
|
||||
serde_json::json!({
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": fr,
|
||||
}]
|
||||
})
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn unix_timestamp() -> u64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
fn sampling_params(req: &ChatRequest) -> SamplingParams {
|
||||
SamplingParams {
|
||||
temperature: req.temperature.unwrap_or(0.0),
|
||||
top_k: req.top_k.unwrap_or(0),
|
||||
top_p: req.top_p.unwrap_or(1.0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Map engine finish_reason strings to OpenAI-standard values. Any engine-internal
|
||||
/// code (e.g. "error" from tp/pp client-stall) collapses to None so SDK clients see
|
||||
/// a clean null instead of an unknown value.
|
||||
fn normalize_finish_reason(fr: &str) -> Option<&'static str> {
|
||||
match fr {
|
||||
"stop" => Some("stop"),
|
||||
"length" => Some("length"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
460
crates/xserv-server/src/engine.rs
Normal file
460
crates/xserv-server/src/engine.rs
Normal file
@@ -0,0 +1,460 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::path::Path;
|
||||
use std::sync::Once;
|
||||
use std::sync::mpsc;
|
||||
use std::time::Instant;
|
||||
use xserv_model::loader;
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, SamplingParams, sample};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
pub struct Engine {
|
||||
model: Qwen3,
|
||||
config: ModelConfig,
|
||||
tokenizer: Tokenizer,
|
||||
max_batch_size: usize,
|
||||
max_seq_len: usize,
|
||||
paged_cache: PagedKVCache,
|
||||
}
|
||||
|
||||
pub struct GenerateRequest {
|
||||
pub prompt_tokens: Vec<u32>,
|
||||
pub max_tokens: usize,
|
||||
pub sampling: SamplingParams,
|
||||
pub sender: tokio::sync::mpsc::Sender<GenerateEvent>,
|
||||
}
|
||||
|
||||
pub enum GenerateEvent {
|
||||
Token { id: u32, text: String },
|
||||
Done { finish_reason: String },
|
||||
}
|
||||
|
||||
struct Sequence {
|
||||
id: u64,
|
||||
prompt_tokens: Vec<u32>,
|
||||
generated_tokens: Vec<u32>,
|
||||
max_tokens: usize,
|
||||
sampling: SamplingParams,
|
||||
seq_slot: Option<usize>,
|
||||
sender: tokio::sync::mpsc::Sender<GenerateEvent>,
|
||||
prefilled: bool,
|
||||
/// Set when a `try_send` failed (client too slow or gone). The scheduler
|
||||
/// reaps the sequence next iteration instead of blocking the decode thread.
|
||||
client_stalled: bool,
|
||||
eos_token_id: Option<u32>,
|
||||
decode_buffer: Vec<u8>,
|
||||
created_at: Instant,
|
||||
}
|
||||
|
||||
impl Engine {
|
||||
pub fn load(model_dir: &Path, max_batch_size: usize, max_seq_len: usize) -> Self {
|
||||
Self::load_with_swap(model_dir, max_batch_size, max_seq_len, 8)
|
||||
}
|
||||
|
||||
pub fn load_with_swap(
|
||||
model_dir: &Path,
|
||||
max_batch_size: usize,
|
||||
max_seq_len: usize,
|
||||
swap_space_gb: usize,
|
||||
) -> Self {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
eprintln!("[engine] Loading weights...");
|
||||
let weights = loader::load_model_dir(model_dir, Device::Cuda(0));
|
||||
eprintln!("[engine] Loaded {} tensors", weights.len());
|
||||
let model = Qwen3::from_weights(config.clone(), weights);
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
|
||||
// Tier-1 sizing: size the GPU block pool to *available VRAM* after the
|
||||
// weights are resident, not to worst-case max_batch * max_ctx. This is
|
||||
// what makes paged attention elastic — sequences share the pool on
|
||||
// demand, and overflow is swapped to host (Tier-2) rather than reserved.
|
||||
let bytes_per_block = PagedKVCache::bytes_per_block(&config, DType::BF16);
|
||||
let info = xserv_cuda::device::device_info(0).expect("device info");
|
||||
// Reserve headroom for activations, cuBLAS workspace and the [B, vocab]
|
||||
// logits buffer; the transpose peak during load is already behind us.
|
||||
const ACTIVATION_RESERVE: usize = 3 * 1024 * 1024 * 1024; // 3 GiB
|
||||
let util_num = 90; // use 90% of remaining free memory for KV
|
||||
let usable = info.free_memory.saturating_sub(ACTIVATION_RESERVE);
|
||||
let mut total_blocks = (usable * util_num / 100) / bytes_per_block;
|
||||
// Cap at a sane upper bound and ensure a floor.
|
||||
total_blocks = total_blocks.max(256);
|
||||
// Test hook: force a small GPU pool to exercise the swap path. Must stay
|
||||
// >= max_blocks_per_seq so a single max-length sequence still fits.
|
||||
if let Ok(v) = std::env::var("XSERV_MAX_KV_BLOCKS") {
|
||||
if let Ok(n) = v.parse::<usize>() {
|
||||
total_blocks = total_blocks.min(n);
|
||||
eprintln!("[engine] XSERV_MAX_KV_BLOCKS override: gpu_blocks={total_blocks}");
|
||||
}
|
||||
}
|
||||
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
// Slots must cover running + swapped sequences, so be generous (cheap:
|
||||
// each slot is just a block-table row of i32s).
|
||||
let max_seqs_slots = (max_batch_size * 8).max(32);
|
||||
// CPU swap pool: swap_space_gb of pinned host memory.
|
||||
let cpu_total_blocks = (swap_space_gb * 1024 * 1024 * 1024) / bytes_per_block;
|
||||
|
||||
let paged_cache = PagedKVCache::new(
|
||||
&config,
|
||||
total_blocks,
|
||||
cpu_total_blocks,
|
||||
max_seqs_slots,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
0,
|
||||
);
|
||||
|
||||
eprintln!(
|
||||
"[engine] Ready (max_batch={max_batch_size}, max_seq_len={max_seq_len}, \
|
||||
gpu_blocks={total_blocks} ({:.1} GiB), swap_blocks={cpu_total_blocks} ({swap_space_gb} GiB), \
|
||||
free_vram={:.1} GiB)",
|
||||
(total_blocks * bytes_per_block) as f64 / 1e9,
|
||||
info.free_memory as f64 / 1e9,
|
||||
);
|
||||
Self {
|
||||
model,
|
||||
config,
|
||||
tokenizer,
|
||||
max_batch_size,
|
||||
max_seq_len,
|
||||
paged_cache,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tokenizer(&self) -> &Tokenizer {
|
||||
&self.tokenizer
|
||||
}
|
||||
|
||||
pub fn max_seq_len(&self) -> usize {
|
||||
self.max_seq_len
|
||||
}
|
||||
|
||||
/// Main scheduler loop. Receives requests from channel, manages concurrent sequences.
|
||||
///
|
||||
/// Sequences move between three sets:
|
||||
/// waiting — admitted to the queue, no GPU slot yet
|
||||
/// running — KV resident on GPU, actively prefilling/decoding
|
||||
/// swapped — KV evicted to pinned host memory (preempted), paused
|
||||
/// When running sequences grow past the GPU block pool, the newest are
|
||||
/// swapped out to host (vLLM-style) and swapped back in when blocks free up.
|
||||
pub fn run(&mut self, rx: mpsc::Receiver<GenerateRequest>) {
|
||||
let mut waiting: VecDeque<Sequence> = VecDeque::new();
|
||||
let mut running: Vec<Sequence> = Vec::new();
|
||||
let mut swapped: Vec<Sequence> = Vec::new();
|
||||
let mut next_id: u64 = 0;
|
||||
|
||||
eprintln!("[scheduler] Listening for requests...");
|
||||
|
||||
loop {
|
||||
// Step 1: Remove finished sequences and return their slots.
|
||||
let finished_slots: Vec<usize> = running
|
||||
.iter()
|
||||
.filter(|s| is_finished(s))
|
||||
.filter_map(|s| s.seq_slot)
|
||||
.collect();
|
||||
for slot in finished_slots {
|
||||
self.paged_cache.free_sequence(slot);
|
||||
}
|
||||
running.retain(|seq| !is_finished(seq));
|
||||
|
||||
// Step 2: Swap previously-evicted sequences back in when there is
|
||||
// room (oldest first). They resume decoding from where they paused.
|
||||
while running.len() < self.max_batch_size && !swapped.is_empty() {
|
||||
let slot = swapped[0].seq_slot.expect("swapped slot");
|
||||
if !self.paged_cache.can_swap_in(slot) {
|
||||
break;
|
||||
}
|
||||
self.paged_cache.swap_in(slot).expect("swap_in");
|
||||
let seq = swapped.remove(0);
|
||||
eprintln!(
|
||||
"[scheduler] swapped in seq {} ({} blocks)",
|
||||
seq.id,
|
||||
self.paged_cache.block_count(slot)
|
||||
);
|
||||
running.push(seq);
|
||||
}
|
||||
|
||||
// Step 3: Admit new sequences (block-aware). Only admit if the GPU
|
||||
// pool can hold the prompt AND leave one block of decode headroom
|
||||
// per already-running sequence, so admission never starves decode.
|
||||
{
|
||||
let mut avail = self.paged_cache.free_blocks();
|
||||
let decode_reserve = running.len();
|
||||
while running.len() < self.max_batch_size {
|
||||
let Some(front) = waiting.front() else {
|
||||
break;
|
||||
};
|
||||
let prompt_blocks = front.prompt_tokens.len().div_ceil(BLOCK_SIZE).max(1);
|
||||
if avail < prompt_blocks + decode_reserve {
|
||||
break;
|
||||
}
|
||||
let free_slot = (0..self.paged_cache.max_seqs())
|
||||
.find(|&s| self.paged_cache.is_slot_free(s));
|
||||
let Some(slot) = free_slot else {
|
||||
break;
|
||||
};
|
||||
let mut seq = waiting.pop_front().unwrap();
|
||||
self.paged_cache
|
||||
.register_sequence(slot)
|
||||
.expect("register paged slot");
|
||||
seq.seq_slot = Some(slot);
|
||||
running.push(seq);
|
||||
avail -= prompt_blocks; // projected free after this seq prefills
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: If nothing to do, blocking wait for new request.
|
||||
if running.is_empty() && waiting.is_empty() && swapped.is_empty() {
|
||||
match rx.recv() {
|
||||
Ok(req) => {
|
||||
let seq = self.make_sequence(req, &mut next_id);
|
||||
waiting.push_back(seq);
|
||||
continue;
|
||||
}
|
||||
Err(_) => break, // channel closed
|
||||
}
|
||||
}
|
||||
// Nothing runnable this iteration (e.g. all swapped, waiting on
|
||||
// blocks to free): loop to retry swap-in/admission next iteration.
|
||||
if running.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Step 5a: Process prefills (one at a time — different prompt lengths).
|
||||
// Admission guaranteed block headroom, so ensure_capacity won't starve.
|
||||
let mut newly_prefilled = Vec::new();
|
||||
for seq in running.iter_mut() {
|
||||
if !seq.prefilled {
|
||||
let slot = seq.seq_slot.expect("slot");
|
||||
let logits = self.model.forward_prefill_paged(
|
||||
&seq.prompt_tokens,
|
||||
slot,
|
||||
&mut self.paged_cache,
|
||||
);
|
||||
let next = sample(&logits, &seq.sampling);
|
||||
seq.generated_tokens.push(next);
|
||||
seq.prefilled = true;
|
||||
emit_token(&self.tokenizer, seq, next);
|
||||
newly_prefilled.push(seq.id);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5b: Ensure block headroom for this decode step; preempt the
|
||||
// newest running sequences to host if the pool can't cover it.
|
||||
let mut needed = decode_block_need(&self.paged_cache, &running, &newly_prefilled);
|
||||
while self.paged_cache.free_blocks() < needed {
|
||||
// Victim: newest prefilled, decoding (not just-prefilled) sequence.
|
||||
let victim = (0..running.len()).rev().find(|&p| {
|
||||
running[p].prefilled
|
||||
&& !newly_prefilled.contains(&running[p].id)
|
||||
&& running[p].seq_slot.is_some()
|
||||
});
|
||||
let Some(pos) = victim else {
|
||||
break;
|
||||
};
|
||||
let seq = running.remove(pos);
|
||||
let slot = seq.seq_slot.unwrap();
|
||||
if self.paged_cache.can_swap_out(slot) {
|
||||
let nblocks = self.paged_cache.block_count(slot);
|
||||
self.paged_cache.swap_out(slot).expect("swap_out");
|
||||
eprintln!(
|
||||
"[scheduler] preempt: swapped out seq {} ({nblocks} blocks) to host",
|
||||
seq.id
|
||||
);
|
||||
swapped.push(seq);
|
||||
needed = decode_block_need(&self.paged_cache, &running, &newly_prefilled);
|
||||
} else {
|
||||
running.insert(pos, seq); // CPU pool full — can't evict further
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5c: Batched paged decode for the surviving prefilled sequences.
|
||||
let decode_indices: Vec<usize> = running
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, s)| s.prefilled && !newly_prefilled.contains(&s.id))
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
if !decode_indices.is_empty() {
|
||||
static LOG_ONCE: Once = Once::new();
|
||||
LOG_ONCE.call_once(|| {
|
||||
eprintln!("[scheduler] paged decode active");
|
||||
});
|
||||
|
||||
let tokens: Vec<u32> = decode_indices
|
||||
.iter()
|
||||
.map(|&i| *running[i].generated_tokens.last().unwrap())
|
||||
.collect();
|
||||
let positions: Vec<usize> = decode_indices
|
||||
.iter()
|
||||
.map(|&i| self.paged_cache.seq_len(running[i].seq_slot.unwrap()))
|
||||
.collect();
|
||||
let slots: Vec<usize> = decode_indices
|
||||
.iter()
|
||||
.map(|&i| running[i].seq_slot.unwrap())
|
||||
.collect();
|
||||
|
||||
let logits = self.model.forward_decode_paged(
|
||||
&tokens,
|
||||
&positions,
|
||||
&slots,
|
||||
&mut self.paged_cache,
|
||||
);
|
||||
|
||||
// Fast path: every active sequence is greedy → run argmax on
|
||||
// the GPU and only D2H the chosen token ids (a few bytes per
|
||||
// sequence) instead of the full [B, vocab_size] BF16 logits
|
||||
// (~1.2 MB for B=4, Qwen3 vocab=152K).
|
||||
let all_greedy = decode_indices
|
||||
.iter()
|
||||
.all(|&i| running[i].sampling.temperature == 0.0);
|
||||
if all_greedy {
|
||||
let next_ids = xserv_kernels::argmax_bf16_to_host(&logits);
|
||||
for (j, &i) in decode_indices.iter().enumerate() {
|
||||
let next = next_ids[j];
|
||||
running[i].generated_tokens.push(next);
|
||||
emit_token(&self.tokenizer, &mut running[i], next);
|
||||
}
|
||||
} else {
|
||||
// Mixed sampling: keep the CPU path for now (top-k/top-p
|
||||
// sampling still runs there). Only the rows that need it
|
||||
// get exercised; greedy rows could in principle reuse the
|
||||
// GPU argmax but the CPU pass is short for B<=4.
|
||||
let vocab_size = logits.shape()[1];
|
||||
let logits_cpu = logits.to_device(xserv_tensor::Device::Cpu);
|
||||
let data = logits_cpu.as_slice::<half::bf16>();
|
||||
for (j, &i) in decode_indices.iter().enumerate() {
|
||||
let row_start = j * vocab_size;
|
||||
let row_logits = &data[row_start..row_start + vocab_size];
|
||||
let next = if running[i].sampling.temperature == 0.0 {
|
||||
row_logits
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(idx, _)| idx as u32)
|
||||
.unwrap()
|
||||
} else {
|
||||
let row_tensor =
|
||||
xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
|
||||
sample(&row_tensor, &running[i].sampling)
|
||||
};
|
||||
running[i].generated_tokens.push(next);
|
||||
emit_token(&self.tokenizer, &mut running[i], next);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 6: Check for newly arrived requests (non-blocking)
|
||||
loop {
|
||||
match rx.try_recv() {
|
||||
Ok(req) => {
|
||||
let seq = self.make_sequence(req, &mut next_id);
|
||||
waiting.push_back(seq);
|
||||
}
|
||||
Err(mpsc::TryRecvError::Empty) => break,
|
||||
Err(mpsc::TryRecvError::Disconnected) => return,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn make_sequence(&mut self, req: GenerateRequest, next_id: &mut u64) -> Sequence {
|
||||
let id = *next_id;
|
||||
*next_id += 1;
|
||||
Sequence {
|
||||
id,
|
||||
prompt_tokens: req.prompt_tokens,
|
||||
generated_tokens: Vec::new(),
|
||||
max_tokens: req.max_tokens,
|
||||
sampling: req.sampling,
|
||||
seq_slot: None,
|
||||
sender: req.sender,
|
||||
prefilled: false,
|
||||
client_stalled: false,
|
||||
eos_token_id: self.tokenizer.eos_token_id(),
|
||||
decode_buffer: Vec::new(),
|
||||
created_at: Instant::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Total additional GPU blocks the next decode step needs across all
|
||||
/// currently-decoding (prefilled, not just-prefilled) sequences.
|
||||
fn decode_block_need(paged: &PagedKVCache, running: &[Sequence], newly_prefilled: &[u64]) -> usize {
|
||||
running
|
||||
.iter()
|
||||
.filter(|s| s.prefilled && !newly_prefilled.contains(&s.id))
|
||||
.filter_map(|s| s.seq_slot)
|
||||
.map(|slot| paged.additional_blocks_needed(slot, 1))
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn emit_token(tokenizer: &Tokenizer, seq: &mut Sequence, token_id: u32) {
|
||||
if tokenizer.eos_token_id() == Some(token_id) {
|
||||
let tail = tokenizer.flush_decode_stream(&mut seq.decode_buffer);
|
||||
send_token_if_nonempty(seq, tail);
|
||||
try_send_event(
|
||||
seq,
|
||||
GenerateEvent::Done {
|
||||
finish_reason: "stop".to_string(),
|
||||
},
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let text = tokenizer.decode_token_stream(token_id, &mut seq.decode_buffer);
|
||||
if seq.generated_tokens.len() >= seq.max_tokens {
|
||||
let tail = tokenizer.flush_decode_stream(&mut seq.decode_buffer);
|
||||
send_token_if_nonempty(seq, text);
|
||||
send_token_if_nonempty(seq, tail);
|
||||
try_send_event(
|
||||
seq,
|
||||
GenerateEvent::Done {
|
||||
finish_reason: "length".to_string(),
|
||||
},
|
||||
);
|
||||
} else {
|
||||
send_token_if_nonempty(seq, text);
|
||||
}
|
||||
}
|
||||
|
||||
fn send_token_if_nonempty(seq: &mut Sequence, text: String) {
|
||||
if !text.is_empty() {
|
||||
let id = *seq.generated_tokens.last().unwrap_or(&0);
|
||||
try_send_event(seq, GenerateEvent::Token { id, text });
|
||||
}
|
||||
}
|
||||
|
||||
/// Send an event without blocking the shared decode thread. If the client is
|
||||
/// too slow (channel full) or gone (closed), flag the sequence for eviction
|
||||
/// instead of blocking — one slow consumer must never stall the whole
|
||||
/// continuous-batching loop. When the sequence is reaped its `sender` drops,
|
||||
/// closing the channel so the client's receive loop ends rather than hanging.
|
||||
fn try_send_event(seq: &mut Sequence, event: GenerateEvent) {
|
||||
if let Err(err) = seq.sender.try_send(event) {
|
||||
seq.client_stalled = true;
|
||||
if let tokio::sync::mpsc::error::TrySendError::Full(_) = err {
|
||||
eprintln!(
|
||||
"[scheduler] seq {}: client too slow (stream channel full), evicting",
|
||||
seq.id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_finished(seq: &Sequence) -> bool {
|
||||
if seq.client_stalled {
|
||||
return true;
|
||||
}
|
||||
if seq.generated_tokens.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let last = *seq.generated_tokens.last().unwrap();
|
||||
if seq.generated_tokens.len() >= seq.max_tokens {
|
||||
return true;
|
||||
}
|
||||
seq.sender.is_closed() || seq.eos_token_id == Some(last)
|
||||
}
|
||||
153
crates/xserv-server/src/main.rs
Normal file
153
crates/xserv-server/src/main.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
mod api;
|
||||
mod engine;
|
||||
mod pp_engine;
|
||||
mod tp_engine;
|
||||
|
||||
use axum::{
|
||||
Extension, Router,
|
||||
extract::DefaultBodyLimit,
|
||||
routing::{get, post},
|
||||
};
|
||||
use engine::GenerateRequest;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex, mpsc};
|
||||
use xserv_model::ModelConfig;
|
||||
|
||||
pub struct AppState {
|
||||
pub model_name: String,
|
||||
pub chat_template: api::ChatTemplate,
|
||||
pub engine_sender: Mutex<mpsc::SyncSender<GenerateRequest>>,
|
||||
pub engine_tokenizer: Mutex<xserv_tokenizer::Tokenizer>,
|
||||
pub max_seq_len: usize,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!(
|
||||
"Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N] [--pp N]"
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let port: u16 = args
|
||||
.iter()
|
||||
.position(|a| a == "--port")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(8080);
|
||||
let max_batch: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--max-batch")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(4)
|
||||
.max(1);
|
||||
let requested_max_seq_len: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--max-seq-len")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(2048)
|
||||
.max(1);
|
||||
let swap_space_gb: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--swap-space-gb")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(8);
|
||||
let tp: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--tp")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
let pp: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--pp")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
if tp > 1 && pp > 1 {
|
||||
eprintln!("--tp and --pp cannot be combined yet (2D TP×PP is future work)");
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
// gpt-oss is only implemented in the TP engine; route it there even at
|
||||
// tp=1 (single-rank world) so quantized models can serve on one GPU.
|
||||
let is_gpt_oss = model_config.model_type.as_deref() == Some("gpt_oss");
|
||||
if pp > 1 && is_gpt_oss {
|
||||
eprintln!(
|
||||
"gpt-oss is not supported by the pipeline-parallel engine (Qwen3 only); use --tp instead"
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_max_seq_len = model_config.max_seq_len();
|
||||
if model_max_seq_len == 0 {
|
||||
eprintln!("model config has invalid max_seq_len=0");
|
||||
std::process::exit(1);
|
||||
}
|
||||
let max_seq_len = requested_max_seq_len.min(model_max_seq_len);
|
||||
if max_seq_len != requested_max_seq_len {
|
||||
eprintln!(
|
||||
"[server] --max-seq-len {requested_max_seq_len} exceeds model limit {model_max_seq_len}; using {max_seq_len}"
|
||||
);
|
||||
}
|
||||
|
||||
let model_name = model_dir
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let tokenizer = xserv_tokenizer::Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
|
||||
// Bounded channel to backpressure incoming requests when the engine falls
|
||||
// behind, instead of letting them pile up in RAM. try_send in the API
|
||||
// handler surfaces this as 503 to the client.
|
||||
let (tx, rx) = mpsc::sync_channel::<GenerateRequest>(256);
|
||||
|
||||
let model_dir_clone = model_dir.clone();
|
||||
std::thread::spawn(move || {
|
||||
if pp > 1 {
|
||||
// Pipeline-parallel path: stage-0 coordinator + worker stage threads.
|
||||
pp_engine::run_pp(&model_dir_clone, pp, max_seq_len, rx);
|
||||
} else if tp <= 1 && !is_gpt_oss {
|
||||
let mut engine = engine::Engine::load_with_swap(
|
||||
&model_dir_clone,
|
||||
max_batch,
|
||||
max_seq_len,
|
||||
swap_space_gb,
|
||||
);
|
||||
engine.run(rx);
|
||||
} else {
|
||||
// Tensor-parallel path: rank-0 coordinator + worker rank threads.
|
||||
tp_engine::run_tp(&model_dir_clone, tp, max_seq_len, rx);
|
||||
}
|
||||
});
|
||||
|
||||
let model_type = model_config.model_type.clone().unwrap_or_default();
|
||||
let chat_template = api::ChatTemplate::load(&model_dir, &model_type);
|
||||
let state = Arc::new(AppState {
|
||||
model_name,
|
||||
chat_template,
|
||||
engine_sender: Mutex::new(tx),
|
||||
engine_tokenizer: Mutex::new(tokenizer),
|
||||
max_seq_len,
|
||||
});
|
||||
|
||||
let app = Router::new()
|
||||
.route("/health", get(api::health))
|
||||
.route("/v1/models", get(api::list_models))
|
||||
.route("/v1/chat/completions", post(api::chat_completions))
|
||||
.layer(DefaultBodyLimit::max(4 * 1024 * 1024))
|
||||
.layer(Extension(state));
|
||||
|
||||
let addr = format!("0.0.0.0:{port}");
|
||||
eprintln!("[server] Listening on {addr} (max_batch={max_batch}, max_seq_len={max_seq_len})");
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
338
crates/xserv-server/src/pp_engine.rs
Normal file
338
crates/xserv-server/src/pp_engine.rs
Normal file
@@ -0,0 +1,338 @@
|
||||
//! Pipeline-parallel inference engine for the HTTP server (Phase 18).
|
||||
//!
|
||||
//! Layer-wise split: stage `s` holds layers `[s*L, (s+1)*L)`. Stage 0 owns the
|
||||
//! token embedding and acts as the coordinator (scheduler + tokenizer + response
|
||||
//! sender + stop logic); the last stage owns `norm`/`lm_head` and does sampling.
|
||||
//! Hidden states are handed off stage->stage via NCCL P2P (`PpContext`); the
|
||||
//! sampled token id (a single u32) is returned last-stage -> stage0 over an
|
||||
//! in-process channel (same process, so no NCCL needed for that).
|
||||
//!
|
||||
//! v1 is serial: one request at a time, one token per step, the pipeline is
|
||||
//! filled and drained each step (stage0's decode step t+1 depends on the token
|
||||
//! the last stage sampled at step t). This gives correctness + per-GPU memory
|
||||
//! savings; throughput via microbatch/1F1B overlap is future work
|
||||
//! (see docs/18-pipeline-parallelism.md).
|
||||
|
||||
use std::ffi::c_void;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::sync::mpsc;
|
||||
use std::thread;
|
||||
|
||||
use half::bf16;
|
||||
use xserv_distributed::{PpContext, UniqueId};
|
||||
use xserv_model::loader;
|
||||
use xserv_model::sampling::SamplingParams;
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, sample};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
use crate::engine::{GenerateEvent, GenerateRequest};
|
||||
|
||||
/// Control messages from the coordinator (stage 0) to a worker stage. The heavy
|
||||
/// hidden-state tensors do NOT travel here — they go GPU->GPU over NCCL. Only
|
||||
/// tiny control info (slot ids, token count, sampling params) is sent.
|
||||
#[derive(Clone)]
|
||||
enum PpCommand {
|
||||
Register(usize),
|
||||
Free(usize),
|
||||
/// Receive `[n_tokens, hidden]` from the previous stage, run this stage's
|
||||
/// layers; if last stage, sample with `sampling` and return the token.
|
||||
Prefill {
|
||||
n_tokens: usize,
|
||||
slot: usize,
|
||||
sampling: SamplingParams,
|
||||
},
|
||||
/// Receive `[1, hidden]`, run this stage's layers; last stage samples.
|
||||
Decode {
|
||||
slot: usize,
|
||||
sampling: SamplingParams,
|
||||
},
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
struct StageCtx {
|
||||
model: Qwen3,
|
||||
cache: PagedKVCache,
|
||||
pp: Arc<PpContext>,
|
||||
hidden: usize,
|
||||
device: u32,
|
||||
}
|
||||
|
||||
/// Build this stage: NCCL init, load + slice weights, size a per-stage KV pool
|
||||
/// for THIS stage's layers only (so per-GPU KV is ~1/P).
|
||||
fn build_stage(
|
||||
model_dir: &Path,
|
||||
config: &ModelConfig,
|
||||
stage: usize,
|
||||
world: usize,
|
||||
device: u32,
|
||||
max_seq_len: usize,
|
||||
id: UniqueId,
|
||||
) -> StageCtx {
|
||||
let pp = Arc::new(PpContext::init(stage, world, id, device));
|
||||
let weights = loader::load_model_dir(model_dir, Device::Cpu);
|
||||
let model = Qwen3::from_weights_pp(config.clone(), weights, stage, world, device);
|
||||
|
||||
// The KV cache only needs this stage's layers; build it from a config clone
|
||||
// whose layer count is the per-stage count (heads are NOT split under PP).
|
||||
let per_stage = config.num_layers() / world;
|
||||
let mut stage_config = config.clone();
|
||||
stage_config.num_hidden_layers = Some(per_stage);
|
||||
|
||||
let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE);
|
||||
let total_blocks = max_blocks_per_seq + 8; // v1 serial: one active sequence
|
||||
let cache = PagedKVCache::new(
|
||||
&stage_config,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
device,
|
||||
);
|
||||
StageCtx {
|
||||
model,
|
||||
cache,
|
||||
pp,
|
||||
hidden: config.hidden(),
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate a zeroed `[n, hidden]` device tensor and receive into it from `peer`.
|
||||
fn recv_hidden(sc: &StageCtx, n: usize, peer: usize) -> Tensor {
|
||||
let zeros = vec![bf16::ZERO; n * sc.hidden];
|
||||
let x = Tensor::from_slice(&zeros, &[n, sc.hidden]).to_device(Device::Cuda(sc.device));
|
||||
let ptr = x.storage().gpu_buffer().as_ptr() as *mut c_void;
|
||||
sc.pp.recv_bf16_ptr(ptr, n * sc.hidden, peer);
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
x
|
||||
}
|
||||
|
||||
/// Send the `[*, hidden]` hidden state to `peer`, then synchronize so NCCL has
|
||||
/// finished reading `x` before it is dropped/reused.
|
||||
fn send_hidden(sc: &StageCtx, x: &Tensor, peer: usize) {
|
||||
let ptr = x.storage().gpu_buffer().as_ptr() as *const c_void;
|
||||
sc.pp.send_bf16_ptr(ptr, x.numel(), peer);
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
|
||||
fn worker_loop(
|
||||
stage: usize,
|
||||
world: usize,
|
||||
id: UniqueId,
|
||||
model_dir: PathBuf,
|
||||
config: ModelConfig,
|
||||
max_seq_len: usize,
|
||||
cmd_rx: mpsc::Receiver<PpCommand>,
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
token_tx: mpsc::Sender<u32>,
|
||||
) {
|
||||
let mut sc = build_stage(
|
||||
&model_dir,
|
||||
&config,
|
||||
stage,
|
||||
world,
|
||||
stage as u32,
|
||||
max_seq_len,
|
||||
id,
|
||||
);
|
||||
let is_last = stage == world - 1;
|
||||
let prev = stage - 1;
|
||||
let next = stage + 1;
|
||||
|
||||
while let Ok(cmd) = cmd_rx.recv() {
|
||||
match cmd {
|
||||
PpCommand::Register(slot) => {
|
||||
let _ = sc.cache.register_sequence(slot);
|
||||
let _ = ack_tx.send(());
|
||||
}
|
||||
PpCommand::Free(slot) => {
|
||||
sc.cache.free_sequence(slot);
|
||||
let _ = ack_tx.send(());
|
||||
}
|
||||
PpCommand::Prefill {
|
||||
n_tokens,
|
||||
slot,
|
||||
sampling,
|
||||
} => {
|
||||
let x = recv_hidden(&sc, n_tokens, prev);
|
||||
let x = sc.model.forward_layers_prefill(x, slot, &mut sc.cache);
|
||||
if is_last {
|
||||
let logits = sc.model.head(&x);
|
||||
let _ = token_tx.send(sample(&logits, &sampling));
|
||||
} else {
|
||||
send_hidden(&sc, &x, next);
|
||||
}
|
||||
}
|
||||
PpCommand::Decode { slot, sampling } => {
|
||||
let x = recv_hidden(&sc, 1, prev);
|
||||
let x = sc.model.forward_layers_decode(x, &[slot], &mut sc.cache);
|
||||
if is_last {
|
||||
let logits = sc.model.head(&x);
|
||||
let _ = token_tx.send(sample(&logits, &sampling));
|
||||
} else {
|
||||
send_hidden(&sc, &x, next);
|
||||
}
|
||||
}
|
||||
PpCommand::Shutdown => {
|
||||
let _ = ack_tx.send(());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the PP coordinator (stage 0) on the calling thread. Spawns worker stages
|
||||
/// 1..world and consumes generation requests from `rx`.
|
||||
pub fn run_pp(
|
||||
model_dir: &Path,
|
||||
world: usize,
|
||||
max_seq_len: usize,
|
||||
rx: mpsc::Receiver<GenerateRequest>,
|
||||
) {
|
||||
assert!(world >= 2, "run_pp requires world >= 2");
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
assert!(
|
||||
config.num_layers() % world == 0,
|
||||
"num_layers {} not divisible by pp {world}",
|
||||
config.num_layers()
|
||||
);
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
let id = xserv_distributed::get_unique_id();
|
||||
|
||||
// Worker stages 1..world. Each gets a control channel; all share one ack
|
||||
// channel and one token channel (only the last stage actually sends tokens).
|
||||
let (ack_tx, ack_rx) = mpsc::channel::<()>();
|
||||
let (token_tx, token_rx) = mpsc::channel::<u32>();
|
||||
let mut cmd_txs: Vec<mpsc::Sender<PpCommand>> = Vec::new();
|
||||
for stage in 1..world {
|
||||
let (ctx_tx, ctx_rx) = mpsc::channel::<PpCommand>();
|
||||
cmd_txs.push(ctx_tx);
|
||||
let ack_tx = ack_tx.clone();
|
||||
let token_tx = token_tx.clone();
|
||||
let model_dir = model_dir.to_path_buf();
|
||||
let config = config.clone();
|
||||
thread::spawn(move || {
|
||||
worker_loop(
|
||||
stage,
|
||||
world,
|
||||
id,
|
||||
model_dir,
|
||||
config,
|
||||
max_seq_len,
|
||||
ctx_rx,
|
||||
ack_tx,
|
||||
token_tx,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
// Stage 0 (this thread): coordinator + embedding + first layers.
|
||||
let mut sc = build_stage(model_dir, &config, 0, world, 0, max_seq_len, id);
|
||||
eprintln!("[pp-engine] ready (pp={world}, max_seq_len={max_seq_len})");
|
||||
|
||||
let n_workers = world - 1;
|
||||
let next_peer = 1usize;
|
||||
let broadcast = |txs: &[mpsc::Sender<PpCommand>], cmd: PpCommand| {
|
||||
for t in txs {
|
||||
let _ = t.send(cmd.clone());
|
||||
}
|
||||
};
|
||||
let wait_acks = |rx: &mpsc::Receiver<()>| {
|
||||
for _ in 0..n_workers {
|
||||
let _ = rx.recv();
|
||||
}
|
||||
};
|
||||
|
||||
let slot = 0usize;
|
||||
while let Ok(req) = rx.recv() {
|
||||
broadcast(&cmd_txs, PpCommand::Register(slot));
|
||||
sc.cache.register_sequence(slot).expect("register slot");
|
||||
wait_acks(&ack_rx);
|
||||
|
||||
// Prefill: embed prompt, run stage-0 layers, push hidden into the pipe.
|
||||
broadcast(
|
||||
&cmd_txs,
|
||||
PpCommand::Prefill {
|
||||
n_tokens: req.prompt_tokens.len(),
|
||||
slot,
|
||||
sampling: req.sampling.clone(),
|
||||
},
|
||||
);
|
||||
let x = sc.model.embed(&req.prompt_tokens);
|
||||
let x = sc.model.forward_layers_prefill(x, slot, &mut sc.cache);
|
||||
send_hidden(&sc, &x, next_peer);
|
||||
let mut next = token_rx.recv().expect("prefill token");
|
||||
|
||||
let mut decode_buf: Vec<u8> = Vec::new();
|
||||
let mut generated = 1usize;
|
||||
let mut stalled = !emit_text(&tokenizer, &req, next, &mut decode_buf);
|
||||
|
||||
let finish = loop {
|
||||
if stalled {
|
||||
break "error";
|
||||
}
|
||||
if tokenizer.is_eos(next) {
|
||||
break "stop";
|
||||
}
|
||||
if generated >= req.max_tokens {
|
||||
break "length";
|
||||
}
|
||||
broadcast(
|
||||
&cmd_txs,
|
||||
PpCommand::Decode {
|
||||
slot,
|
||||
sampling: req.sampling.clone(),
|
||||
},
|
||||
);
|
||||
let x = sc.model.embed(&[next]);
|
||||
let x = sc.model.forward_layers_decode(x, &[slot], &mut sc.cache);
|
||||
send_hidden(&sc, &x, next_peer);
|
||||
next = token_rx.recv().expect("decode token");
|
||||
generated += 1;
|
||||
stalled = !emit_text(&tokenizer, &req, next, &mut decode_buf);
|
||||
};
|
||||
|
||||
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
|
||||
if !tail.is_empty() {
|
||||
let _ = req.sender.try_send(GenerateEvent::Token {
|
||||
id: next,
|
||||
text: tail,
|
||||
});
|
||||
}
|
||||
let _ = req.sender.try_send(GenerateEvent::Done {
|
||||
finish_reason: finish.to_string(),
|
||||
});
|
||||
|
||||
broadcast(&cmd_txs, PpCommand::Free(slot));
|
||||
sc.cache.free_sequence(slot);
|
||||
wait_acks(&ack_rx);
|
||||
}
|
||||
|
||||
broadcast(&cmd_txs, PpCommand::Shutdown);
|
||||
}
|
||||
|
||||
/// Stream a token's decoded text to the client (EOS contributes no text).
|
||||
/// Returns false if the send would block (client too slow) or the client is
|
||||
/// gone — the caller stops generating so the coordinator thread is free to
|
||||
/// admit the next request instead of blocking on one slow consumer.
|
||||
fn emit_text(
|
||||
tokenizer: &Tokenizer,
|
||||
req: &GenerateRequest,
|
||||
token_id: u32,
|
||||
buf: &mut Vec<u8>,
|
||||
) -> bool {
|
||||
if tokenizer.is_eos(token_id) {
|
||||
return true;
|
||||
}
|
||||
let text = tokenizer.decode_token_stream(token_id, buf);
|
||||
if !text.is_empty() {
|
||||
return req
|
||||
.sender
|
||||
.try_send(GenerateEvent::Token { id: token_id, text })
|
||||
.is_ok();
|
||||
}
|
||||
true
|
||||
}
|
||||
366
crates/xserv-server/src/tp_engine.rs
Normal file
366
crates/xserv-server/src/tp_engine.rs
Normal file
@@ -0,0 +1,366 @@
|
||||
//! Tensor-parallel inference engine for the HTTP server.
|
||||
//!
|
||||
//! Serial coordinator model: one rank-0 coordinator thread (the caller) drives
|
||||
//! generation and owns the scheduler; ranks 1..world are worker threads. For
|
||||
//! each step the coordinator broadcasts a command (Register/Prefill/Decode/Free)
|
||||
//! to the workers and runs the same op on its own shard; the per-layer NCCL
|
||||
//! AllReduces keep all ranks in lockstep. Only the coordinator samples — the
|
||||
//! chosen token is carried in the next Decode command, so this is correct for
|
||||
//! both greedy and stochastic sampling.
|
||||
//!
|
||||
//! Requests are processed one at a time (sufficient for the quality benchmark,
|
||||
//! which issues serial requests). Continuous batching across ranks is future
|
||||
//! work; the single-GPU `Engine` still handles TP=1.
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::sync::mpsc;
|
||||
use std::thread;
|
||||
|
||||
use xserv_distributed::{TpContext, UniqueId};
|
||||
use xserv_model::loader;
|
||||
use xserv_model::{
|
||||
BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, sample,
|
||||
sample_greedy_penalized,
|
||||
};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
use crate::engine::{GenerateEvent, GenerateRequest};
|
||||
|
||||
#[derive(Clone)]
|
||||
enum TpCommand {
|
||||
Register(usize),
|
||||
Free(usize),
|
||||
Prefill {
|
||||
tokens: Vec<u32>,
|
||||
slot: usize,
|
||||
},
|
||||
Decode {
|
||||
tokens: Vec<u32>,
|
||||
positions: Vec<usize>,
|
||||
slots: Vec<usize>,
|
||||
},
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
enum TpModel {
|
||||
Qwen3(Qwen3),
|
||||
GptOss(GptOss),
|
||||
}
|
||||
|
||||
impl TpModel {
|
||||
fn forward_prefill_paged(
|
||||
&self,
|
||||
tokens: &[u32],
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
match self {
|
||||
TpModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache),
|
||||
TpModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward_decode_paged(
|
||||
&self,
|
||||
tokens: &[u32],
|
||||
positions: &[usize],
|
||||
slots: &[usize],
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
match self {
|
||||
TpModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache),
|
||||
TpModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RankCtx {
|
||||
model: TpModel,
|
||||
cache: PagedKVCache,
|
||||
decoder: GraphedGptOssDecoder,
|
||||
}
|
||||
|
||||
/// Decode one step: gpt-oss batch=1 goes through the CUDA-graph decoder
|
||||
/// (lazy capture, replay thereafter); everything else runs eager.
|
||||
fn rank_decode(rc: &mut RankCtx, tokens: &[u32], positions: &[usize], slots: &[usize]) -> Tensor {
|
||||
match &rc.model {
|
||||
TpModel::GptOss(m) => rc
|
||||
.decoder
|
||||
.decode(m, tokens, positions, slots, &mut rc.cache),
|
||||
TpModel::Qwen3(_) => rc
|
||||
.model
|
||||
.forward_decode_paged(tokens, positions, slots, &mut rc.cache),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_rank(
|
||||
model_dir: &Path,
|
||||
config: &ModelConfig,
|
||||
rank: usize,
|
||||
world: usize,
|
||||
device: u32,
|
||||
max_seq_len: usize,
|
||||
tp: Option<Arc<TpContext>>,
|
||||
) -> RankCtx {
|
||||
let weights = loader::load_model_dir(model_dir, Device::Cpu);
|
||||
let model = if config.is_moe() {
|
||||
TpModel::GptOss(GptOss::from_weights_tp(
|
||||
config.clone(),
|
||||
weights,
|
||||
rank,
|
||||
world,
|
||||
device,
|
||||
tp,
|
||||
))
|
||||
} else {
|
||||
TpModel::Qwen3(Qwen3::from_weights_tp(
|
||||
config.clone(),
|
||||
weights,
|
||||
rank,
|
||||
world,
|
||||
device,
|
||||
tp,
|
||||
))
|
||||
};
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
let cache = PagedKVCache::new_tp(
|
||||
config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
device,
|
||||
);
|
||||
RankCtx {
|
||||
model,
|
||||
cache,
|
||||
decoder: GraphedGptOssDecoder::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn worker_loop(
|
||||
rank: usize,
|
||||
world: usize,
|
||||
id: UniqueId,
|
||||
model_dir: PathBuf,
|
||||
config: ModelConfig,
|
||||
max_seq_len: usize,
|
||||
cmd_rx: mpsc::Receiver<TpCommand>,
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
) {
|
||||
let tp = Arc::new(TpContext::init(rank, world, id, rank as u32));
|
||||
let mut rc = build_rank(
|
||||
&model_dir,
|
||||
&config,
|
||||
rank,
|
||||
world,
|
||||
rank as u32,
|
||||
max_seq_len,
|
||||
Some(tp),
|
||||
);
|
||||
while let Ok(cmd) = cmd_rx.recv() {
|
||||
match cmd {
|
||||
TpCommand::Register(slot) => {
|
||||
let _ = rc.cache.register_sequence(slot);
|
||||
}
|
||||
TpCommand::Free(slot) => rc.cache.free_sequence(slot),
|
||||
TpCommand::Prefill { tokens, slot } => {
|
||||
let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache);
|
||||
}
|
||||
TpCommand::Decode {
|
||||
tokens,
|
||||
positions,
|
||||
slots,
|
||||
} => {
|
||||
let _ = rank_decode(&mut rc, &tokens, &positions, &slots);
|
||||
}
|
||||
TpCommand::Shutdown => {
|
||||
let _ = ack_tx.send(());
|
||||
break;
|
||||
}
|
||||
}
|
||||
let _ = ack_tx.send(());
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the TP coordinator (rank 0) on the calling thread. Spawns worker ranks
|
||||
/// internally and consumes generation requests from `rx`.
|
||||
pub fn run_tp(
|
||||
model_dir: &Path,
|
||||
world: usize,
|
||||
max_seq_len: usize,
|
||||
rx: mpsc::Receiver<GenerateRequest>,
|
||||
) {
|
||||
// world=1 is a valid single-rank configuration (gpt-oss has no
|
||||
// single-GPU engine path; NCCL init and all_reduce no-op at world=1).
|
||||
assert!(world >= 1, "run_tp requires world >= 1");
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
assert!(
|
||||
config.num_kv_heads() % world == 0,
|
||||
"num_kv_heads {} not divisible by tp {world}",
|
||||
config.num_kv_heads()
|
||||
);
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
let id = xserv_distributed::get_unique_id();
|
||||
|
||||
// Spawn worker ranks 1..world.
|
||||
let (ack_tx, ack_rx) = mpsc::channel::<()>();
|
||||
let mut cmd_txs: Vec<mpsc::Sender<TpCommand>> = Vec::new();
|
||||
for rank in 1..world {
|
||||
let (ctx_tx, ctx_rx) = mpsc::channel::<TpCommand>();
|
||||
cmd_txs.push(ctx_tx);
|
||||
let ack_tx = ack_tx.clone();
|
||||
let model_dir = model_dir.to_path_buf();
|
||||
let config = config.clone();
|
||||
thread::spawn(move || {
|
||||
worker_loop(
|
||||
rank,
|
||||
world,
|
||||
id,
|
||||
model_dir,
|
||||
config,
|
||||
max_seq_len,
|
||||
ctx_rx,
|
||||
ack_tx,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
// Rank 0 (this thread).
|
||||
let tp = Arc::new(TpContext::init(0, world, id, 0));
|
||||
let mut rc = build_rank(model_dir, &config, 0, world, 0, max_seq_len, Some(tp));
|
||||
eprintln!("[tp-engine] ready (tp={world}, max_seq_len={max_seq_len})");
|
||||
|
||||
// Optional repetition penalty to break greedy repetition loops (reasoning
|
||||
// models loop under pure greedy when numerics diverge from the reference).
|
||||
// Off by default; XSERV_REP_PENALTY>1 enables it over the last
|
||||
// XSERV_REP_WINDOW generated tokens. Applied only on the greedy path.
|
||||
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1.0);
|
||||
let rep_window: usize = std::env::var("XSERV_REP_WINDOW")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(128);
|
||||
let pick = |logits: &Tensor, sp: &xserv_model::SamplingParams, history: &[u32]| -> u32 {
|
||||
if rep_penalty > 1.0 && sp.temperature == 0.0 {
|
||||
let start = history.len().saturating_sub(rep_window);
|
||||
sample_greedy_penalized(logits, &history[start..], rep_penalty)
|
||||
} else {
|
||||
sample(logits, sp)
|
||||
}
|
||||
};
|
||||
|
||||
let n_workers = world - 1;
|
||||
let broadcast = |txs: &[mpsc::Sender<TpCommand>], cmd: TpCommand| {
|
||||
for t in txs {
|
||||
let _ = t.send(cmd.clone());
|
||||
}
|
||||
};
|
||||
let wait_acks = |rx: &mpsc::Receiver<()>| {
|
||||
for _ in 0..n_workers {
|
||||
let _ = rx.recv();
|
||||
}
|
||||
};
|
||||
|
||||
let slot = 0usize;
|
||||
while let Ok(req) = rx.recv() {
|
||||
broadcast(&cmd_txs, TpCommand::Register(slot));
|
||||
rc.cache.register_sequence(slot).expect("register slot");
|
||||
wait_acks(&ack_rx);
|
||||
|
||||
// Prefill.
|
||||
broadcast(
|
||||
&cmd_txs,
|
||||
TpCommand::Prefill {
|
||||
tokens: req.prompt_tokens.clone(),
|
||||
slot,
|
||||
},
|
||||
);
|
||||
let logits = rc
|
||||
.model
|
||||
.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache);
|
||||
wait_acks(&ack_rx);
|
||||
let mut gen_ids: Vec<u32> = Vec::new();
|
||||
let mut next = pick(&logits, &req.sampling, &gen_ids);
|
||||
gen_ids.push(next);
|
||||
|
||||
let mut decode_buf: Vec<u8> = Vec::new();
|
||||
let mut generated = 1usize;
|
||||
let mut stalled = !emit_text(&tokenizer, &req, next, &mut decode_buf);
|
||||
|
||||
let finish = loop {
|
||||
if stalled {
|
||||
break "error";
|
||||
}
|
||||
if tokenizer.is_eos(next) {
|
||||
break "stop";
|
||||
}
|
||||
if generated >= req.max_tokens {
|
||||
break "length";
|
||||
}
|
||||
let pos = rc.cache.seq_len(slot);
|
||||
broadcast(
|
||||
&cmd_txs,
|
||||
TpCommand::Decode {
|
||||
tokens: vec![next],
|
||||
positions: vec![pos],
|
||||
slots: vec![slot],
|
||||
},
|
||||
);
|
||||
let logits = rank_decode(&mut rc, &[next], &[pos], &[slot]);
|
||||
wait_acks(&ack_rx);
|
||||
next = pick(&logits, &req.sampling, &gen_ids);
|
||||
gen_ids.push(next);
|
||||
generated += 1;
|
||||
stalled = !emit_text(&tokenizer, &req, next, &mut decode_buf);
|
||||
};
|
||||
|
||||
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
|
||||
if !tail.is_empty() {
|
||||
let _ = req.sender.try_send(GenerateEvent::Token {
|
||||
id: next,
|
||||
text: tail,
|
||||
});
|
||||
}
|
||||
let _ = req.sender.try_send(GenerateEvent::Done {
|
||||
finish_reason: finish.to_string(),
|
||||
});
|
||||
|
||||
broadcast(&cmd_txs, TpCommand::Free(slot));
|
||||
rc.cache.free_sequence(slot);
|
||||
wait_acks(&ack_rx);
|
||||
}
|
||||
|
||||
broadcast(&cmd_txs, TpCommand::Shutdown);
|
||||
}
|
||||
|
||||
/// Stream a token's decoded text to the client (EOS contributes no text).
|
||||
/// Returns false if the send would block (client too slow) or the client is
|
||||
/// gone — the caller stops generating so the serial coordinator thread is free
|
||||
/// to admit the next request instead of blocking on one slow consumer.
|
||||
fn emit_text(
|
||||
tokenizer: &Tokenizer,
|
||||
req: &GenerateRequest,
|
||||
token_id: u32,
|
||||
buf: &mut Vec<u8>,
|
||||
) -> bool {
|
||||
if tokenizer.is_eos(token_id) {
|
||||
return true;
|
||||
}
|
||||
let text = tokenizer.decode_token_stream(token_id, buf);
|
||||
if !text.is_empty() {
|
||||
return req
|
||||
.sender
|
||||
.try_send(GenerateEvent::Token { id: token_id, text })
|
||||
.is_ok();
|
||||
}
|
||||
true
|
||||
}
|
||||
@@ -5,6 +5,7 @@ pub enum DType {
|
||||
F32,
|
||||
F16,
|
||||
BF16,
|
||||
FP8E4M3,
|
||||
}
|
||||
|
||||
impl DType {
|
||||
@@ -13,6 +14,7 @@ impl DType {
|
||||
DType::F32 => 4,
|
||||
DType::F16 => 2,
|
||||
DType::BF16 => 2,
|
||||
DType::FP8E4M3 => 1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +23,7 @@ impl DType {
|
||||
DType::F32 => "f32",
|
||||
DType::F16 => "f16",
|
||||
DType::BF16 => "bf16",
|
||||
DType::FP8E4M3 => "fp8e4m3",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -40,18 +43,30 @@ pub trait TensorDType: Copy + Send + Sync + 'static {
|
||||
|
||||
impl TensorDType for f32 {
|
||||
const DTYPE: DType = DType::F32;
|
||||
fn to_f64(self) -> f64 { self as f64 }
|
||||
fn from_f64(v: f64) -> Self { v as f32 }
|
||||
fn to_f64(self) -> f64 {
|
||||
self as f64
|
||||
}
|
||||
fn from_f64(v: f64) -> Self {
|
||||
v as f32
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorDType for f16 {
|
||||
const DTYPE: DType = DType::F16;
|
||||
fn to_f64(self) -> f64 { self.to_f32() as f64 }
|
||||
fn from_f64(v: f64) -> Self { f16::from_f32(v as f32) }
|
||||
fn to_f64(self) -> f64 {
|
||||
self.to_f32() as f64
|
||||
}
|
||||
fn from_f64(v: f64) -> Self {
|
||||
f16::from_f32(v as f32)
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorDType for bf16 {
|
||||
const DTYPE: DType = DType::BF16;
|
||||
fn to_f64(self) -> f64 { self.to_f32() as f64 }
|
||||
fn from_f64(v: f64) -> Self { bf16::from_f32(v as f32) }
|
||||
fn to_f64(self) -> f64 {
|
||||
self.to_f32() as f64
|
||||
}
|
||||
fn from_f64(v: f64) -> Self {
|
||||
bf16::from_f32(v as f32)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,5 +4,6 @@ pub mod storage;
|
||||
pub mod tensor;
|
||||
|
||||
pub use dtype::{DType, TensorDType};
|
||||
pub use storage::Device;
|
||||
pub use tensor::Tensor;
|
||||
pub use shape::Dims;
|
||||
pub use storage::{Device, Storage};
|
||||
pub use tensor::{Tensor, register_gpu_contiguous};
|
||||
|
||||
@@ -18,12 +18,21 @@ pub fn contiguous_strides(shape: &[usize]) -> Dims {
|
||||
}
|
||||
|
||||
/// Check if the given strides represent contiguous (row-major) layout for the shape.
|
||||
/// A stride mismatch on a dimension of size 1 is allowed because that
|
||||
/// dimension is never stepped.
|
||||
pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
|
||||
if shape.is_empty() {
|
||||
return true;
|
||||
}
|
||||
let expected = contiguous_strides(shape);
|
||||
strides == expected.as_slice()
|
||||
let ndim = shape.len();
|
||||
let mut expected_stride = 1usize;
|
||||
for d in (0..ndim).rev() {
|
||||
if shape[d] != 1 && strides[d] != expected_stride {
|
||||
return false;
|
||||
}
|
||||
expected_stride *= shape[d];
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Total number of elements given a shape.
|
||||
@@ -37,8 +46,16 @@ pub fn broadcast_shape(a: &[usize], b: &[usize]) -> Option<Dims> {
|
||||
let ndim = a.len().max(b.len());
|
||||
let mut result = SmallVec::with_capacity(ndim);
|
||||
for i in 0..ndim {
|
||||
let da = if i < ndim - a.len() { 1 } else { a[i - (ndim - a.len())] };
|
||||
let db = if i < ndim - b.len() { 1 } else { b[i - (ndim - b.len())] };
|
||||
let da = if i < ndim - a.len() {
|
||||
1
|
||||
} else {
|
||||
a[i - (ndim - a.len())]
|
||||
};
|
||||
let db = if i < ndim - b.len() {
|
||||
1
|
||||
} else {
|
||||
b[i - (ndim - b.len())]
|
||||
};
|
||||
if da == db {
|
||||
result.push(da);
|
||||
} else if da == 1 {
|
||||
@@ -91,8 +108,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_broadcast_shape() {
|
||||
assert_eq!(broadcast_shape(&[3, 1], &[1, 4]).unwrap().as_slice(), &[3, 4]);
|
||||
assert_eq!(broadcast_shape(&[2, 3, 4], &[4]).unwrap().as_slice(), &[2, 3, 4]);
|
||||
assert_eq!(
|
||||
broadcast_shape(&[3, 1], &[1, 4]).unwrap().as_slice(),
|
||||
&[3, 4]
|
||||
);
|
||||
assert_eq!(
|
||||
broadcast_shape(&[2, 3, 4], &[4]).unwrap().as_slice(),
|
||||
&[2, 3, 4]
|
||||
);
|
||||
assert_eq!(broadcast_shape(&[1], &[5, 3]).unwrap().as_slice(), &[5, 3]);
|
||||
assert!(broadcast_shape(&[3], &[4]).is_none());
|
||||
}
|
||||
@@ -100,6 +123,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_broadcast_strides() {
|
||||
// [3,1] with strides [1,1] broadcast to [3,4]
|
||||
assert_eq!(broadcast_strides(&[3, 1], &[1, 1], &[3, 4]).as_slice(), &[1, 0]);
|
||||
assert_eq!(
|
||||
broadcast_strides(&[3, 1], &[1, 1], &[3, 4]).as_slice(),
|
||||
&[1, 0]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use xserv_cuda::{GpuBuffer, Result as CudaResult};
|
||||
|
||||
enum StorageInner {
|
||||
Cpu { data: Vec<u8> },
|
||||
Cuda { buffer: GpuBuffer },
|
||||
Cuda { buffer: GpuBuffer, device: u32 },
|
||||
}
|
||||
|
||||
/// Reference-counted storage for tensor data. Multiple tensors can share
|
||||
@@ -31,21 +31,21 @@ impl Storage {
|
||||
Self(Arc::new(StorageInner::Cpu { data }))
|
||||
}
|
||||
|
||||
pub fn cuda(buffer: GpuBuffer) -> Self {
|
||||
Self(Arc::new(StorageInner::Cuda { buffer }))
|
||||
pub fn cuda(buffer: GpuBuffer, device: u32) -> Self {
|
||||
Self(Arc::new(StorageInner::Cuda { buffer, device }))
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device {
|
||||
match self.0.as_ref() {
|
||||
StorageInner::Cpu { .. } => Device::Cpu,
|
||||
StorageInner::Cuda { .. } => Device::Cuda(0),
|
||||
StorageInner::Cuda { device, .. } => Device::Cuda(*device),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len_bytes(&self) -> usize {
|
||||
match self.0.as_ref() {
|
||||
StorageInner::Cpu { data } => data.len(),
|
||||
StorageInner::Cuda { buffer } => buffer.len(),
|
||||
StorageInner::Cuda { buffer, .. } => buffer.len(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ impl Storage {
|
||||
|
||||
pub fn gpu_buffer(&self) -> &GpuBuffer {
|
||||
match self.0.as_ref() {
|
||||
StorageInner::Cuda { buffer } => buffer,
|
||||
StorageInner::Cuda { buffer, .. } => buffer,
|
||||
StorageInner::Cpu { .. } => panic!("cannot access CPU storage as GPU buffer"),
|
||||
}
|
||||
}
|
||||
@@ -71,11 +71,11 @@ impl Storage {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
match (current, target) {
|
||||
(Device::Cpu, Device::Cuda(_dev)) => {
|
||||
(Device::Cpu, Device::Cuda(dev)) => {
|
||||
let cpu_data = self.as_cpu_bytes();
|
||||
let mut buf = GpuBuffer::alloc(cpu_data.len())?;
|
||||
let mut buf = xserv_cuda::allocator::cached_alloc(cpu_data.len())?;
|
||||
buf.copy_from_host(cpu_data)?;
|
||||
Ok(Storage::cuda(buf))
|
||||
Ok(Storage::cuda(buf, dev))
|
||||
}
|
||||
(Device::Cuda(_), Device::Cpu) => {
|
||||
let gpu_buf = self.gpu_buffer();
|
||||
@@ -83,11 +83,11 @@ impl Storage {
|
||||
gpu_buf.copy_to_host(&mut data)?;
|
||||
Ok(Storage::cpu(data))
|
||||
}
|
||||
(Device::Cuda(_), Device::Cuda(_)) => {
|
||||
(Device::Cuda(_), Device::Cuda(dev)) => {
|
||||
let src = self.gpu_buffer();
|
||||
let mut dst = GpuBuffer::alloc(src.len())?;
|
||||
let mut dst = xserv_cuda::allocator::cached_alloc(src.len())?;
|
||||
dst.copy_from_device(src)?;
|
||||
Ok(Storage::cuda(dst))
|
||||
Ok(Storage::cuda(dst, dev))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
@@ -97,10 +97,10 @@ impl Storage {
|
||||
pub fn deep_copy(&self) -> CudaResult<Self> {
|
||||
match self.0.as_ref() {
|
||||
StorageInner::Cpu { data } => Ok(Storage::cpu(data.clone())),
|
||||
StorageInner::Cuda { buffer } => {
|
||||
let mut dst = GpuBuffer::alloc(buffer.len())?;
|
||||
StorageInner::Cuda { buffer, device } => {
|
||||
let mut dst = xserv_cuda::allocator::cached_alloc(buffer.len())?;
|
||||
dst.copy_from_device(buffer)?;
|
||||
Ok(Storage::cuda(dst))
|
||||
Ok(Storage::cuda(dst, *device))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -109,10 +109,24 @@ impl Storage {
|
||||
pub fn zeros(len_bytes: usize, device: Device) -> CudaResult<Self> {
|
||||
match device {
|
||||
Device::Cpu => Ok(Storage::cpu(vec![0u8; len_bytes])),
|
||||
Device::Cuda(_) => {
|
||||
let mut buf = GpuBuffer::alloc(len_bytes)?;
|
||||
Device::Cuda(dev) => {
|
||||
let mut buf = xserv_cuda::allocator::cached_alloc(len_bytes)?;
|
||||
buf.zero()?;
|
||||
Ok(Storage::cuda(buf))
|
||||
Ok(Storage::cuda(buf, dev))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate storage **without zeroing** on the given device.
|
||||
/// The buffer may contain stale data from the caching allocator's pool.
|
||||
/// Only use when the caller guarantees the kernel will fully overwrite
|
||||
/// every element before any read.
|
||||
pub fn empty(len_bytes: usize, device: Device) -> CudaResult<Self> {
|
||||
match device {
|
||||
Device::Cpu => Ok(Storage::cpu(vec![0u8; len_bytes])), // CPU still zeros (cheap)
|
||||
Device::Cuda(dev) => {
|
||||
let buf = xserv_cuda::allocator::cached_alloc(len_bytes)?;
|
||||
Ok(Storage::cuda(buf, dev))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,21 @@
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use crate::dtype::{DType, TensorDType};
|
||||
use crate::shape::{self, Dims};
|
||||
use crate::storage::{Device, Storage};
|
||||
|
||||
/// Global hook for GPU strided-to-contiguous copy.
|
||||
/// Set by `xserv-kernels` (or any crate that provides a GPU kernel) via
|
||||
/// `register_gpu_contiguous`. When set, `contiguous()` on a non-contiguous
|
||||
/// GPU tensor calls this instead of doing a CPU round-trip.
|
||||
static GPU_CONTIGUOUS_FN: OnceLock<fn(&Tensor) -> Tensor> = OnceLock::new();
|
||||
|
||||
/// Register a function that makes a non-contiguous GPU tensor contiguous.
|
||||
/// Intended to be called once by the kernel crate at startup.
|
||||
pub fn register_gpu_contiguous(f: fn(&Tensor) -> Tensor) {
|
||||
let _ = GPU_CONTIGUOUS_FN.set(f);
|
||||
}
|
||||
|
||||
/// Multi-dimensional array with CPU or GPU storage.
|
||||
///
|
||||
/// Tensors support view semantics: transpose, slice, etc. share
|
||||
@@ -18,6 +32,23 @@ pub struct Tensor {
|
||||
impl Tensor {
|
||||
// --- Creation ---
|
||||
|
||||
/// Create a tensor from raw components (for advanced use like GPU KV cache).
|
||||
pub fn from_storage(
|
||||
storage: Storage,
|
||||
shape: Dims,
|
||||
strides: Dims,
|
||||
offset: usize,
|
||||
dtype: DType,
|
||||
) -> Self {
|
||||
Self {
|
||||
storage,
|
||||
shape,
|
||||
strides,
|
||||
offset,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_slice<T: TensorDType>(data: &[T], shape: &[usize]) -> Self {
|
||||
let numel: usize = shape.iter().product();
|
||||
assert_eq!(data.len(), numel, "data length mismatch with shape");
|
||||
@@ -33,6 +64,28 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a tensor from raw bytes. Used for dtypes without a Rust type
|
||||
/// (e.g. FP8 E4M3) where we store the bit pattern as-is.
|
||||
pub fn from_raw_bytes(data: &[u8], shape: &[usize], dtype: DType) -> Self {
|
||||
let numel: usize = shape.iter().product();
|
||||
assert_eq!(
|
||||
data.len(),
|
||||
numel * dtype.size_bytes(),
|
||||
"raw bytes length {} != expected {} (numel={} * elem_size={})",
|
||||
data.len(),
|
||||
numel * dtype.size_bytes(),
|
||||
numel,
|
||||
dtype.size_bytes()
|
||||
);
|
||||
Self {
|
||||
storage: Storage::cpu(data.to_vec()),
|
||||
shape: Dims::from_slice(shape),
|
||||
strides: shape::contiguous_strides(shape),
|
||||
offset: 0,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
|
||||
let numel = shape::num_elements(shape);
|
||||
let len_bytes = numel * dtype.size_bytes();
|
||||
@@ -46,25 +99,56 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate a tensor **without zeroing** the backing memory.
|
||||
/// The buffer may contain stale data. Only use when the calling kernel
|
||||
/// will fully overwrite every element before any read.
|
||||
pub fn empty(shape: &[usize], dtype: DType, device: Device) -> Self {
|
||||
let numel = shape::num_elements(shape);
|
||||
let len_bytes = numel * dtype.size_bytes();
|
||||
let storage = Storage::empty(len_bytes, device).expect("alloc failed");
|
||||
Self {
|
||||
storage,
|
||||
shape: Dims::from_slice(shape),
|
||||
strides: shape::contiguous_strides(shape),
|
||||
offset: 0,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ones(shape: &[usize], dtype: DType) -> Self {
|
||||
let numel = shape::num_elements(shape);
|
||||
match dtype {
|
||||
DType::F32 => Self::from_slice(&vec![1.0f32; numel], shape),
|
||||
DType::F16 => Self::from_slice(&vec![half::f16::from_f32(1.0); numel], shape),
|
||||
DType::BF16 => Self::from_slice(&vec![half::bf16::from_f32(1.0); numel], shape),
|
||||
DType::FP8E4M3 => panic!("ones() not supported for FP8E4M3"),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Properties ---
|
||||
|
||||
pub fn shape(&self) -> &[usize] { &self.shape }
|
||||
pub fn strides(&self) -> &[usize] { &self.strides }
|
||||
pub fn dtype(&self) -> DType { self.dtype }
|
||||
pub fn ndim(&self) -> usize { self.shape.len() }
|
||||
pub fn numel(&self) -> usize { shape::num_elements(&self.shape) }
|
||||
pub fn offset(&self) -> usize { self.offset }
|
||||
pub fn shape(&self) -> &[usize] {
|
||||
&self.shape
|
||||
}
|
||||
pub fn strides(&self) -> &[usize] {
|
||||
&self.strides
|
||||
}
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
pub fn ndim(&self) -> usize {
|
||||
self.shape.len()
|
||||
}
|
||||
pub fn numel(&self) -> usize {
|
||||
shape::num_elements(&self.shape)
|
||||
}
|
||||
pub fn offset(&self) -> usize {
|
||||
self.offset
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device { self.storage.device() }
|
||||
pub fn device(&self) -> Device {
|
||||
self.storage.device()
|
||||
}
|
||||
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
shape::is_contiguous(&self.shape, &self.strides)
|
||||
@@ -85,6 +169,21 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Zero-copy slice along `dim`: keeps elements `[start, start+len)`.
|
||||
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Self {
|
||||
assert!(dim < self.ndim());
|
||||
assert!(start + len <= self.shape[dim], "narrow out of bounds");
|
||||
let mut new_shape = self.shape.clone();
|
||||
new_shape[dim] = len;
|
||||
Self {
|
||||
storage: self.storage.clone(),
|
||||
shape: new_shape,
|
||||
strides: self.strides.clone(),
|
||||
offset: self.offset + start * self.strides[dim],
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn transpose(&self, dim0: usize, dim1: usize) -> Self {
|
||||
assert!(dim0 < self.ndim() && dim1 < self.ndim());
|
||||
let mut new_shape = self.shape.clone();
|
||||
@@ -118,10 +217,19 @@ impl Tensor {
|
||||
pub fn unsqueeze(&self, dim: usize) -> Self {
|
||||
assert!(dim <= self.ndim());
|
||||
let mut new_shape = self.shape.clone();
|
||||
let mut new_strides = self.strides.clone();
|
||||
new_shape.insert(dim, 1);
|
||||
let stride_val = if dim < self.strides.len() { self.strides[dim] } else { 1 };
|
||||
new_strides.insert(dim, stride_val);
|
||||
let new_strides = if self.is_contiguous() {
|
||||
shape::contiguous_strides(&new_shape)
|
||||
} else {
|
||||
let mut s = self.strides.clone();
|
||||
let stride_val = if dim < self.strides.len() {
|
||||
self.strides[dim]
|
||||
} else {
|
||||
1
|
||||
};
|
||||
s.insert(dim, stride_val);
|
||||
s
|
||||
};
|
||||
Self {
|
||||
storage: self.storage.clone(),
|
||||
shape: new_shape,
|
||||
@@ -137,8 +245,16 @@ impl Tensor {
|
||||
if self.is_contiguous() {
|
||||
return self.clone();
|
||||
}
|
||||
// Copy to contiguous layout on CPU
|
||||
assert_eq!(self.device(), Device::Cpu, "contiguous() on GPU not yet supported");
|
||||
// For GPU tensors: use the registered GPU kernel if available,
|
||||
// otherwise fall back to CPU round-trip.
|
||||
if matches!(self.device(), Device::Cuda(_)) {
|
||||
if let Some(gpu_fn) = GPU_CONTIGUOUS_FN.get() {
|
||||
return gpu_fn(self);
|
||||
}
|
||||
let cpu = self.to_device(Device::Cpu);
|
||||
let contig = cpu.contiguous();
|
||||
return contig.to_device(self.device());
|
||||
}
|
||||
let numel = self.numel();
|
||||
let elem_size = self.dtype.size_bytes();
|
||||
let src_bytes = self.storage.as_cpu_bytes();
|
||||
@@ -147,7 +263,12 @@ impl Tensor {
|
||||
let ndim = self.ndim();
|
||||
let mut idx = vec![0usize; ndim];
|
||||
for flat in 0..numel {
|
||||
let src_offset = self.offset + idx.iter().zip(self.strides.iter()).map(|(i, s)| i * s).sum::<usize>();
|
||||
let src_offset = self.offset
|
||||
+ idx
|
||||
.iter()
|
||||
.zip(self.strides.iter())
|
||||
.map(|(i, s)| i * s)
|
||||
.sum::<usize>();
|
||||
let src_byte_offset = src_offset * elem_size;
|
||||
let dst_byte_offset = flat * elem_size;
|
||||
dst[dst_byte_offset..dst_byte_offset + elem_size]
|
||||
@@ -173,17 +294,21 @@ impl Tensor {
|
||||
// --- Device transfer ---
|
||||
|
||||
pub fn to_device(&self, device: Device) -> Self {
|
||||
let t = if self.is_contiguous() { self.clone() } else { self.contiguous() };
|
||||
if t.device() == device {
|
||||
return t;
|
||||
if self.device() == device {
|
||||
return self.clone();
|
||||
}
|
||||
let new_storage = t.storage.to_device(device).expect("device transfer failed");
|
||||
// Transfer the raw storage (preserving strides/offset).
|
||||
// Non-contiguous layout is preserved — the user can call contiguous() after.
|
||||
let new_storage = self
|
||||
.storage
|
||||
.to_device(device)
|
||||
.expect("device transfer failed");
|
||||
Self {
|
||||
storage: new_storage,
|
||||
shape: t.shape,
|
||||
strides: t.strides,
|
||||
offset: 0,
|
||||
dtype: t.dtype,
|
||||
shape: self.shape.clone(),
|
||||
strides: self.strides.clone(),
|
||||
offset: self.offset,
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,6 +326,17 @@ impl Tensor {
|
||||
unsafe { std::slice::from_raw_parts(bytes[start..].as_ptr() as *const T, len) }
|
||||
}
|
||||
|
||||
/// Raw byte access for dtypes without a Rust type (e.g. FP8).
|
||||
pub fn as_raw_bytes(&self) -> &[u8] {
|
||||
assert!(self.is_contiguous(), "as_raw_bytes requires contiguous");
|
||||
assert_eq!(self.device(), Device::Cpu, "as_raw_bytes requires CPU");
|
||||
let bytes = self.storage.as_cpu_bytes();
|
||||
let elem_size = self.dtype.size_bytes();
|
||||
let start = self.offset * elem_size;
|
||||
let len = self.numel() * elem_size;
|
||||
&bytes[start..start + len]
|
||||
}
|
||||
|
||||
/// Raw pointer to storage start (for GPU kernel launch).
|
||||
pub fn data_ptr(&self) -> *const u8 {
|
||||
match self.device() {
|
||||
@@ -215,14 +351,75 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn storage(&self) -> &Storage { &self.storage }
|
||||
pub fn storage(&self) -> &Storage {
|
||||
&self.storage
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Tensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f, "Tensor(shape={:?}, dtype={}, device={}, contiguous={})",
|
||||
self.shape.as_slice(), self.dtype, self.device(), self.is_contiguous()
|
||||
f,
|
||||
"Tensor(shape={:?}, dtype={}, device={}, contiguous={})",
|
||||
self.shape.as_slice(),
|
||||
self.dtype,
|
||||
self.device(),
|
||||
self.is_contiguous()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn contiguous_2d() -> Tensor {
|
||||
Tensor::from_slice(&[1.0f32; 12], &[3, 4])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unsqueeze_dim0_contiguous() {
|
||||
let t = contiguous_2d();
|
||||
let u = t.unsqueeze(0);
|
||||
assert_eq!(u.shape(), &[1, 3, 4]);
|
||||
assert!(u.is_contiguous());
|
||||
assert_eq!(u.strides(), &[12, 4, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unsqueeze_dim1_contiguous() {
|
||||
let t = contiguous_2d();
|
||||
let u = t.unsqueeze(1);
|
||||
assert_eq!(u.shape(), &[3, 1, 4]);
|
||||
assert!(u.is_contiguous());
|
||||
assert_eq!(u.strides(), &[4, 4, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unsqueeze_dim2_contiguous() {
|
||||
let t = contiguous_2d();
|
||||
let u = t.unsqueeze(2);
|
||||
assert_eq!(u.shape(), &[3, 4, 1]);
|
||||
assert!(u.is_contiguous());
|
||||
assert_eq!(u.strides(), &[4, 1, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unsqueeze_noncontiguous() {
|
||||
// Transpose makes [3,4] into [4,3] with strides [1,4] (non-contiguous)
|
||||
let t = contiguous_2d().transpose(0, 1);
|
||||
assert!(!t.is_contiguous());
|
||||
let u = t.unsqueeze(0);
|
||||
assert_eq!(u.shape(), &[1, 4, 3]);
|
||||
// Non-contiguous path: stride_val copied from strides[0]=1
|
||||
assert_eq!(u.strides(), &[1, 1, 4]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unsqueeze_squeeze_roundtrip() {
|
||||
let t = contiguous_2d();
|
||||
let u = t.unsqueeze(1).squeeze(1);
|
||||
assert_eq!(u.shape(), t.shape());
|
||||
assert!(u.is_contiguous());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,11 @@ fn test_zeros_and_ones() {
|
||||
|
||||
#[test]
|
||||
fn test_bf16_tensor() {
|
||||
let data: Vec<bf16> = vec![bf16::from_f32(1.0), bf16::from_f32(2.5), bf16::from_f32(-3.0)];
|
||||
let data: Vec<bf16> = vec![
|
||||
bf16::from_f32(1.0),
|
||||
bf16::from_f32(2.5),
|
||||
bf16::from_f32(-3.0),
|
||||
];
|
||||
let t = Tensor::from_slice(&data, &[3]);
|
||||
assert_eq!(t.dtype(), DType::BF16);
|
||||
let out = t.as_slice::<bf16>();
|
||||
|
||||
9
crates/xserv-tokenizer/Cargo.toml
Normal file
9
crates/xserv-tokenizer/Cargo.toml
Normal file
@@ -0,0 +1,9 @@
|
||||
[package]
|
||||
name = "xserv-tokenizer"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
regex.workspace = true
|
||||
452
crates/xserv-tokenizer/src/bpe.rs
Normal file
452
crates/xserv-tokenizer/src/bpe.rs
Normal file
@@ -0,0 +1,452 @@
|
||||
use regex::Regex;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
pub struct Tokenizer {
|
||||
encoder: HashMap<Vec<u8>, u32>,
|
||||
decoder: Vec<Vec<u8>>,
|
||||
merge_ranks: HashMap<(u32, u32), usize>,
|
||||
special_tokens: HashMap<String, u32>,
|
||||
#[allow(dead_code)]
|
||||
special_token_ids: HashMap<u32, String>,
|
||||
pre_tokenize_re: Regex,
|
||||
eos_token_id: Option<u32>,
|
||||
eos_token_ids: Vec<u32>,
|
||||
byte_fallback: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TokenizerJson {
|
||||
model: ModelSection,
|
||||
#[serde(default)]
|
||||
added_tokens: Vec<AddedToken>,
|
||||
#[serde(default)]
|
||||
pre_tokenizer: Option<PreTokenizerSection>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct PreTokenizerSection {
|
||||
#[serde(default, rename = "type")]
|
||||
kind: Option<String>,
|
||||
#[serde(default)]
|
||||
pattern: Option<PatternSpec>,
|
||||
#[serde(default)]
|
||||
pretokenizers: Option<Vec<PreTokenizerSection>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct PatternSpec {
|
||||
#[serde(rename = "Regex")]
|
||||
regex: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ModelSection {
|
||||
vocab: HashMap<String, u32>,
|
||||
merges: Vec<MergeEntry>,
|
||||
#[serde(default)]
|
||||
byte_fallback: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum MergeEntry {
|
||||
Str(String),
|
||||
Pair(Vec<String>),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AddedToken {
|
||||
id: u32,
|
||||
content: String,
|
||||
#[allow(dead_code)]
|
||||
special: bool,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
pub fn from_file(path: &Path) -> Self {
|
||||
let data = std::fs::read_to_string(path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
|
||||
let tj: TokenizerJson = serde_json::from_str(&data)
|
||||
.unwrap_or_else(|e| panic!("failed to parse tokenizer.json: {e}"));
|
||||
|
||||
// Build encoder: token bytes → ID
|
||||
// All HF tokenizers use GPT-2 byte-to-unicode mapping for vocab keys.
|
||||
let mut encoder = HashMap::new();
|
||||
for (token_str, &id) in &tj.model.vocab {
|
||||
let bytes = token_str_to_bytes(token_str);
|
||||
encoder.insert(bytes, id);
|
||||
}
|
||||
|
||||
// Build decoder: ID → token bytes
|
||||
let max_id = tj.model.vocab.values().copied().max().unwrap_or(0);
|
||||
let added_max = tj.added_tokens.iter().map(|t| t.id).max().unwrap_or(0);
|
||||
let vocab_size = (max_id.max(added_max) + 1) as usize;
|
||||
let mut decoder = vec![vec![]; vocab_size];
|
||||
for (token_str, &id) in &tj.model.vocab {
|
||||
decoder[id as usize] = token_str_to_bytes(token_str);
|
||||
}
|
||||
|
||||
// Parse merges (supports both "a b" string format and ["a", "b"] array format)
|
||||
let byte_fallback = tj.model.byte_fallback;
|
||||
let mut merge_ranks = HashMap::new();
|
||||
for (rank, entry) in tj.model.merges.iter().enumerate() {
|
||||
let (a_str, b_str) = match entry {
|
||||
MergeEntry::Str(s) => {
|
||||
let parts: Vec<&str> = s.splitn(2, ' ').collect();
|
||||
if parts.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
(parts[0].to_string(), parts[1].to_string())
|
||||
}
|
||||
MergeEntry::Pair(v) => {
|
||||
if v.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
(v[0].clone(), v[1].clone())
|
||||
}
|
||||
};
|
||||
let a_bytes = token_str_to_bytes(&a_str);
|
||||
let b_bytes = token_str_to_bytes(&b_str);
|
||||
if let (Some(&a_id), Some(&b_id)) = (encoder.get(&a_bytes), encoder.get(&b_bytes)) {
|
||||
merge_ranks.insert((a_id, b_id), rank);
|
||||
}
|
||||
}
|
||||
|
||||
// Added tokens are matched as indivisible tokens by HF tokenizers,
|
||||
// even when their `special` flag is false (for example Qwen3's
|
||||
// <think> and </think> tokens).
|
||||
let mut special_tokens = HashMap::new();
|
||||
let mut special_token_ids = HashMap::new();
|
||||
for at in &tj.added_tokens {
|
||||
special_tokens.insert(at.content.clone(), at.id);
|
||||
special_token_ids.insert(at.id, at.content.clone());
|
||||
decoder.resize(decoder.len().max(at.id as usize + 1), vec![]);
|
||||
decoder[at.id as usize] = at.content.as_bytes().to_vec();
|
||||
}
|
||||
// End-of-generation tokens, in priority order. Families differ:
|
||||
// Qwen uses <|im_end|>, Llama <|end_of_text|>, GPT-2 <|endoftext|>.
|
||||
// gpt-oss (harmony) ends the assistant turn with <|return|> and also
|
||||
// treats <|call|> (tool call) and <|endoftext|> as terminators
|
||||
// (see generation_config.json eos_token_id = [200002, 199999, 200012]).
|
||||
let eos_names = [
|
||||
"<|im_end|>",
|
||||
"<|end_of_text|>",
|
||||
"<|return|>",
|
||||
"<|call|>",
|
||||
"<|endoftext|>",
|
||||
];
|
||||
let mut eos_token_ids: Vec<u32> = Vec::new();
|
||||
for name in eos_names {
|
||||
if let Some(&id) = special_tokens.get(name) {
|
||||
if !eos_token_ids.contains(&id) {
|
||||
eos_token_ids.push(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
let eos_token_id = eos_token_ids.first().copied();
|
||||
|
||||
// Pre-tokenization regex: prefer the model's own regex from tokenizer.json,
|
||||
// fall back to GPT-2/Qwen heuristic if not present or unsupported.
|
||||
let model_regex = tj.pre_tokenizer.as_ref().and_then(|pt| {
|
||||
// Direct Split with regex
|
||||
if pt.kind.as_deref() == Some("Split") {
|
||||
return pt.pattern.as_ref().and_then(|p| p.regex.clone());
|
||||
}
|
||||
// Sequence → find the Split entry
|
||||
if let Some(subs) = &pt.pretokenizers {
|
||||
for sub in subs {
|
||||
if sub.kind.as_deref() == Some("Split") {
|
||||
if let Some(r) = sub.pattern.as_ref().and_then(|p| p.regex.clone()) {
|
||||
return Some(r);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
});
|
||||
|
||||
let pre_tokenize_re = if let Some(ref pat) = model_regex {
|
||||
// Strip unsupported lookahead (?!\S) — Rust regex doesn't support it.
|
||||
// The lookahead only affects trailing-whitespace edge cases.
|
||||
let cleaned = pat.replace(r"(?!\S)", "");
|
||||
match Regex::new(&cleaned) {
|
||||
Ok(re) => re,
|
||||
Err(e) => {
|
||||
eprintln!("warning: model pre_tokenizer regex failed ({e}), using fallback");
|
||||
if byte_fallback {
|
||||
Regex::new(r"[\p{L}\p{N}]+|[^\s\p{L}\p{N}]|\s+").unwrap()
|
||||
} else {
|
||||
Regex::new(
|
||||
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+",
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if byte_fallback {
|
||||
Regex::new(r"[\p{L}\p{N}]+|[^\s\p{L}\p{N}]|\s+").unwrap()
|
||||
} else {
|
||||
Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+").unwrap()
|
||||
};
|
||||
|
||||
Self {
|
||||
encoder,
|
||||
decoder,
|
||||
merge_ranks,
|
||||
special_tokens,
|
||||
special_token_ids,
|
||||
pre_tokenize_re,
|
||||
eos_token_id,
|
||||
eos_token_ids,
|
||||
byte_fallback,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encode(&self, text: &str) -> Vec<u32> {
|
||||
let mut tokens = Vec::new();
|
||||
|
||||
// Check for special tokens first (split around them)
|
||||
let mut remaining = text;
|
||||
while !remaining.is_empty() {
|
||||
// Find earliest special token
|
||||
let mut earliest: Option<(usize, &str, u32)> = None;
|
||||
for (st, &id) in &self.special_tokens {
|
||||
if let Some(pos) = remaining.find(st.as_str()) {
|
||||
if earliest.is_none() || pos < earliest.unwrap().0 {
|
||||
earliest = Some((pos, st, id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((pos, st, id)) = earliest {
|
||||
if pos > 0 {
|
||||
self.encode_ordinary(&remaining[..pos], &mut tokens);
|
||||
}
|
||||
tokens.push(id);
|
||||
remaining = &remaining[pos + st.len()..];
|
||||
} else {
|
||||
self.encode_ordinary(remaining, &mut tokens);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tokens
|
||||
}
|
||||
|
||||
fn encode_ordinary(&self, text: &str, out: &mut Vec<u32>) {
|
||||
for mat in self.pre_tokenize_re.find_iter(text) {
|
||||
let word = mat.as_str();
|
||||
// Try to encode the whole word first
|
||||
if let Some(&id) = self.encoder.get(word.as_bytes()) {
|
||||
out.push(id);
|
||||
continue;
|
||||
}
|
||||
// Fall back to per-byte encoding
|
||||
let word_bytes: Vec<u8> = word.bytes().collect();
|
||||
let mut token_ids: Vec<u32> = word_bytes.iter().filter_map(|&b| {
|
||||
if let Some(&id) = self.encoder.get(&vec![b]) {
|
||||
Some(id)
|
||||
} else if self.byte_fallback {
|
||||
let hex_token = format!("<0x{:02X}>", b);
|
||||
if let Some(&id) = self.special_tokens.get(&hex_token) {
|
||||
Some(id)
|
||||
} else if let Some(&id) = self.encoder.get(hex_token.as_bytes()) {
|
||||
Some(id)
|
||||
} else if let Some(&unk_id) = self.special_tokens.get("<unk>") {
|
||||
eprintln!("warning: byte 0x{b:02X} not in vocab, using <unk> token");
|
||||
Some(unk_id)
|
||||
} else {
|
||||
eprintln!("warning: byte 0x{b:02X} not in vocab and no fallback token, using token 0");
|
||||
Some(0)
|
||||
}
|
||||
} else {
|
||||
eprintln!("warning: byte {b} (0x{b:02X}) not in vocab, skipping");
|
||||
None
|
||||
}
|
||||
}).collect();
|
||||
|
||||
// BPE merges
|
||||
loop {
|
||||
if token_ids.len() < 2 {
|
||||
break;
|
||||
}
|
||||
let mut best_rank = usize::MAX;
|
||||
let mut best_idx = 0;
|
||||
for i in 0..token_ids.len() - 1 {
|
||||
if let Some(&rank) = self.merge_ranks.get(&(token_ids[i], token_ids[i + 1])) {
|
||||
if rank < best_rank {
|
||||
best_rank = rank;
|
||||
best_idx = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
if best_rank == usize::MAX {
|
||||
break;
|
||||
}
|
||||
|
||||
let merged_bytes = [
|
||||
self.decoder[token_ids[best_idx] as usize].as_slice(),
|
||||
self.decoder[token_ids[best_idx + 1] as usize].as_slice(),
|
||||
]
|
||||
.concat();
|
||||
let merged_id = *self.encoder.get(&merged_bytes).unwrap_or_else(|| {
|
||||
panic!("merged token not in vocab");
|
||||
});
|
||||
token_ids[best_idx] = merged_id;
|
||||
token_ids.remove(best_idx + 1);
|
||||
}
|
||||
|
||||
out.extend_from_slice(&token_ids);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode(&self, token_ids: &[u32]) -> String {
|
||||
let mut bytes = Vec::new();
|
||||
for &id in token_ids {
|
||||
if let Some(b) = self.decoder.get(id as usize) {
|
||||
bytes.extend_from_slice(b);
|
||||
}
|
||||
}
|
||||
String::from_utf8_lossy(&bytes).into_owned()
|
||||
}
|
||||
|
||||
pub fn decode_token_stream(&self, token_id: u32, pending: &mut Vec<u8>) -> String {
|
||||
if let Some(bytes) = self.decoder.get(token_id as usize) {
|
||||
pending.extend_from_slice(bytes);
|
||||
}
|
||||
take_valid_utf8(pending)
|
||||
}
|
||||
|
||||
pub fn flush_decode_stream(&self, pending: &mut Vec<u8>) -> String {
|
||||
let text = String::from_utf8_lossy(pending).into_owned();
|
||||
pending.clear();
|
||||
text
|
||||
}
|
||||
|
||||
pub fn eos_token_id(&self) -> Option<u32> {
|
||||
self.eos_token_id
|
||||
}
|
||||
|
||||
/// True if `id` is any end-of-generation token (a model may have several;
|
||||
/// gpt-oss/harmony ends on <|return|>, <|call|>, or <|endoftext|>).
|
||||
pub fn is_eos(&self, id: u32) -> bool {
|
||||
self.eos_token_ids.contains(&id)
|
||||
}
|
||||
|
||||
pub fn vocab_size(&self) -> usize {
|
||||
self.decoder.len()
|
||||
}
|
||||
|
||||
pub fn special_token_id(&self, name: &str) -> Option<u32> {
|
||||
self.special_tokens.get(name).copied()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a token string from HF vocab (which uses Unicode replacements for bytes)
|
||||
/// back to raw bytes. GPT-2 uses a byte-to-unicode mapping where e.g. byte 0x20 (space)
|
||||
/// is represented as 'Ġ' (U+0120).
|
||||
fn token_str_to_bytes(s: &str) -> Vec<u8> {
|
||||
s.chars().map(|c| unicode_to_byte(c)).collect()
|
||||
}
|
||||
|
||||
fn take_valid_utf8(pending: &mut Vec<u8>) -> String {
|
||||
match std::str::from_utf8(pending) {
|
||||
Ok(text) => {
|
||||
let text = text.to_string();
|
||||
pending.clear();
|
||||
text
|
||||
}
|
||||
Err(err) => {
|
||||
let valid_up_to = err.valid_up_to();
|
||||
if valid_up_to == 0 {
|
||||
if let Some(error_len) = err.error_len() {
|
||||
let invalid_len = error_len.min(pending.len());
|
||||
let text = String::from_utf8_lossy(&pending[..invalid_len]).into_owned();
|
||||
pending.drain(..invalid_len);
|
||||
return text;
|
||||
}
|
||||
return String::new();
|
||||
}
|
||||
let text = String::from_utf8_lossy(&pending[..valid_up_to]).into_owned();
|
||||
pending.drain(..valid_up_to);
|
||||
text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a Unicode char back to the byte it represents in GPT-2 encoding.
|
||||
fn unicode_to_byte(c: char) -> u8 {
|
||||
// Build the inverse map on first use
|
||||
use std::sync::OnceLock;
|
||||
static INV_MAP: OnceLock<HashMap<u32, u8>> = OnceLock::new();
|
||||
|
||||
let map = INV_MAP.get_or_init(|| {
|
||||
let mut m = HashMap::new();
|
||||
// Build GPT-2's bytes_to_unicode forward map, then invert
|
||||
let mut n = 0u32;
|
||||
for b in 0..=255u16 {
|
||||
let byte = b as u8;
|
||||
let unicode = match byte {
|
||||
0x21..=0x7E | 0xA1..=0xAC | 0xAE..=0xFF => byte as u32,
|
||||
_ => {
|
||||
let u = 256 + n;
|
||||
n += 1;
|
||||
u
|
||||
}
|
||||
};
|
||||
m.insert(unicode, byte);
|
||||
}
|
||||
m
|
||||
});
|
||||
|
||||
*map.get(&(c as u32))
|
||||
.unwrap_or_else(|| panic!("unmapped unicode char U+{:04X} in tokenizer", c as u32))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Tokenizer, take_valid_utf8};
|
||||
|
||||
#[test]
|
||||
fn qwen_added_tokens_are_indivisible_and_im_end_is_eos() {
|
||||
let path =
|
||||
std::env::temp_dir().join(format!("xserv-tokenizer-test-{}.json", std::process::id()));
|
||||
std::fs::write(
|
||||
&path,
|
||||
r#"{
|
||||
"model": {
|
||||
"vocab": {},
|
||||
"merges": [],
|
||||
"byte_fallback": false
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id":151643,"content":"<|endoftext|>","special":true},
|
||||
{"id":151644,"content":"<|im_start|>","special":true},
|
||||
{"id":151645,"content":"<|im_end|>","special":true},
|
||||
{"id":151667,"content":"<think>","special":false},
|
||||
{"id":151668,"content":"</think>","special":false}
|
||||
]
|
||||
}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&path);
|
||||
let _ = std::fs::remove_file(&path);
|
||||
|
||||
assert_eq!(tokenizer.eos_token_id(), Some(151645));
|
||||
assert_eq!(tokenizer.encode("<think>"), vec![151667]);
|
||||
assert_eq!(tokenizer.encode("</think>"), vec![151668]);
|
||||
assert_eq!(tokenizer.decode(&[151645]), "<|im_end|>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_decode_buffers_incomplete_utf8() {
|
||||
let mut pending = vec![0xF0, 0x9F];
|
||||
assert_eq!(take_valid_utf8(&mut pending), "");
|
||||
pending.extend_from_slice(&[0x98, 0x8A, b'!']);
|
||||
assert_eq!(take_valid_utf8(&mut pending), "😊!");
|
||||
assert!(pending.is_empty());
|
||||
}
|
||||
}
|
||||
3
crates/xserv-tokenizer/src/lib.rs
Normal file
3
crates/xserv-tokenizer/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod bpe;
|
||||
|
||||
pub use bpe::Tokenizer;
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <math.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// GELU (tanh approximation):
|
||||
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
@@ -35,12 +36,85 @@ __global__ void silu_bf16(const __nv_bfloat16* x, __nv_bfloat16* out, int n) {
|
||||
if (idx < n) out[idx] = __float2bfloat16(silu_f(__bfloat162float(x[idx])));
|
||||
}
|
||||
|
||||
__global__ void scale_f32_kernel(const float* x, float* out, float scale, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = x[idx] * scale;
|
||||
}
|
||||
|
||||
__global__ void scale_bf16_kernel(const __nv_bfloat16* x, __nv_bfloat16* out, float scale, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(x[idx]) * scale);
|
||||
}
|
||||
|
||||
// Fused SiLU×Mul: out = silu(gate) * up
|
||||
__global__ void silu_mul_bf16_kernel(const __nv_bfloat16* gate, const __nv_bfloat16* up,
|
||||
__nv_bfloat16* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
float g = __bfloat162float(gate[idx]);
|
||||
float u = __bfloat162float(up[idx]);
|
||||
float silu_g = g / (1.0f + expf(-g));
|
||||
out[idx] = __float2bfloat16(silu_g * u);
|
||||
}
|
||||
}
|
||||
|
||||
// gpt-oss GLU: gate_up is [N, 2*D] with interleaved columns (gate=even, up=odd).
|
||||
// gate = gate_up[::2].clamp(max=limit)
|
||||
// up = gate_up[1::2].clamp(-limit, limit)
|
||||
// glu = gate * sigmoid(gate * alpha)
|
||||
// out = (up + 1) * glu
|
||||
// Output: [N, D]
|
||||
__global__ void gpt_oss_glu_bf16_kernel(const __nv_bfloat16* gate_up, __nv_bfloat16* out,
|
||||
int n_elements, float alpha, float limit) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n_elements) {
|
||||
float g = __bfloat162float(gate_up[idx * 2]);
|
||||
float u = __bfloat162float(gate_up[idx * 2 + 1]);
|
||||
g = fminf(g, limit);
|
||||
u = fmaxf(fminf(u, limit), -limit);
|
||||
float glu = g / (1.0f + expf(-g * alpha));
|
||||
out[idx] = __float2bfloat16((u + 1.0f) * glu);
|
||||
}
|
||||
}
|
||||
|
||||
// Element-wise add: out = a + b
|
||||
__global__ void add_f32_kernel(const float* a, const float* b, float* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = a[idx] + b[idx];
|
||||
}
|
||||
__global__ void add_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(a[idx]) + __bfloat162float(b[idx]));
|
||||
}
|
||||
|
||||
// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c]
|
||||
__global__ void bias_add_2d_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ x, const __nv_bfloat16* __restrict__ bias,
|
||||
__nv_bfloat16* __restrict__ out, int rows, int cols
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= rows * cols) return;
|
||||
float v = __bfloat162float(x[idx]) + __bfloat162float(bias[idx % cols]);
|
||||
out[idx] = __float2bfloat16(v);
|
||||
}
|
||||
|
||||
// Element-wise mul: out = a * b
|
||||
__global__ void mul_f32_kernel(const float* a, const float* b, float* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = a[idx] * b[idx];
|
||||
}
|
||||
__global__ void mul_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(a[idx]) * __bfloat162float(b[idx]));
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_gelu_f32(const void* x, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
gelu_f32<<<grid, block, 0, (cudaStream_t)stream>>>((const float*)x, (float*)out, n);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_gelu_bf16(const void* x, void* out, int n, void* stream) {
|
||||
@@ -48,12 +122,14 @@ void launch_gelu_bf16(const void* x, void* out, int n, void* stream) {
|
||||
int grid = (n + block - 1) / block;
|
||||
gelu_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_silu_f32(const void* x, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
silu_f32<<<grid, block, 0, (cudaStream_t)stream>>>((const float*)x, (float*)out, n);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_silu_bf16(const void* x, void* out, int n, void* stream) {
|
||||
@@ -61,6 +137,77 @@ void launch_silu_bf16(const void* x, void* out, int n, void* stream) {
|
||||
int grid = (n + block - 1) / block;
|
||||
silu_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_scale_f32(const void* x, void* out, float scale, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
scale_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)x, (float*)out, scale, n);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_scale_bf16(const void* x, void* out, float scale, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
scale_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, scale, n);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_add_f32(const void* a, const void* b, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
add_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)a, (const float*)b, (float*)out, n);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
void launch_add_bf16(const void* a, const void* b, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
add_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
void launch_bias_add_2d_bf16(const void* x, const void* bias, void* out, int rows, int cols, void* stream) {
|
||||
int n = rows * cols;
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
bias_add_2d_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const __nv_bfloat16*)bias, (__nv_bfloat16*)out, rows, cols);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
void launch_mul_f32(const void* a, const void* b, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
mul_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)a, (const float*)b, (float*)out, n);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
void launch_mul_bf16(const void* a, const void* b, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
mul_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_silu_mul_bf16(const void* gate, const void* up, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
silu_mul_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)gate, (const __nv_bfloat16*)up, (__nv_bfloat16*)out, n);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_gpt_oss_glu_bf16(const void* gate_up, void* out, int n_elements,
|
||||
float alpha, float limit, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n_elements + block - 1) / block;
|
||||
gpt_oss_glu_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)gate_up, (__nv_bfloat16*)out, n_elements, alpha, limit);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
59
csrc/attention/causal_mask.cu
Normal file
59
csrc/attention/causal_mask.cu
Normal file
@@ -0,0 +1,59 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Apply causal mask: set scores[row][col] = -inf where col > row + offset.
|
||||
// offset is used for KV cache: when query starts at position `offset`,
|
||||
// we allow attending to positions [0, offset + row].
|
||||
// scores: [batch, rows, cols] (flattened batch×heads)
|
||||
|
||||
__global__ void causal_mask_f32(
|
||||
float* __restrict__ scores,
|
||||
int rows, int cols, int offset
|
||||
) {
|
||||
int batch_idx = blockIdx.z;
|
||||
int row = blockIdx.y;
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (col < cols && col > row + offset) {
|
||||
// 64-bit index: batch * rows * cols overflows int32 at moderate batch
|
||||
// and long context (e.g. batch=128 * heads=28 * seq=32768).
|
||||
long long idx = ((long long)batch_idx * rows + row) * cols + col;
|
||||
scores[idx] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void causal_mask_bf16(
|
||||
__nv_bfloat16* __restrict__ scores,
|
||||
int rows, int cols, int offset
|
||||
) {
|
||||
int batch_idx = blockIdx.z;
|
||||
int row = blockIdx.y;
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (col < cols && col > row + offset) {
|
||||
long long idx = ((long long)batch_idx * rows + row) * cols + col;
|
||||
scores[idx] = __float2bfloat16(-INFINITY);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_causal_mask_f32(void* scores, int batch, int rows, int cols,
|
||||
int offset, void* stream) {
|
||||
int block = 256;
|
||||
dim3 grid((cols + block - 1) / block, rows, batch);
|
||||
causal_mask_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(float*)scores, rows, cols, offset);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_causal_mask_bf16(void* scores, int batch, int rows, int cols,
|
||||
int offset, void* stream) {
|
||||
int block = 256;
|
||||
dim3 grid((cols + block - 1) / block, rows, batch);
|
||||
causal_mask_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)scores, rows, cols, offset);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
616
csrc/attention/flash_attention.cu
Normal file
616
csrc/attention/flash_attention.cu
Normal file
@@ -0,0 +1,616 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <float.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Flash Attention 2 forward kernel for BF16 with FP32 accumulation.
|
||||
//
|
||||
// Algorithm: outer loop over Q tiles (BR rows), inner loop over K/V tiles (BC rows).
|
||||
// Uses online softmax — no O(S^2) memory.
|
||||
//
|
||||
// Layout: Q [batch, num_q_heads, q_len, head_dim]
|
||||
// K [batch, num_kv_heads, kv_len, head_dim]
|
||||
// V [batch, num_kv_heads, kv_len, head_dim]
|
||||
// O [batch, num_q_heads, q_len, head_dim]
|
||||
//
|
||||
// Shared memory (BF16):
|
||||
// smem_q[BR][head_dim] — 64 * 128 * 2 = 16 KB (loaded once per Q tile)
|
||||
// smem_kv[BC][head_dim] — 64 * 128 * 2 = 16 KB (alternates K and V)
|
||||
// Total: 32 KB (fits in default 48 KB shared memory)
|
||||
|
||||
#define BR 64
|
||||
#define BC 64
|
||||
#define THREADS_PER_BLOCK 128
|
||||
|
||||
__global__ void flash_attention_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ Q,
|
||||
const __nv_bfloat16* __restrict__ K,
|
||||
const __nv_bfloat16* __restrict__ V,
|
||||
__nv_bfloat16* __restrict__ O,
|
||||
int num_q_heads, int num_kv_heads,
|
||||
int q_len, int kv_len, int head_dim,
|
||||
float scale, int causal
|
||||
) {
|
||||
// Grid: (ceil(q_len / BR), batch * num_q_heads)
|
||||
int q_tile_idx = blockIdx.x;
|
||||
int bh = blockIdx.y;
|
||||
int batch_idx = bh / num_q_heads;
|
||||
int q_head = bh % num_q_heads;
|
||||
|
||||
// GQA: map Q head to KV head
|
||||
int heads_per_group = num_q_heads / num_kv_heads;
|
||||
int kv_head = q_head / heads_per_group;
|
||||
|
||||
int q_tile_start = q_tile_idx * BR;
|
||||
if (q_tile_start >= q_len) return;
|
||||
int q_tile_rows = min(BR, q_len - q_tile_start);
|
||||
|
||||
// Pointers to this batch/head's data
|
||||
const __nv_bfloat16* Q_head = Q + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim;
|
||||
const __nv_bfloat16* K_head = K + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
||||
const __nv_bfloat16* V_head = V + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
||||
__nv_bfloat16* O_head = O + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
// Dynamic shared memory
|
||||
extern __shared__ __nv_bfloat16 smem[];
|
||||
__nv_bfloat16* smem_q = smem; // BR * head_dim elements
|
||||
__nv_bfloat16* smem_kv = smem + BR * head_dim; // BC * head_dim elements
|
||||
|
||||
// ---- Load Q tile into shared memory (cooperative) ----
|
||||
int q_elems = q_tile_rows * head_dim;
|
||||
for (int i = tid; i < q_elems; i += THREADS_PER_BLOCK) {
|
||||
int row = i / head_dim;
|
||||
int col = i % head_dim;
|
||||
smem_q[row * head_dim + col] = Q_head[(q_tile_start + row) * head_dim + col];
|
||||
}
|
||||
// Zero-pad if q_tile_rows < BR
|
||||
for (int i = q_elems + tid; i < BR * head_dim; i += THREADS_PER_BLOCK) {
|
||||
smem_q[i] = __float2bfloat16(0.0f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Thread t (0 <= t < q_tile_rows) owns Q row t
|
||||
bool owns_row = (tid < q_tile_rows);
|
||||
|
||||
// Per-thread FP32 accumulators (head_dim up to 128)
|
||||
float O_acc[128];
|
||||
float m_val = -INFINITY;
|
||||
float l_val = 0.0f;
|
||||
if (owns_row) {
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
O_acc[d] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// kv_offset handles cached KV longer than Q (decode step)
|
||||
int kv_offset = kv_len - q_len;
|
||||
int num_kv_tiles = (kv_len + BC - 1) / BC;
|
||||
|
||||
// ---- Inner loop over K/V tiles ----
|
||||
for (int j = 0; j < num_kv_tiles; j++) {
|
||||
int kv_tile_start = j * BC;
|
||||
int kv_tile_cols = min(BC, kv_len - kv_tile_start);
|
||||
|
||||
// Causal: skip entire tile if all K positions are in the future
|
||||
if (causal) {
|
||||
int max_allowed_kv = (q_tile_start + q_tile_rows - 1) + kv_offset;
|
||||
if (kv_tile_start > max_allowed_kv) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Load K tile into smem_kv ----
|
||||
int kv_elems = kv_tile_cols * head_dim;
|
||||
for (int i = tid; i < kv_elems; i += THREADS_PER_BLOCK) {
|
||||
int row = i / head_dim;
|
||||
int col = i % head_dim;
|
||||
smem_kv[row * head_dim + col] = K_head[(kv_tile_start + row) * head_dim + col];
|
||||
}
|
||||
for (int i = kv_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) {
|
||||
smem_kv[i] = __float2bfloat16(0.0f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ---- Compute S = Q @ K^T * scale, causal mask, online softmax ----
|
||||
float P[BC];
|
||||
|
||||
if (owns_row) {
|
||||
float row_max = -INFINITY;
|
||||
for (int c = 0; c < kv_tile_cols; c++) {
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
dot += __bfloat162float(smem_q[tid * head_dim + d])
|
||||
* __bfloat162float(smem_kv[c * head_dim + d]);
|
||||
}
|
||||
float s = dot * scale;
|
||||
|
||||
if (causal) {
|
||||
int q_pos = q_tile_start + tid;
|
||||
int kv_pos = kv_tile_start + c;
|
||||
if (kv_pos > q_pos + kv_offset) {
|
||||
s = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
P[c] = s; // store score temporarily in P
|
||||
row_max = fmaxf(row_max, s);
|
||||
}
|
||||
|
||||
// Online softmax: m_new, P = exp(S - m_new), l_new
|
||||
float m_new = fmaxf(m_val, row_max);
|
||||
|
||||
float psum = 0.0f;
|
||||
for (int c = 0; c < kv_tile_cols; c++) {
|
||||
P[c] = expf(P[c] - m_new);
|
||||
psum += P[c];
|
||||
}
|
||||
|
||||
// Rescale previous accumulator
|
||||
float correction = expf(m_val - m_new);
|
||||
l_val = correction * l_val + psum;
|
||||
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
O_acc[d] *= correction;
|
||||
}
|
||||
|
||||
m_val = m_new;
|
||||
}
|
||||
|
||||
// Sync before overwriting smem_kv with V tile
|
||||
__syncthreads();
|
||||
|
||||
// ---- Load V tile (reuse smem_kv) ----
|
||||
int v_elems = kv_tile_cols * head_dim;
|
||||
for (int i = tid; i < v_elems; i += THREADS_PER_BLOCK) {
|
||||
int row = i / head_dim;
|
||||
int col = i % head_dim;
|
||||
smem_kv[row * head_dim + col] = V_head[(kv_tile_start + row) * head_dim + col];
|
||||
}
|
||||
for (int i = v_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) {
|
||||
smem_kv[i] = __float2bfloat16(0.0f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ---- Accumulate O += P @ V_tile ----
|
||||
if (owns_row) {
|
||||
for (int c = 0; c < kv_tile_cols; c++) {
|
||||
float p = P[c];
|
||||
if (p != 0.0f) {
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
O_acc[d] += p * __bfloat162float(smem_kv[c * head_dim + d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// ---- Final normalize and write output (convert FP32 → BF16) ----
|
||||
if (owns_row) {
|
||||
float inv_l = (l_val > 0.0f) ? (1.0f / l_val) : 0.0f;
|
||||
int global_row = q_tile_start + tid;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
O_head[global_row * head_dim + d] = __float2bfloat16(O_acc[d] * inv_l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flash Attention 2 forward with gpt-oss attention sinks + optional sliding window.
|
||||
// Identical to flash_attention_bf16_kernel, plus:
|
||||
// - sinks: [num_q_heads] BF16 — a per-head extra softmax logit (no value),
|
||||
// folded into the denominator after the K/V tiles (exactly as the decode
|
||||
// sink kernel does).
|
||||
// - window_size > 0: sliding-window mask. Query at global position p attends
|
||||
// to keys k with p - window_size < k <= p (matches HF gpt-oss).
|
||||
__global__ void flash_attention_sinks_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ Q,
|
||||
const __nv_bfloat16* __restrict__ K,
|
||||
const __nv_bfloat16* __restrict__ V,
|
||||
__nv_bfloat16* __restrict__ O,
|
||||
const __nv_bfloat16* __restrict__ sinks, // [num_q_heads] or NULL
|
||||
int num_q_heads, int num_kv_heads,
|
||||
int q_len, int kv_len, int head_dim,
|
||||
float scale, int causal, int window_size
|
||||
) {
|
||||
int q_tile_idx = blockIdx.x;
|
||||
int bh = blockIdx.y;
|
||||
int batch_idx = bh / num_q_heads;
|
||||
int q_head = bh % num_q_heads;
|
||||
|
||||
int heads_per_group = num_q_heads / num_kv_heads;
|
||||
int kv_head = q_head / heads_per_group;
|
||||
|
||||
int q_tile_start = q_tile_idx * BR;
|
||||
if (q_tile_start >= q_len) return;
|
||||
int q_tile_rows = min(BR, q_len - q_tile_start);
|
||||
|
||||
const __nv_bfloat16* Q_head = Q + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim;
|
||||
const __nv_bfloat16* K_head = K + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
||||
const __nv_bfloat16* V_head = V + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
||||
__nv_bfloat16* O_head = O + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
extern __shared__ __nv_bfloat16 smem[];
|
||||
__nv_bfloat16* smem_q = smem;
|
||||
__nv_bfloat16* smem_kv = smem + BR * head_dim;
|
||||
|
||||
int q_elems = q_tile_rows * head_dim;
|
||||
for (int i = tid; i < q_elems; i += THREADS_PER_BLOCK) {
|
||||
int row = i / head_dim;
|
||||
int col = i % head_dim;
|
||||
smem_q[row * head_dim + col] = Q_head[(q_tile_start + row) * head_dim + col];
|
||||
}
|
||||
for (int i = q_elems + tid; i < BR * head_dim; i += THREADS_PER_BLOCK) {
|
||||
smem_q[i] = __float2bfloat16(0.0f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
bool owns_row = (tid < q_tile_rows);
|
||||
|
||||
float O_acc[128];
|
||||
float m_val = -INFINITY;
|
||||
float l_val = 0.0f;
|
||||
if (owns_row) {
|
||||
for (int d = 0; d < head_dim; d++) O_acc[d] = 0.0f;
|
||||
}
|
||||
|
||||
int kv_offset = kv_len - q_len;
|
||||
int num_kv_tiles = (kv_len + BC - 1) / BC;
|
||||
|
||||
for (int j = 0; j < num_kv_tiles; j++) {
|
||||
int kv_tile_start = j * BC;
|
||||
int kv_tile_cols = min(BC, kv_len - kv_tile_start);
|
||||
|
||||
if (causal) {
|
||||
int max_allowed_kv = (q_tile_start + q_tile_rows - 1) + kv_offset;
|
||||
if (kv_tile_start > max_allowed_kv) continue;
|
||||
}
|
||||
|
||||
int kv_elems = kv_tile_cols * head_dim;
|
||||
for (int i = tid; i < kv_elems; i += THREADS_PER_BLOCK) {
|
||||
int row = i / head_dim;
|
||||
int col = i % head_dim;
|
||||
smem_kv[row * head_dim + col] = K_head[(kv_tile_start + row) * head_dim + col];
|
||||
}
|
||||
for (int i = kv_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) {
|
||||
smem_kv[i] = __float2bfloat16(0.0f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float P[BC];
|
||||
|
||||
if (owns_row) {
|
||||
float row_max = -INFINITY;
|
||||
int q_pos = q_tile_start + tid + kv_offset; // global query position
|
||||
for (int c = 0; c < kv_tile_cols; c++) {
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
dot += __bfloat162float(smem_q[tid * head_dim + d])
|
||||
* __bfloat162float(smem_kv[c * head_dim + d]);
|
||||
}
|
||||
float s = dot * scale;
|
||||
|
||||
int kv_pos = kv_tile_start + c;
|
||||
if (causal && kv_pos > q_pos) {
|
||||
s = -INFINITY;
|
||||
}
|
||||
// Sliding window: drop keys older than the window.
|
||||
if (window_size > 0 && kv_pos <= q_pos - window_size) {
|
||||
s = -INFINITY;
|
||||
}
|
||||
|
||||
P[c] = s;
|
||||
row_max = fmaxf(row_max, s);
|
||||
}
|
||||
|
||||
// A fully-masked KV tile (every key causal- or window-masked) has
|
||||
// row_max == -INFINITY. Folding it in computes expf(-inf - (-inf))
|
||||
// = NaN, and a later valid tile's 0*NaN correction then poisons the
|
||||
// whole row. This happens for sliding-window layers whenever a
|
||||
// query's window starts past an early tile (the causal `continue`
|
||||
// above only skips fully-future tiles, not out-of-window ones).
|
||||
// A masked tile contributes nothing to the softmax — skip it.
|
||||
if (row_max != -INFINITY) {
|
||||
float m_new = fmaxf(m_val, row_max);
|
||||
float psum = 0.0f;
|
||||
for (int c = 0; c < kv_tile_cols; c++) {
|
||||
P[c] = expf(P[c] - m_new);
|
||||
psum += P[c];
|
||||
}
|
||||
float correction = expf(m_val - m_new);
|
||||
l_val = correction * l_val + psum;
|
||||
for (int d = 0; d < head_dim; d++) O_acc[d] *= correction;
|
||||
m_val = m_new;
|
||||
} else {
|
||||
for (int c = 0; c < kv_tile_cols; c++) P[c] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int v_elems = kv_tile_cols * head_dim;
|
||||
for (int i = tid; i < v_elems; i += THREADS_PER_BLOCK) {
|
||||
int row = i / head_dim;
|
||||
int col = i % head_dim;
|
||||
smem_kv[row * head_dim + col] = V_head[(kv_tile_start + row) * head_dim + col];
|
||||
}
|
||||
for (int i = v_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) {
|
||||
smem_kv[i] = __float2bfloat16(0.0f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (owns_row) {
|
||||
for (int c = 0; c < kv_tile_cols; c++) {
|
||||
float p = P[c];
|
||||
if (p != 0.0f) {
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
O_acc[d] += p * __bfloat162float(smem_kv[c * head_dim + d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Fold in the per-head attention sink (extra logit, no value contribution).
|
||||
if (owns_row && sinks != nullptr) {
|
||||
float sink_logit = __bfloat162float(sinks[q_head]);
|
||||
float m_new = fmaxf(m_val, sink_logit);
|
||||
float correction = expf(m_val - m_new);
|
||||
l_val = correction * l_val + expf(sink_logit - m_new);
|
||||
for (int d = 0; d < head_dim; d++) O_acc[d] *= correction;
|
||||
m_val = m_new;
|
||||
}
|
||||
|
||||
if (owns_row) {
|
||||
float inv_l = (l_val > 0.0f) ? (1.0f / l_val) : 0.0f;
|
||||
int global_row = q_tile_start + tid;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
O_head[global_row * head_dim + d] = __float2bfloat16(O_acc[d] * inv_l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Decode Attention kernel: optimized for Q_len=1 (single-token decode).
|
||||
// Parallelizes across KV sequence dimension instead of Q rows.
|
||||
//
|
||||
// Grid: (batch * num_q_heads, 1) — one block per Q head
|
||||
// Block: 256 threads — each thread handles ceil(kv_len / 256) KV positions
|
||||
// Uses online softmax reduction across threads.
|
||||
// ============================================================
|
||||
|
||||
#define DECODE_THREADS 256
|
||||
#define HEAD_DIM_MAX 128
|
||||
|
||||
__global__ void decode_attention_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ Q,
|
||||
const __nv_bfloat16* __restrict__ K,
|
||||
const __nv_bfloat16* __restrict__ V,
|
||||
__nv_bfloat16* __restrict__ O,
|
||||
int num_q_heads, int num_kv_heads,
|
||||
int kv_len, int head_dim,
|
||||
float scale
|
||||
) {
|
||||
int bh = blockIdx.x;
|
||||
int batch_idx = bh / num_q_heads;
|
||||
int q_head = bh % num_q_heads;
|
||||
|
||||
// GQA mapping
|
||||
int heads_per_group = num_q_heads / num_kv_heads;
|
||||
int kv_head = q_head / heads_per_group;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
// Pointers to this batch/head's data
|
||||
// Q: [batch, num_q_heads, 1, head_dim]
|
||||
const __nv_bfloat16* Q_ptr = Q + ((long long)batch_idx * num_q_heads + q_head) * head_dim;
|
||||
// K/V: [batch, num_kv_heads, kv_len, head_dim]
|
||||
const __nv_bfloat16* K_base = K + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
||||
const __nv_bfloat16* V_base = V + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
||||
__nv_bfloat16* O_ptr = O + ((long long)batch_idx * num_q_heads + q_head) * head_dim;
|
||||
|
||||
// Load Q vector into registers (head_dim <= 128)
|
||||
float q_reg[HEAD_DIM_MAX];
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
q_reg[d] = __bfloat162float(Q_ptr[d]);
|
||||
}
|
||||
|
||||
// Each thread processes a chunk of KV positions
|
||||
// Thread tid handles positions: tid, tid+DECODE_THREADS, tid+2*DECODE_THREADS, ...
|
||||
float local_max = -INFINITY;
|
||||
float local_sum = 0.0f;
|
||||
float local_O[HEAD_DIM_MAX];
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
local_O[d] = 0.0f;
|
||||
}
|
||||
|
||||
for (int pos = tid; pos < kv_len; pos += DECODE_THREADS) {
|
||||
// Compute dot(Q, K[pos]) * scale
|
||||
const __nv_bfloat16* K_pos = K_base + pos * head_dim;
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
dot += q_reg[d] * __bfloat162float(K_pos[d]);
|
||||
}
|
||||
float s = dot * scale;
|
||||
|
||||
// Online softmax update
|
||||
float new_max = fmaxf(local_max, s);
|
||||
float correction = expf(local_max - new_max);
|
||||
float p = expf(s - new_max);
|
||||
|
||||
// Rescale running sum and O
|
||||
local_sum = local_sum * correction + p;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
local_O[d] = local_O[d] * correction;
|
||||
}
|
||||
|
||||
// Accumulate V[pos] weighted by p
|
||||
const __nv_bfloat16* V_pos = V_base + pos * head_dim;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
local_O[d] += p * __bfloat162float(V_pos[d]);
|
||||
}
|
||||
|
||||
local_max = new_max;
|
||||
}
|
||||
|
||||
// --- Block-level online softmax reduction ---
|
||||
// We need to combine (local_max, local_sum, local_O) across all threads.
|
||||
// Strategy: reduce max, then each thread rescales, then reduce sum and O.
|
||||
|
||||
// Shared memory for reduction
|
||||
__shared__ float smem_max[32]; // one per warp
|
||||
__shared__ float smem_sum[32];
|
||||
__shared__ float smem_O_warp[32][HEAD_DIM_MAX];
|
||||
|
||||
// Step 1: Block-wide max reduction
|
||||
int lane = tid & 31;
|
||||
int warp_id = tid >> 5;
|
||||
int num_warps = DECODE_THREADS >> 5; // 8 warps
|
||||
|
||||
float warp_max = local_max;
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
|
||||
if (lane == 0) smem_max[warp_id] = warp_max;
|
||||
__syncthreads();
|
||||
|
||||
float global_max;
|
||||
if (tid == 0) {
|
||||
global_max = smem_max[0];
|
||||
for (int i = 1; i < num_warps; i++)
|
||||
global_max = fmaxf(global_max, smem_max[i]);
|
||||
smem_max[0] = global_max;
|
||||
}
|
||||
__syncthreads();
|
||||
global_max = smem_max[0];
|
||||
|
||||
// Step 2: Each thread rescales its local_sum and local_O with global_max
|
||||
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
|
||||
local_sum *= rescale;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
local_O[d] *= rescale;
|
||||
}
|
||||
|
||||
// Step 3: Reduce sum across block
|
||||
float warp_sum = local_sum;
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
|
||||
if (lane == 0) smem_sum[warp_id] = warp_sum;
|
||||
__syncthreads();
|
||||
|
||||
float global_sum;
|
||||
if (tid == 0) {
|
||||
global_sum = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++)
|
||||
global_sum += smem_sum[i];
|
||||
smem_sum[0] = global_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
global_sum = smem_sum[0];
|
||||
|
||||
// Step 4: Reduce O across block, dim by dim. Store one partial per warp
|
||||
// and sum in warp-id order; atomicAdd made greedy decode nondeterministic
|
||||
// when logits were close (same fix pattern as paged_attention.cu / gemv.cu).
|
||||
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
||||
|
||||
for (int i = tid; i < 32 * HEAD_DIM_MAX; i += DECODE_THREADS) {
|
||||
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
float val = local_O[d];
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
if (lane == 0) smem_O_warp[warp_id][d] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Thread 0..head_dim-1 write final output
|
||||
for (int d = tid; d < head_dim; d += DECODE_THREADS) {
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
|
||||
O_ptr[d] = __float2bfloat16(out * inv_sum);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_flash_attention_bf16(
|
||||
const void* Q, const void* K, const void* V, void* O,
|
||||
int batch, int num_q_heads, int num_kv_heads,
|
||||
int q_len, int kv_len, int head_dim,
|
||||
float scale, int causal, void* stream
|
||||
) {
|
||||
int q_tiles = (q_len + BR - 1) / BR;
|
||||
dim3 grid(q_tiles, batch * num_q_heads);
|
||||
int block = THREADS_PER_BLOCK;
|
||||
|
||||
// Shared memory: smem_q[BR * head_dim] + smem_kv[BC * head_dim], all BF16
|
||||
int smem_bytes = (BR + BC) * head_dim * (int)sizeof(__nv_bfloat16);
|
||||
|
||||
flash_attention_bf16_kernel<<<grid, block, smem_bytes, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)Q,
|
||||
(const __nv_bfloat16*)K,
|
||||
(const __nv_bfloat16*)V,
|
||||
(__nv_bfloat16*)O,
|
||||
num_q_heads, num_kv_heads,
|
||||
q_len, kv_len, head_dim,
|
||||
scale, causal
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_flash_attention_sinks_bf16(
|
||||
const void* Q, const void* K, const void* V, void* O,
|
||||
const void* sinks,
|
||||
int batch, int num_q_heads, int num_kv_heads,
|
||||
int q_len, int kv_len, int head_dim,
|
||||
float scale, int causal, int window_size, void* stream
|
||||
) {
|
||||
int q_tiles = (q_len + BR - 1) / BR;
|
||||
dim3 grid(q_tiles, batch * num_q_heads);
|
||||
int block = THREADS_PER_BLOCK;
|
||||
int smem_bytes = (BR + BC) * head_dim * (int)sizeof(__nv_bfloat16);
|
||||
|
||||
flash_attention_sinks_bf16_kernel<<<grid, block, smem_bytes, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)Q,
|
||||
(const __nv_bfloat16*)K,
|
||||
(const __nv_bfloat16*)V,
|
||||
(__nv_bfloat16*)O,
|
||||
(const __nv_bfloat16*)sinks,
|
||||
num_q_heads, num_kv_heads,
|
||||
q_len, kv_len, head_dim,
|
||||
scale, causal, window_size
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_decode_attention_bf16(
|
||||
const void* Q, const void* K, const void* V, void* O,
|
||||
int batch, int num_q_heads, int num_kv_heads,
|
||||
int kv_len, int head_dim,
|
||||
float scale, int causal, void* stream
|
||||
) {
|
||||
int grid = batch * num_q_heads;
|
||||
int block = DECODE_THREADS;
|
||||
|
||||
decode_attention_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)Q,
|
||||
(const __nv_bfloat16*)K,
|
||||
(const __nv_bfloat16*)V,
|
||||
(__nv_bfloat16*)O,
|
||||
num_q_heads, num_kv_heads,
|
||||
kv_len, head_dim,
|
||||
scale
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
614
csrc/attention/paged_attention.cu
Normal file
614
csrc/attention/paged_attention.cu
Normal file
@@ -0,0 +1,614 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <float.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Paged decode attention kernel for BF16 with FP32 accumulation.
|
||||
//
|
||||
// Reads K/V from a paged pool indexed by a per-sequence block table.
|
||||
// One CUDA block per (sequence, q_head). Each block streams over the
|
||||
// sequence's KV positions and accumulates attention output via online
|
||||
// softmax.
|
||||
//
|
||||
// Layouts:
|
||||
// Q [batch, num_q_heads, 1, head_dim] BF16
|
||||
// K_cache [num_blocks, num_kv_heads, BLOCK_SIZE, head_dim] BF16
|
||||
// V_cache same
|
||||
// block_tables [max_seqs, max_blocks_per_seq] int32
|
||||
// — the i-th sequence in this launch reads row
|
||||
// block_tables[seq_slot[i] * stride + ...].
|
||||
// For simplicity the launch passes a packed row table
|
||||
// [batch, max_blocks_per_seq] (already gathered for the
|
||||
// active batch) so we just index by blockIdx.x_seq.
|
||||
// context_lens [batch] int32 — number of valid tokens per sequence.
|
||||
//
|
||||
// One CUDA block: 256 threads, head_dim <= 128.
|
||||
|
||||
#define PAGED_BLOCK_SIZE 16
|
||||
#define PAGED_THREADS 256
|
||||
#define PAGED_HEAD_DIM_MAX 128
|
||||
|
||||
__global__ void paged_decode_attention_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ Q,
|
||||
const __nv_bfloat16* __restrict__ K_cache,
|
||||
const __nv_bfloat16* __restrict__ V_cache,
|
||||
__nv_bfloat16* __restrict__ O,
|
||||
const int* __restrict__ block_tables, // [batch, max_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [batch]
|
||||
int num_q_heads, int num_kv_heads,
|
||||
int head_dim, int max_blocks_per_seq,
|
||||
float scale
|
||||
) {
|
||||
int seq_idx = blockIdx.y; // batch dim
|
||||
int q_head = blockIdx.x; // 0 .. num_q_heads-1
|
||||
int tid = threadIdx.x;
|
||||
|
||||
int kv_len = context_lens[seq_idx];
|
||||
if (kv_len <= 0) {
|
||||
// Nothing to attend over; zero output for safety.
|
||||
if (tid < head_dim) {
|
||||
O[((long long)seq_idx * num_q_heads + q_head) * head_dim + tid] =
|
||||
__float2bfloat16(0.0f);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// GQA mapping
|
||||
int heads_per_group = num_q_heads / num_kv_heads;
|
||||
int kv_head = q_head / heads_per_group;
|
||||
|
||||
// Pointers
|
||||
const __nv_bfloat16* Q_ptr = Q +
|
||||
((long long)seq_idx * num_q_heads + q_head) * head_dim;
|
||||
__nv_bfloat16* O_ptr = O +
|
||||
((long long)seq_idx * num_q_heads + q_head) * head_dim;
|
||||
const int* bt = block_tables + (long long)seq_idx * max_blocks_per_seq;
|
||||
|
||||
// Load Q vector into registers.
|
||||
float q_reg[PAGED_HEAD_DIM_MAX];
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
q_reg[d] = __bfloat162float(Q_ptr[d]);
|
||||
}
|
||||
|
||||
// Per-thread online softmax state.
|
||||
float local_max = -INFINITY;
|
||||
float local_sum = 0.0f;
|
||||
float local_O[PAGED_HEAD_DIM_MAX];
|
||||
for (int d = 0; d < head_dim; d++) local_O[d] = 0.0f;
|
||||
|
||||
int kv_stride_block = num_kv_heads * PAGED_BLOCK_SIZE * head_dim;
|
||||
int kv_stride_head = PAGED_BLOCK_SIZE * head_dim;
|
||||
|
||||
// Each thread handles positions tid, tid+PAGED_THREADS, ...
|
||||
for (int pos = tid; pos < kv_len; pos += PAGED_THREADS) {
|
||||
int logical_blk = pos / PAGED_BLOCK_SIZE;
|
||||
int slot_in_blk = pos % PAGED_BLOCK_SIZE;
|
||||
int phys_blk = bt[logical_blk];
|
||||
|
||||
const __nv_bfloat16* K_pos = K_cache
|
||||
+ (long long)phys_blk * kv_stride_block
|
||||
+ kv_head * kv_stride_head
|
||||
+ slot_in_blk * head_dim;
|
||||
const __nv_bfloat16* V_pos = V_cache
|
||||
+ (long long)phys_blk * kv_stride_block
|
||||
+ kv_head * kv_stride_head
|
||||
+ slot_in_blk * head_dim;
|
||||
|
||||
// dot(Q, K[pos]) * scale
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
dot += q_reg[d] * __bfloat162float(K_pos[d]);
|
||||
}
|
||||
float s = dot * scale;
|
||||
|
||||
float new_max = fmaxf(local_max, s);
|
||||
float correction = expf(local_max - new_max);
|
||||
float p = expf(s - new_max);
|
||||
|
||||
local_sum = local_sum * correction + p;
|
||||
for (int d = 0; d < head_dim; d++) local_O[d] *= correction;
|
||||
|
||||
// Accumulate weighted V.
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
local_O[d] += p * __bfloat162float(V_pos[d]);
|
||||
}
|
||||
|
||||
local_max = new_max;
|
||||
}
|
||||
|
||||
// ---- Block-level online softmax reduction ----
|
||||
__shared__ float smem_max[32];
|
||||
__shared__ float smem_sum[32];
|
||||
__shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX];
|
||||
|
||||
int lane = tid & 31;
|
||||
int warp_id = tid >> 5;
|
||||
int num_warps = PAGED_THREADS >> 5;
|
||||
|
||||
// Step 1: block-wide max
|
||||
float warp_max = local_max;
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
|
||||
if (lane == 0) smem_max[warp_id] = warp_max;
|
||||
__syncthreads();
|
||||
|
||||
float global_max;
|
||||
if (tid == 0) {
|
||||
global_max = smem_max[0];
|
||||
for (int i = 1; i < num_warps; i++)
|
||||
global_max = fmaxf(global_max, smem_max[i]);
|
||||
smem_max[0] = global_max;
|
||||
}
|
||||
__syncthreads();
|
||||
global_max = smem_max[0];
|
||||
|
||||
// Step 2: rescale local state to global_max
|
||||
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
|
||||
local_sum *= rescale;
|
||||
for (int d = 0; d < head_dim; d++) local_O[d] *= rescale;
|
||||
|
||||
// Step 3: reduce sum
|
||||
float warp_sum = local_sum;
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
|
||||
if (lane == 0) smem_sum[warp_id] = warp_sum;
|
||||
__syncthreads();
|
||||
|
||||
float global_sum;
|
||||
if (tid == 0) {
|
||||
global_sum = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++) global_sum += smem_sum[i];
|
||||
smem_sum[0] = global_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
global_sum = smem_sum[0];
|
||||
|
||||
// Step 4: reduce O across block, dim by dim. Store one partial per warp
|
||||
// and sum in warp-id order; atomicAdd made greedy decode nondeterministic
|
||||
// when logits were close.
|
||||
for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) {
|
||||
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
float val = local_O[d];
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
if (lane == 0) smem_O_warp[warp_id][d] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
||||
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
|
||||
O_ptr[d] = __float2bfloat16(out * inv_sum);
|
||||
}
|
||||
}
|
||||
|
||||
// Tree-aware paged decode attention: per-query mask lets sibling candidates
|
||||
// in the same batch attend to different subsets of newly-written K/V.
|
||||
// `tree_start`: position where newly-written K/V begins (typically pos_offset).
|
||||
// `tree_len`: number of newly-written K/V rows (= batch, one per query).
|
||||
// `tree_mask[i][j] = 1` iff query i attends to K/V at position `tree_start+j`.
|
||||
// Positions < tree_start are always attended (regular history).
|
||||
__global__ void paged_decode_attention_tree_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ Q,
|
||||
const __nv_bfloat16* __restrict__ K_cache,
|
||||
const __nv_bfloat16* __restrict__ V_cache,
|
||||
__nv_bfloat16* __restrict__ O,
|
||||
const int* __restrict__ block_tables,
|
||||
const int* __restrict__ context_lens,
|
||||
const int* __restrict__ tree_mask, // [batch, tree_len] int32
|
||||
int num_q_heads, int num_kv_heads,
|
||||
int head_dim, int max_blocks_per_seq,
|
||||
int tree_start, int tree_len,
|
||||
float scale
|
||||
) {
|
||||
int seq_idx = blockIdx.y;
|
||||
int q_head = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
int kv_len = context_lens[seq_idx];
|
||||
if (kv_len <= 0) {
|
||||
if (tid < head_dim) {
|
||||
O[((long long)seq_idx * num_q_heads + q_head) * head_dim + tid] =
|
||||
__float2bfloat16(0.0f);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int heads_per_group = num_q_heads / num_kv_heads;
|
||||
int kv_head = q_head / heads_per_group;
|
||||
|
||||
const __nv_bfloat16* Q_ptr = Q +
|
||||
((long long)seq_idx * num_q_heads + q_head) * head_dim;
|
||||
__nv_bfloat16* O_ptr = O +
|
||||
((long long)seq_idx * num_q_heads + q_head) * head_dim;
|
||||
const int* bt = block_tables + (long long)seq_idx * max_blocks_per_seq;
|
||||
const int* mask_row = tree_mask + (long long)seq_idx * tree_len;
|
||||
|
||||
float q_reg[PAGED_HEAD_DIM_MAX];
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
q_reg[d] = __bfloat162float(Q_ptr[d]);
|
||||
}
|
||||
|
||||
float local_max = -INFINITY;
|
||||
float local_sum = 0.0f;
|
||||
float local_O[PAGED_HEAD_DIM_MAX];
|
||||
for (int d = 0; d < head_dim; d++) local_O[d] = 0.0f;
|
||||
|
||||
int kv_stride_block = num_kv_heads * PAGED_BLOCK_SIZE * head_dim;
|
||||
int kv_stride_head = PAGED_BLOCK_SIZE * head_dim;
|
||||
|
||||
for (int pos = tid; pos < kv_len; pos += PAGED_THREADS) {
|
||||
// Tree mask: skip positions in [tree_start, tree_start+tree_len) that
|
||||
// the mask marks as 0. Everything else (history) is always attended.
|
||||
if (pos >= tree_start && pos < tree_start + tree_len) {
|
||||
if (mask_row[pos - tree_start] == 0) continue;
|
||||
}
|
||||
|
||||
int logical_blk = pos / PAGED_BLOCK_SIZE;
|
||||
int slot_in_blk = pos % PAGED_BLOCK_SIZE;
|
||||
int phys_blk = bt[logical_blk];
|
||||
|
||||
const __nv_bfloat16* K_pos = K_cache
|
||||
+ (long long)phys_blk * kv_stride_block
|
||||
+ kv_head * kv_stride_head
|
||||
+ slot_in_blk * head_dim;
|
||||
const __nv_bfloat16* V_pos = V_cache
|
||||
+ (long long)phys_blk * kv_stride_block
|
||||
+ kv_head * kv_stride_head
|
||||
+ slot_in_blk * head_dim;
|
||||
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
dot += q_reg[d] * __bfloat162float(K_pos[d]);
|
||||
}
|
||||
float s = dot * scale;
|
||||
|
||||
float new_max = fmaxf(local_max, s);
|
||||
float correction = expf(local_max - new_max);
|
||||
float p = expf(s - new_max);
|
||||
|
||||
local_sum = local_sum * correction + p;
|
||||
for (int d = 0; d < head_dim; d++) local_O[d] *= correction;
|
||||
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
local_O[d] += p * __bfloat162float(V_pos[d]);
|
||||
}
|
||||
|
||||
local_max = new_max;
|
||||
}
|
||||
|
||||
// Block-level reduction (identical to base kernel).
|
||||
__shared__ float smem_max[32];
|
||||
__shared__ float smem_sum[32];
|
||||
__shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX];
|
||||
|
||||
int lane = tid & 31;
|
||||
int warp_id = tid >> 5;
|
||||
int num_warps = PAGED_THREADS >> 5;
|
||||
|
||||
float warp_max = local_max;
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
|
||||
if (lane == 0) smem_max[warp_id] = warp_max;
|
||||
__syncthreads();
|
||||
|
||||
float global_max;
|
||||
if (tid == 0) {
|
||||
global_max = smem_max[0];
|
||||
for (int i = 1; i < num_warps; i++)
|
||||
global_max = fmaxf(global_max, smem_max[i]);
|
||||
smem_max[0] = global_max;
|
||||
}
|
||||
__syncthreads();
|
||||
global_max = smem_max[0];
|
||||
|
||||
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
|
||||
local_sum *= rescale;
|
||||
for (int d = 0; d < head_dim; d++) local_O[d] *= rescale;
|
||||
|
||||
float warp_sum = local_sum;
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
|
||||
if (lane == 0) smem_sum[warp_id] = warp_sum;
|
||||
__syncthreads();
|
||||
|
||||
float global_sum;
|
||||
if (tid == 0) {
|
||||
global_sum = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++) global_sum += smem_sum[i];
|
||||
smem_sum[0] = global_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
global_sum = smem_sum[0];
|
||||
|
||||
for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) {
|
||||
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
float val = local_O[d];
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
if (lane == 0) smem_O_warp[warp_id][d] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
||||
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
|
||||
O_ptr[d] = __float2bfloat16(out * inv_sum);
|
||||
}
|
||||
}
|
||||
|
||||
// Extended paged decode attention with attention sinks and sliding window.
|
||||
// sinks: [num_q_heads] BF16 — per-head extra logit appended before softmax.
|
||||
// window_size: >0 = sliding window (only attend to last `window_size` positions), 0 = full.
|
||||
__global__ void paged_decode_attention_sinks_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ Q,
|
||||
const __nv_bfloat16* __restrict__ K_cache,
|
||||
const __nv_bfloat16* __restrict__ V_cache,
|
||||
__nv_bfloat16* __restrict__ O,
|
||||
const int* __restrict__ block_tables,
|
||||
const int* __restrict__ context_lens,
|
||||
const __nv_bfloat16* __restrict__ sinks, // [num_q_heads] or NULL
|
||||
int num_q_heads, int num_kv_heads,
|
||||
int head_dim, int max_blocks_per_seq,
|
||||
float scale, int window_size
|
||||
) {
|
||||
int seq_idx = blockIdx.y;
|
||||
int q_head = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
int kv_len = context_lens[seq_idx];
|
||||
if (kv_len <= 0) {
|
||||
if (tid < head_dim) {
|
||||
O[((long long)seq_idx * num_q_heads + q_head) * head_dim + tid] =
|
||||
__float2bfloat16(0.0f);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int heads_per_group = num_q_heads / num_kv_heads;
|
||||
int kv_head = q_head / heads_per_group;
|
||||
|
||||
const __nv_bfloat16* Q_ptr = Q +
|
||||
((long long)seq_idx * num_q_heads + q_head) * head_dim;
|
||||
__nv_bfloat16* O_ptr = O +
|
||||
((long long)seq_idx * num_q_heads + q_head) * head_dim;
|
||||
const int* bt = block_tables + (long long)seq_idx * max_blocks_per_seq;
|
||||
|
||||
// Sliding window: only attend to positions [kv_len - window_size, kv_len)
|
||||
int start_pos = 0;
|
||||
if (window_size > 0 && kv_len > window_size) {
|
||||
start_pos = kv_len - window_size;
|
||||
}
|
||||
|
||||
float q_reg[PAGED_HEAD_DIM_MAX];
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
q_reg[d] = __bfloat162float(Q_ptr[d]);
|
||||
}
|
||||
|
||||
float local_max = -INFINITY;
|
||||
float local_sum = 0.0f;
|
||||
float local_O[PAGED_HEAD_DIM_MAX];
|
||||
for (int d = 0; d < head_dim; d++) local_O[d] = 0.0f;
|
||||
|
||||
int kv_stride_block = num_kv_heads * PAGED_BLOCK_SIZE * head_dim;
|
||||
int kv_stride_head = PAGED_BLOCK_SIZE * head_dim;
|
||||
|
||||
int attend_len = kv_len - start_pos;
|
||||
for (int rel = tid; rel < attend_len; rel += PAGED_THREADS) {
|
||||
int pos = start_pos + rel;
|
||||
int logical_blk = pos / PAGED_BLOCK_SIZE;
|
||||
int slot_in_blk = pos % PAGED_BLOCK_SIZE;
|
||||
int phys_blk = bt[logical_blk];
|
||||
|
||||
const __nv_bfloat16* K_pos = K_cache
|
||||
+ (long long)phys_blk * kv_stride_block
|
||||
+ kv_head * kv_stride_head
|
||||
+ slot_in_blk * head_dim;
|
||||
const __nv_bfloat16* V_pos = V_cache
|
||||
+ (long long)phys_blk * kv_stride_block
|
||||
+ kv_head * kv_stride_head
|
||||
+ slot_in_blk * head_dim;
|
||||
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
dot += q_reg[d] * __bfloat162float(K_pos[d]);
|
||||
}
|
||||
float s = dot * scale;
|
||||
|
||||
float new_max = fmaxf(local_max, s);
|
||||
float correction = expf(local_max - new_max);
|
||||
float p = expf(s - new_max);
|
||||
|
||||
local_sum = local_sum * correction + p;
|
||||
for (int d = 0; d < head_dim; d++) local_O[d] *= correction;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
local_O[d] += p * __bfloat162float(V_pos[d]);
|
||||
}
|
||||
local_max = new_max;
|
||||
}
|
||||
|
||||
// Include the sink logit (only thread 0 handles it to avoid double-counting)
|
||||
float sink_logit = -INFINITY;
|
||||
if (sinks != nullptr && tid == 0) {
|
||||
sink_logit = __bfloat162float(sinks[q_head]);
|
||||
float new_max = fmaxf(local_max, sink_logit);
|
||||
float correction = expf(local_max - new_max);
|
||||
float p = expf(sink_logit - new_max);
|
||||
local_sum = local_sum * correction + p;
|
||||
for (int d = 0; d < head_dim; d++) local_O[d] *= correction;
|
||||
// Sink absorbs probability but produces no value output (p * 0)
|
||||
local_max = new_max;
|
||||
}
|
||||
|
||||
// ---- Block-level online softmax reduction (same as base kernel) ----
|
||||
__shared__ float smem_max[32];
|
||||
__shared__ float smem_sum[32];
|
||||
__shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX];
|
||||
|
||||
int lane = tid & 31;
|
||||
int warp_id = tid >> 5;
|
||||
int num_warps = PAGED_THREADS >> 5;
|
||||
|
||||
float warp_max = local_max;
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
|
||||
if (lane == 0) smem_max[warp_id] = warp_max;
|
||||
__syncthreads();
|
||||
|
||||
float global_max;
|
||||
if (tid == 0) {
|
||||
global_max = smem_max[0];
|
||||
for (int i = 1; i < num_warps; i++)
|
||||
global_max = fmaxf(global_max, smem_max[i]);
|
||||
smem_max[0] = global_max;
|
||||
}
|
||||
__syncthreads();
|
||||
global_max = smem_max[0];
|
||||
|
||||
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
|
||||
local_sum *= rescale;
|
||||
for (int d = 0; d < head_dim; d++) local_O[d] *= rescale;
|
||||
|
||||
float warp_sum = local_sum;
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
|
||||
if (lane == 0) smem_sum[warp_id] = warp_sum;
|
||||
__syncthreads();
|
||||
|
||||
float global_sum;
|
||||
if (tid == 0) {
|
||||
global_sum = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++) global_sum += smem_sum[i];
|
||||
smem_sum[0] = global_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
global_sum = smem_sum[0];
|
||||
|
||||
for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) {
|
||||
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
float val = local_O[d];
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
if (lane == 0) smem_O_warp[warp_id][d] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
||||
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
|
||||
O_ptr[d] = __float2bfloat16(out * inv_sum);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_paged_decode_attention_bf16(
|
||||
const void* Q,
|
||||
const void* K_cache,
|
||||
const void* V_cache,
|
||||
void* O,
|
||||
const int* block_tables,
|
||||
const int* context_lens,
|
||||
int batch, int num_q_heads, int num_kv_heads,
|
||||
int head_dim, int max_blocks_per_seq,
|
||||
float scale, void* stream
|
||||
) {
|
||||
dim3 grid(num_q_heads, batch);
|
||||
int block = PAGED_THREADS;
|
||||
|
||||
paged_decode_attention_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)Q,
|
||||
(const __nv_bfloat16*)K_cache,
|
||||
(const __nv_bfloat16*)V_cache,
|
||||
(__nv_bfloat16*)O,
|
||||
block_tables, context_lens,
|
||||
num_q_heads, num_kv_heads,
|
||||
head_dim, max_blocks_per_seq,
|
||||
scale
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_paged_decode_attention_tree_bf16(
|
||||
const void* Q,
|
||||
const void* K_cache,
|
||||
const void* V_cache,
|
||||
void* O,
|
||||
const int* block_tables,
|
||||
const int* context_lens,
|
||||
const int* tree_mask,
|
||||
int batch, int num_q_heads, int num_kv_heads,
|
||||
int head_dim, int max_blocks_per_seq,
|
||||
int tree_start, int tree_len,
|
||||
float scale, void* stream
|
||||
) {
|
||||
dim3 grid(num_q_heads, batch);
|
||||
int block = PAGED_THREADS;
|
||||
|
||||
paged_decode_attention_tree_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)Q,
|
||||
(const __nv_bfloat16*)K_cache,
|
||||
(const __nv_bfloat16*)V_cache,
|
||||
(__nv_bfloat16*)O,
|
||||
block_tables, context_lens, tree_mask,
|
||||
num_q_heads, num_kv_heads,
|
||||
head_dim, max_blocks_per_seq,
|
||||
tree_start, tree_len,
|
||||
scale
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_paged_decode_attention_sinks_bf16(
|
||||
const void* Q,
|
||||
const void* K_cache,
|
||||
const void* V_cache,
|
||||
void* O,
|
||||
const int* block_tables,
|
||||
const int* context_lens,
|
||||
const void* sinks,
|
||||
int batch, int num_q_heads, int num_kv_heads,
|
||||
int head_dim, int max_blocks_per_seq,
|
||||
float scale, int window_size, void* stream
|
||||
) {
|
||||
dim3 grid(num_q_heads, batch);
|
||||
int block = PAGED_THREADS;
|
||||
|
||||
paged_decode_attention_sinks_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)Q,
|
||||
(const __nv_bfloat16*)K_cache,
|
||||
(const __nv_bfloat16*)V_cache,
|
||||
(__nv_bfloat16*)O,
|
||||
block_tables, context_lens,
|
||||
(const __nv_bfloat16*)sinks,
|
||||
num_q_heads, num_kv_heads,
|
||||
head_dim, max_blocks_per_seq,
|
||||
scale, window_size
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
215
csrc/attention/reshape_and_cache.cu
Normal file
215
csrc/attention/reshape_and_cache.cu
Normal file
@@ -0,0 +1,215 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Scatter [num_tokens] new K/V into a paged KV pool for ONE sequence.
|
||||
//
|
||||
// Source layouts (BF16, contiguous):
|
||||
// k_src, v_src : [num_kv_heads, num_tokens, head_dim] (head-major)
|
||||
//
|
||||
// Pool layouts (BF16, contiguous):
|
||||
// k_pool, v_pool : [num_blocks_total, num_kv_heads, BLOCK_SIZE, head_dim]
|
||||
//
|
||||
// For token t (0 <= t < num_tokens):
|
||||
// p = start_pos + t
|
||||
// logical_blk = p / BLOCK_SIZE
|
||||
// slot_in_blk = p % BLOCK_SIZE
|
||||
// phys = block_ids[logical_blk]
|
||||
// pool[phys, h, slot_in_blk, :] := src[h, t, :]
|
||||
//
|
||||
// Replaces a Rust-side per-token, per-head cudaMemcpy loop. With Qwen3-8B
|
||||
// (8 KV heads, 36 layers) and a 1024-token prefill, that loop fired
|
||||
// ~290k device-side memcpys; one kernel launch per layer is dramatically
|
||||
// less overhead.
|
||||
//
|
||||
// Grid : (num_tokens, num_kv_heads)
|
||||
// Block: head_dim threads (≤128 in practice; head_dim is padded to a
|
||||
// multiple of 32 by the model and all our shipping configs are
|
||||
// 128, so a single warp's worth handles two slots in flight).
|
||||
|
||||
__global__ void reshape_and_cache_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ k_src,
|
||||
const __nv_bfloat16* __restrict__ v_src,
|
||||
__nv_bfloat16* __restrict__ k_pool,
|
||||
__nv_bfloat16* __restrict__ v_pool,
|
||||
const int* __restrict__ block_ids,
|
||||
int num_tokens, int num_heads,
|
||||
int head_dim, int start_pos, int block_size
|
||||
) {
|
||||
int t = blockIdx.x;
|
||||
int h = blockIdx.y;
|
||||
if (t >= num_tokens || h >= num_heads) return;
|
||||
|
||||
int p = start_pos + t;
|
||||
int logical_blk = p / block_size;
|
||||
int slot_in_blk = p - logical_blk * block_size;
|
||||
int phys = block_ids[logical_blk];
|
||||
|
||||
long long src_off = ((long long)h * num_tokens + t) * head_dim;
|
||||
long long dst_off = (((long long)phys * num_heads + h) * block_size + slot_in_blk) * head_dim;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int blockSize = blockDim.x;
|
||||
|
||||
// Per-thread strided copy. head_dim is typically 128 and blockSize is
|
||||
// 128, so each thread copies exactly one element — but the loop keeps
|
||||
// the kernel correct for non-128 head_dim configs (Phi-style 64, etc.).
|
||||
for (int d = tid; d < head_dim; d += blockSize) {
|
||||
k_pool[dst_off + d] = k_src[src_off + d];
|
||||
v_pool[dst_off + d] = v_src[src_off + d];
|
||||
}
|
||||
}
|
||||
|
||||
// Batched variant: writes one new K/V token per sequence into a paged
|
||||
// pool, indexed by a per-batch block table that also drives the paged
|
||||
// attention kernel. Used in the decode path where every seq advances
|
||||
// by exactly one position per step.
|
||||
//
|
||||
// Source layouts (BF16, contiguous):
|
||||
// k_src, v_src : [batch, num_kv_heads, head_dim]
|
||||
//
|
||||
// Pool layouts (BF16, contiguous):
|
||||
// k_pool, v_pool : [num_blocks_total, num_kv_heads, BLOCK_SIZE, head_dim]
|
||||
//
|
||||
// block_tables : int32 [batch, max_blocks_per_seq]
|
||||
// kv_lens : int32 [batch] (current seq_len BEFORE this step + 1
|
||||
// — i.e. the same buffer paged attention
|
||||
// reads. The new token's logical index
|
||||
// is `kv_lens[b] - 1`.)
|
||||
//
|
||||
// Grid : (batch, num_kv_heads)
|
||||
// Block: head_dim threads.
|
||||
|
||||
__global__ void reshape_and_cache_batched_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ k_src,
|
||||
const __nv_bfloat16* __restrict__ v_src,
|
||||
__nv_bfloat16* __restrict__ k_pool,
|
||||
__nv_bfloat16* __restrict__ v_pool,
|
||||
const int* __restrict__ block_tables,
|
||||
const int* __restrict__ kv_lens,
|
||||
int num_heads, int head_dim,
|
||||
int block_size, int max_blocks_per_seq
|
||||
) {
|
||||
int b = blockIdx.x;
|
||||
int h = blockIdx.y;
|
||||
|
||||
int new_pos = kv_lens[b] - 1;
|
||||
int logical_blk = new_pos / block_size;
|
||||
int slot_in_blk = new_pos - logical_blk * block_size;
|
||||
int phys = block_tables[b * max_blocks_per_seq + logical_blk];
|
||||
|
||||
long long src_off = ((long long)b * num_heads + h) * head_dim;
|
||||
long long dst_off = (((long long)phys * num_heads + h) * block_size + slot_in_blk) * head_dim;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int blockSize = blockDim.x;
|
||||
for (int d = tid; d < head_dim; d += blockSize) {
|
||||
k_pool[dst_off + d] = k_src[src_off + d];
|
||||
v_pool[dst_off + d] = v_src[src_off + d];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_reshape_and_cache_bf16(
|
||||
const void* k_src, const void* v_src,
|
||||
void* k_pool, void* v_pool,
|
||||
const void* block_ids,
|
||||
int num_tokens, int num_heads,
|
||||
int head_dim, int start_pos, int block_size,
|
||||
void* stream
|
||||
) {
|
||||
if (num_tokens <= 0) return;
|
||||
int threads = head_dim < 32 ? 32 : head_dim;
|
||||
if (threads > 1024) threads = 1024;
|
||||
dim3 grid(num_tokens, num_heads);
|
||||
reshape_and_cache_bf16_kernel<<<grid, threads, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)k_src,
|
||||
(const __nv_bfloat16*)v_src,
|
||||
(__nv_bfloat16*)k_pool,
|
||||
(__nv_bfloat16*)v_pool,
|
||||
(const int*)block_ids,
|
||||
num_tokens, num_heads,
|
||||
head_dim, start_pos, block_size
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_reshape_and_cache_batched_bf16(
|
||||
const void* k_src, const void* v_src,
|
||||
void* k_pool, void* v_pool,
|
||||
const void* block_tables, const void* kv_lens,
|
||||
int batch, int num_heads,
|
||||
int head_dim, int block_size, int max_blocks_per_seq,
|
||||
void* stream
|
||||
) {
|
||||
if (batch <= 0 || num_heads <= 0) return;
|
||||
int threads = head_dim < 32 ? 32 : head_dim;
|
||||
if (threads > 1024) threads = 1024;
|
||||
dim3 grid(batch, num_heads);
|
||||
reshape_and_cache_batched_bf16_kernel<<<grid, threads, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)k_src,
|
||||
(const __nv_bfloat16*)v_src,
|
||||
(__nv_bfloat16*)k_pool,
|
||||
(__nv_bfloat16*)v_pool,
|
||||
(const int*)block_tables,
|
||||
(const int*)kv_lens,
|
||||
num_heads, head_dim, block_size, max_blocks_per_seq
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
// Copy one token's K/V from src_pos to dst_pos within one pool.
|
||||
// Grid: (num_kv_heads,). Block: head_dim threads.
|
||||
// pool: [num_blocks_total, num_kv_heads, block_size, head_dim]
|
||||
// block_ids: [max_blocks] for this sequence (logical → physical block map).
|
||||
__global__ void copy_kv_position_kernel(
|
||||
__nv_bfloat16* __restrict__ pool,
|
||||
const int* __restrict__ block_ids,
|
||||
int src_pos, int dst_pos,
|
||||
int head_dim, int block_size
|
||||
) {
|
||||
int h = blockIdx.x;
|
||||
int d = threadIdx.x;
|
||||
if (d >= head_dim) return;
|
||||
|
||||
int num_kv_heads = gridDim.x;
|
||||
|
||||
int src_blk = src_pos / block_size;
|
||||
int src_slot = src_pos % block_size;
|
||||
int src_phys = block_ids[src_blk];
|
||||
|
||||
int dst_blk = dst_pos / block_size;
|
||||
int dst_slot = dst_pos % block_size;
|
||||
int dst_phys = block_ids[dst_blk];
|
||||
|
||||
long long src_off = ((long long)src_phys * num_kv_heads + h) * block_size * head_dim
|
||||
+ src_slot * head_dim + d;
|
||||
long long dst_off = ((long long)dst_phys * num_kv_heads + h) * block_size * head_dim
|
||||
+ dst_slot * head_dim + d;
|
||||
|
||||
pool[dst_off] = pool[src_off];
|
||||
}
|
||||
|
||||
void launch_copy_kv_position(
|
||||
void* k_pool, void* v_pool,
|
||||
const int* block_ids,
|
||||
int src_pos, int dst_pos,
|
||||
int num_kv_heads, int head_dim, int block_size,
|
||||
void* stream
|
||||
) {
|
||||
int threads = head_dim < 32 ? 32 : head_dim;
|
||||
if (threads > 1024) threads = 1024;
|
||||
dim3 grid(num_kv_heads);
|
||||
copy_kv_position_kernel<<<grid, threads, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)k_pool, block_ids,
|
||||
src_pos, dst_pos, head_dim, block_size
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
copy_kv_position_kernel<<<grid, threads, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)v_pool, block_ids,
|
||||
src_pos, dst_pos, head_dim, block_size
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -48,3 +48,18 @@ __device__ __forceinline__ float block_reduce_max(float val) {
|
||||
if (warp_id == 0) val = warp_reduce_max(val);
|
||||
return val;
|
||||
}
|
||||
|
||||
// --- Launch error checking ---
|
||||
// Always on, including release builds. A launch with an invalid config
|
||||
// (e.g. 32-bit overflow in grid/index math) is otherwise silent and produces
|
||||
// garbage with no clue — the MoE int32-overflow bug was found exactly because
|
||||
// release swallowed the launch failure. `cudaGetLastError()` does not
|
||||
// synchronize the stream, so the per-launch host cost is negligible.
|
||||
#include <cstdio>
|
||||
#define CUDA_CHECK_LAST_ERROR() do { \
|
||||
cudaError_t err = cudaGetLastError(); \
|
||||
if (err != cudaSuccess) { \
|
||||
fprintf(stderr, "CUDA kernel launch error at %s:%d: %s\n", \
|
||||
__FILE__, __LINE__, cudaGetErrorString(err)); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Embedding lookup: out[seq_idx] = table[token_ids[seq_idx]]
|
||||
// Grid: num_tokens, Block: handles hidden_size elements per token.
|
||||
@@ -7,10 +8,12 @@ __global__ void embedding_f32(
|
||||
const float* __restrict__ table, // [vocab_size, hidden_size]
|
||||
const int* __restrict__ token_ids, // [num_tokens]
|
||||
float* __restrict__ out, // [num_tokens, hidden_size]
|
||||
int hidden_size
|
||||
int hidden_size,
|
||||
int vocab_size
|
||||
) {
|
||||
int token_idx = blockIdx.x;
|
||||
int tid = token_ids[token_idx];
|
||||
if (tid < 0 || tid >= vocab_size) return;
|
||||
const float* row = table + tid * hidden_size;
|
||||
float* dst = out + token_idx * hidden_size;
|
||||
|
||||
@@ -23,10 +26,12 @@ __global__ void embedding_bf16(
|
||||
const __nv_bfloat16* __restrict__ table,
|
||||
const int* __restrict__ token_ids,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int hidden_size
|
||||
int hidden_size,
|
||||
int vocab_size
|
||||
) {
|
||||
int token_idx = blockIdx.x;
|
||||
int tid = token_ids[token_idx];
|
||||
if (tid < 0 || tid >= vocab_size) return;
|
||||
const __nv_bfloat16* row = table + tid * hidden_size;
|
||||
__nv_bfloat16* dst = out + token_idx * hidden_size;
|
||||
|
||||
@@ -38,18 +43,20 @@ __global__ void embedding_bf16(
|
||||
extern "C" {
|
||||
|
||||
void launch_embedding_f32(const void* table, const void* token_ids, void* out,
|
||||
int num_tokens, int hidden_size, void* stream) {
|
||||
int num_tokens, int hidden_size, int vocab_size, void* stream) {
|
||||
int block = (hidden_size < 256) ? hidden_size : 256;
|
||||
embedding_f32<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)table, (const int*)token_ids, (float*)out, hidden_size);
|
||||
(const float*)table, (const int*)token_ids, (float*)out, hidden_size, vocab_size);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_embedding_bf16(const void* table, const void* token_ids, void* out,
|
||||
int num_tokens, int hidden_size, void* stream) {
|
||||
int num_tokens, int hidden_size, int vocab_size, void* stream) {
|
||||
int block = (hidden_size < 256) ? hidden_size : 256;
|
||||
embedding_bf16<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)table, (const int*)token_ids,
|
||||
(__nv_bfloat16*)out, hidden_size);
|
||||
(__nv_bfloat16*)out, hidden_size, vocab_size);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <math.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// RoPE: Rotary Position Embedding
|
||||
// For each pair (x[2i], x[2i+1]) at position `pos`:
|
||||
// y[2i] = x[2i] * cos - x[2i+1] * sin
|
||||
// y[2i+1] = x[2i] * sin + x[2i+1] * cos
|
||||
// RoPE: Rotary Position Embedding, using the Qwen/Llama rotate_half layout.
|
||||
// For each dimension i in the first half at position `pos`:
|
||||
// y[i] = x[i] * cos - x[i + half_dim] * sin
|
||||
// y[i + half_dim] = x[i + half_dim] * cos + x[i] * sin
|
||||
// where cos/sin come from precomputed cos_cache/sin_cache.
|
||||
//
|
||||
// cos_cache[pos][i] = cos(pos * freq[i])
|
||||
@@ -35,11 +36,11 @@ __global__ void rope_f32(
|
||||
float sin_val = sin_cache[pos * half_dim + pair_idx];
|
||||
|
||||
int base = (token_idx * num_heads + head_idx) * head_dim;
|
||||
float x0 = x[base + 2 * pair_idx];
|
||||
float x1 = x[base + 2 * pair_idx + 1];
|
||||
float x0 = x[base + pair_idx];
|
||||
float x1 = x[base + pair_idx + half_dim];
|
||||
|
||||
x[base + 2 * pair_idx] = x0 * cos_val - x1 * sin_val;
|
||||
x[base + 2 * pair_idx + 1] = x0 * sin_val + x1 * cos_val;
|
||||
x[base + pair_idx] = x0 * cos_val - x1 * sin_val;
|
||||
x[base + pair_idx + half_dim] = x1 * cos_val + x0 * sin_val;
|
||||
}
|
||||
|
||||
__global__ void rope_bf16(
|
||||
@@ -61,11 +62,11 @@ __global__ void rope_bf16(
|
||||
float sin_val = sin_cache[pos * half_dim + pair_idx];
|
||||
|
||||
int base = (token_idx * num_heads + head_idx) * head_dim;
|
||||
float x0 = __bfloat162float(x[base + 2 * pair_idx]);
|
||||
float x1 = __bfloat162float(x[base + 2 * pair_idx + 1]);
|
||||
float x0 = __bfloat162float(x[base + pair_idx]);
|
||||
float x1 = __bfloat162float(x[base + pair_idx + half_dim]);
|
||||
|
||||
x[base + 2 * pair_idx] = __float2bfloat16(x0 * cos_val - x1 * sin_val);
|
||||
x[base + 2 * pair_idx + 1] = __float2bfloat16(x0 * sin_val + x1 * cos_val);
|
||||
x[base + pair_idx] = __float2bfloat16(x0 * cos_val - x1 * sin_val);
|
||||
x[base + pair_idx + half_dim] = __float2bfloat16(x1 * cos_val + x0 * sin_val);
|
||||
}
|
||||
|
||||
// Precompute cos/sin cache on GPU
|
||||
@@ -94,6 +95,7 @@ void launch_rope_f32(void* x, const void* cos_cache, const void* sin_cache,
|
||||
rope_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(float*)x, (const float*)cos_cache, (const float*)sin_cache,
|
||||
(const int*)positions, num_heads, head_dim);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_rope_bf16(void* x, const void* cos_cache, const void* sin_cache,
|
||||
@@ -104,6 +106,7 @@ void launch_rope_bf16(void* x, const void* cos_cache, const void* sin_cache,
|
||||
rope_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)x, (const float*)cos_cache, (const float*)sin_cache,
|
||||
(const int*)positions, num_heads, head_dim);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_compute_rope_cache(void* cos_cache, void* sin_cache,
|
||||
@@ -111,6 +114,7 @@ void launch_compute_rope_cache(void* cos_cache, void* sin_cache,
|
||||
void* stream) {
|
||||
compute_rope_cache<<<max_seq_len, half_dim, 0, (cudaStream_t)stream>>>(
|
||||
(float*)cos_cache, (float*)sin_cache, max_seq_len, half_dim, theta);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
242
csrc/embedding/transpose.cu
Normal file
242
csrc/embedding/transpose.cu
Normal file
@@ -0,0 +1,242 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Transpose between [S, H, D] and [H, S, D] layouts (used for RoPE and attention).
|
||||
// Also handles [S, H*D] → [H, S, D] (reshape_heads) and reverse (merge_heads).
|
||||
|
||||
// reshape_heads: [S, H*D] → [1, H, S, D]
|
||||
// Input layout: element at [s, h*D + d] = flat[s * H*D + h*D + d]
|
||||
// Output layout: element at [0, h, s, d] = flat[h * S*D + s*D + d]
|
||||
__global__ void reshape_heads_bf16(
|
||||
const __nv_bfloat16* __restrict__ in,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int seq_len, int num_heads, int head_dim
|
||||
) {
|
||||
int hidden = num_heads * head_dim;
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = seq_len * hidden;
|
||||
if (idx >= total) return;
|
||||
|
||||
int s = idx / hidden;
|
||||
int rem = idx % hidden;
|
||||
int h = rem / head_dim;
|
||||
int d = rem % head_dim;
|
||||
|
||||
int out_idx = h * seq_len * head_dim + s * head_dim + d;
|
||||
out[out_idx] = in[idx];
|
||||
}
|
||||
|
||||
// merge_heads: [1, H, S, D] → [S, H*D]
|
||||
// Input layout: element at [0, h, s, d] = flat[h * S*D + s*D + d]
|
||||
// Output layout: element at [s, h*D + d] = flat[s * H*D + h*D + d]
|
||||
__global__ void merge_heads_bf16(
|
||||
const __nv_bfloat16* __restrict__ in,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int seq_len, int num_heads, int head_dim
|
||||
) {
|
||||
int hidden = num_heads * head_dim;
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = seq_len * hidden;
|
||||
if (idx >= total) return;
|
||||
|
||||
// idx is output index: [s, h*D + d]
|
||||
int s = idx / hidden;
|
||||
int rem = idx % hidden;
|
||||
int h = rem / head_dim;
|
||||
int d = rem % head_dim;
|
||||
|
||||
int in_idx = h * seq_len * head_dim + s * head_dim + d;
|
||||
out[idx] = in[in_idx];
|
||||
}
|
||||
|
||||
// transpose_for_rope: [1, H, S, D] → [S, H, D]
|
||||
// Input: [h, s, d] at h*S*D + s*D + d
|
||||
// Output: [s, h, d] at s*H*D + h*D + d
|
||||
__global__ void transpose_hsd_to_shd_bf16(
|
||||
const __nv_bfloat16* __restrict__ in,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int seq_len, int num_heads, int head_dim
|
||||
) {
|
||||
int total = seq_len * num_heads * head_dim;
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= total) return;
|
||||
|
||||
// idx = output flat index: s*H*D + h*D + d
|
||||
int s = idx / (num_heads * head_dim);
|
||||
int rem = idx % (num_heads * head_dim);
|
||||
int h = rem / head_dim;
|
||||
int d = rem % head_dim;
|
||||
|
||||
int in_idx = h * seq_len * head_dim + s * head_dim + d;
|
||||
out[idx] = in[in_idx];
|
||||
}
|
||||
|
||||
// transpose_from_rope: [S, H, D] → [1, H, S, D]
|
||||
// Input: [s, h, d] at s*H*D + h*D + d
|
||||
// Output: [h, s, d] at h*S*D + s*D + d
|
||||
__global__ void transpose_shd_to_hsd_bf16(
|
||||
const __nv_bfloat16* __restrict__ in,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int seq_len, int num_heads, int head_dim
|
||||
) {
|
||||
int total = seq_len * num_heads * head_dim;
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= total) return;
|
||||
|
||||
// idx = output flat index: h*S*D + s*D + d
|
||||
int h = idx / (seq_len * head_dim);
|
||||
int rem = idx % (seq_len * head_dim);
|
||||
int s = rem / head_dim;
|
||||
int d = rem % head_dim;
|
||||
|
||||
int in_idx = s * num_heads * head_dim + h * head_dim + d;
|
||||
out[idx] = in[in_idx];
|
||||
}
|
||||
|
||||
// repeat_kv: [1, KV_H, S, D] → [1, KV_H * n_rep, S, D]
|
||||
__global__ void repeat_kv_bf16(
|
||||
const __nv_bfloat16* __restrict__ in,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int kv_heads, int n_rep, int seq_len, int head_dim
|
||||
) {
|
||||
int total_heads = kv_heads * n_rep;
|
||||
int total = total_heads * seq_len * head_dim;
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= total) return;
|
||||
|
||||
int out_h = idx / (seq_len * head_dim);
|
||||
int rem = idx % (seq_len * head_dim);
|
||||
int kv_h = out_h / n_rep;
|
||||
|
||||
int in_idx = kv_h * seq_len * head_dim + rem;
|
||||
out[idx] = in[in_idx];
|
||||
}
|
||||
|
||||
// ---- Generic strided copy (up to 4D) ----
|
||||
// Each thread copies one element. Maps flat contiguous output index to strided input index.
|
||||
// Unused dimensions are padded with shape=1, stride=0.
|
||||
|
||||
__global__ void strided_copy_bf16(
|
||||
const __nv_bfloat16* __restrict__ in,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int numel,
|
||||
int ndim,
|
||||
int shape0, int shape1, int shape2, int shape3,
|
||||
int in_stride0, int in_stride1, int in_stride2, int in_stride3,
|
||||
int in_offset
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= numel) return;
|
||||
|
||||
// Decompose flat output index into multi-dim indices (rightmost = fastest)
|
||||
int remaining = idx;
|
||||
int i3 = remaining % shape3; remaining /= shape3;
|
||||
int i2 = remaining % shape2; remaining /= shape2;
|
||||
int i1 = remaining % shape1; remaining /= shape1;
|
||||
int i0 = remaining;
|
||||
|
||||
int in_idx = in_offset + i0 * in_stride0 + i1 * in_stride1 + i2 * in_stride2 + i3 * in_stride3;
|
||||
out[idx] = in[in_idx];
|
||||
}
|
||||
|
||||
__global__ void strided_copy_f32(
|
||||
const float* __restrict__ in,
|
||||
float* __restrict__ out,
|
||||
int numel,
|
||||
int ndim,
|
||||
int shape0, int shape1, int shape2, int shape3,
|
||||
int in_stride0, int in_stride1, int in_stride2, int in_stride3,
|
||||
int in_offset
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= numel) return;
|
||||
|
||||
int remaining = idx;
|
||||
int i3 = remaining % shape3; remaining /= shape3;
|
||||
int i2 = remaining % shape2; remaining /= shape2;
|
||||
int i1 = remaining % shape1; remaining /= shape1;
|
||||
int i0 = remaining;
|
||||
|
||||
int in_idx = in_offset + i0 * in_stride0 + i1 * in_stride1 + i2 * in_stride2 + i3 * in_stride3;
|
||||
out[idx] = in[in_idx];
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_reshape_heads_bf16(const void* in, void* out,
|
||||
int seq_len, int num_heads, int head_dim, void* stream) {
|
||||
int total = seq_len * num_heads * head_dim;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
reshape_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_merge_heads_bf16(const void* in, void* out,
|
||||
int seq_len, int num_heads, int head_dim, void* stream) {
|
||||
int total = seq_len * num_heads * head_dim;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
merge_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_transpose_hsd_to_shd_bf16(const void* in, void* out,
|
||||
int seq_len, int num_heads, int head_dim, void* stream) {
|
||||
int total = seq_len * num_heads * head_dim;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
transpose_hsd_to_shd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_transpose_shd_to_hsd_bf16(const void* in, void* out,
|
||||
int seq_len, int num_heads, int head_dim, void* stream) {
|
||||
int total = seq_len * num_heads * head_dim;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
transpose_shd_to_hsd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_repeat_kv_bf16(const void* in, void* out,
|
||||
int kv_heads, int n_rep, int seq_len, int head_dim, void* stream) {
|
||||
int total = kv_heads * n_rep * seq_len * head_dim;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
repeat_kv_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, kv_heads, n_rep, seq_len, head_dim);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_strided_copy_bf16(const void* in, void* out, int numel, int ndim,
|
||||
int shape0, int shape1, int shape2, int shape3,
|
||||
int in_stride0, int in_stride1, int in_stride2, int in_stride3,
|
||||
int in_offset, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (numel + block - 1) / block;
|
||||
strided_copy_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, numel, ndim,
|
||||
shape0, shape1, shape2, shape3,
|
||||
in_stride0, in_stride1, in_stride2, in_stride3, in_offset);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_strided_copy_f32(const void* in, void* out, int numel, int ndim,
|
||||
int shape0, int shape1, int shape2, int shape3,
|
||||
int in_stride0, int in_stride1, int in_stride2, int in_stride3,
|
||||
int in_offset, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (numel + block - 1) / block;
|
||||
strided_copy_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)in, (float*)out, numel, ndim,
|
||||
shape0, shape1, shape2, shape3,
|
||||
in_stride0, in_stride1, in_stride2, in_stride3, in_offset);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
196
csrc/gemm/gemv.cu
Normal file
196
csrc/gemm/gemv.cu
Normal file
@@ -0,0 +1,196 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// K-split GEMV for M=1 BF16 decode.
|
||||
//
|
||||
// y[n] = sum_k x[k] * W[k * N + n]
|
||||
//
|
||||
// Grid: (N / TILE_N, K / TILE_K) partials, followed by a deterministic
|
||||
// fixed-order reduction over K blocks. The previous implementation used
|
||||
// atomicAdd into y_fp32[col]; that made BF16 greedy decode sensitive to
|
||||
// inter-block scheduling when logits were close.
|
||||
|
||||
#define GEMV_TILE_N 128
|
||||
#define GEMV_TILE_K 256
|
||||
#define GEMV_BLOCK 128
|
||||
|
||||
__global__ void gemv_bf16_partial_kernel(
|
||||
const __nv_bfloat16* __restrict__ x,
|
||||
const __nv_bfloat16* __restrict__ W,
|
||||
float* __restrict__ partials,
|
||||
int K, int N
|
||||
) {
|
||||
const int block_n = blockIdx.x;
|
||||
const int block_k = blockIdx.y;
|
||||
const int t = threadIdx.x;
|
||||
const int col = block_n * GEMV_TILE_N + t;
|
||||
|
||||
const int k_start = block_k * GEMV_TILE_K;
|
||||
const int k_end = min(k_start + GEMV_TILE_K, K);
|
||||
const int k_len = k_end - k_start;
|
||||
|
||||
// Cooperative load of x into shared memory uses ALL threads in the block
|
||||
// (indexed by t, independent of col). Threads whose column is out of range
|
||||
// must still help load and reach the barrier — returning early here would
|
||||
// leave part of x_shared uninitialized AND make __syncthreads divergent
|
||||
// (UB). So the col>=N check happens only AFTER the load + barrier. This bug
|
||||
// produced intermittent huge/garbage outputs whenever N % GEMV_TILE_N != 0
|
||||
// (e.g. gpt-oss decode o_proj with N=2880), collapsing the forward pass.
|
||||
__shared__ float x_shared[GEMV_TILE_K];
|
||||
for (int i = t; i < k_len; i += GEMV_BLOCK) {
|
||||
x_shared[i] = __bfloat162float(x[k_start + i]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (col >= N) return;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int ki = 0; ki < k_len; ki++) {
|
||||
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
|
||||
}
|
||||
|
||||
partials[(long long)block_k * N + col] = sum;
|
||||
}
|
||||
|
||||
__global__ void gemv_reduce_to_bf16_kernel(
|
||||
const float* __restrict__ partials,
|
||||
__nv_bfloat16* __restrict__ dst,
|
||||
int n,
|
||||
int num_k_blocks
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
float sum = 0.0f;
|
||||
for (int kb = 0; kb < num_k_blocks; kb++) {
|
||||
sum += partials[(long long)kb * n + idx];
|
||||
}
|
||||
dst[idx] = __float2bfloat16(sum);
|
||||
}
|
||||
}
|
||||
|
||||
// Batched variant: M rows, same W. Grid.z = batch row index.
|
||||
// Numerically identical to calling launch_gemv_bf16 M times in sequence because
|
||||
// each z-slice executes the same accumulation order on the same data.
|
||||
// partials buffer must be [M * num_k_blocks * N] floats.
|
||||
__global__ void gemv_bf16_batched_partial_kernel(
|
||||
const __nv_bfloat16* __restrict__ x, // [M, K]
|
||||
const __nv_bfloat16* __restrict__ W, // [K, N]
|
||||
float* __restrict__ partials, // [M, num_k_blocks, N]
|
||||
int K, int N
|
||||
) {
|
||||
const int block_n = blockIdx.x;
|
||||
const int block_k = blockIdx.y;
|
||||
const int row = blockIdx.z;
|
||||
const int t = threadIdx.x;
|
||||
const int col = block_n * GEMV_TILE_N + t;
|
||||
|
||||
const int k_start = block_k * GEMV_TILE_K;
|
||||
const int k_end = min(k_start + GEMV_TILE_K, K);
|
||||
const int k_len = k_end - k_start;
|
||||
|
||||
__shared__ float x_shared[GEMV_TILE_K];
|
||||
const __nv_bfloat16* x_row = x + (long long)row * K;
|
||||
for (int i = t; i < k_len; i += GEMV_BLOCK) {
|
||||
x_shared[i] = __bfloat162float(x_row[k_start + i]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (col >= N) return;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int ki = 0; ki < k_len; ki++) {
|
||||
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
|
||||
}
|
||||
|
||||
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
||||
partials[((long long)row * num_k_blocks + block_k) * N + col] = sum;
|
||||
}
|
||||
|
||||
__global__ void gemv_batched_reduce_to_bf16_kernel(
|
||||
const float* __restrict__ partials, // [M, num_k_blocks, N]
|
||||
__nv_bfloat16* __restrict__ dst, // [M, N]
|
||||
int n,
|
||||
int num_k_blocks
|
||||
) {
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row = blockIdx.y;
|
||||
if (col >= n) return;
|
||||
|
||||
float sum = 0.0f;
|
||||
const float* row_partials = partials + (long long)row * num_k_blocks * n;
|
||||
for (int kb = 0; kb < num_k_blocks; kb++) {
|
||||
sum += row_partials[(long long)kb * n + col];
|
||||
}
|
||||
dst[(long long)row * n + col] = __float2bfloat16(sum);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_gemv_bf16(
|
||||
const void* x,
|
||||
const void* W,
|
||||
void* y_bf16,
|
||||
void* y_fp32_buf,
|
||||
int K, int N,
|
||||
void* stream
|
||||
) {
|
||||
cudaStream_t s = (cudaStream_t)stream;
|
||||
|
||||
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
||||
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks);
|
||||
|
||||
gemv_bf16_partial_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
|
||||
(const __nv_bfloat16*)x,
|
||||
(const __nv_bfloat16*)W,
|
||||
(float*)y_fp32_buf,
|
||||
K, N
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
|
||||
// Fixed-order FP32 reduction over K blocks, then BF16 conversion.
|
||||
int conv_block = 256;
|
||||
int conv_grid = (N + conv_block - 1) / conv_block;
|
||||
gemv_reduce_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
|
||||
(const float*)y_fp32_buf,
|
||||
(__nv_bfloat16*)y_bf16,
|
||||
N,
|
||||
num_k_blocks
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_gemv_bf16_batched(
|
||||
const void* x, // [M, K] BF16
|
||||
const void* W, // [K, N] BF16
|
||||
void* y_bf16, // [M, N] BF16
|
||||
void* y_fp32_buf, // [M * num_k_blocks * N] FP32
|
||||
int M, int K, int N,
|
||||
void* stream
|
||||
) {
|
||||
cudaStream_t s = (cudaStream_t)stream;
|
||||
|
||||
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
||||
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks, M);
|
||||
|
||||
gemv_bf16_batched_partial_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
|
||||
(const __nv_bfloat16*)x,
|
||||
(const __nv_bfloat16*)W,
|
||||
(float*)y_fp32_buf,
|
||||
K, N
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
|
||||
int conv_block = 256;
|
||||
int conv_grid_x = (N + conv_block - 1) / conv_block;
|
||||
dim3 reduce_grid(conv_grid_x, M);
|
||||
gemv_batched_reduce_to_bf16_kernel<<<reduce_grid, conv_block, 0, s>>>(
|
||||
(const float*)y_fp32_buf,
|
||||
(__nv_bfloat16*)y_bf16,
|
||||
N,
|
||||
num_k_blocks
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@@ -1,4 +1,5 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Naive GEMM: each thread computes one element of C.
|
||||
// C[i][j] = sum_k A[i][k] * B[k][j]
|
||||
@@ -46,6 +47,7 @@ void launch_gemm_naive_bf16(
|
||||
gemm_naive_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_gemm_naive_f32(
|
||||
@@ -57,6 +59,7 @@ void launch_gemm_naive_f32(
|
||||
gemm_naive_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)A, (const float*)B, (float*)C, M, N, K
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Tiled GEMM using shared memory.
|
||||
// Each thread block loads TILE_SIZE x TILE_SIZE tiles of A and B
|
||||
@@ -100,6 +101,7 @@ void launch_gemm_tiled_f32(
|
||||
gemm_tiled_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)A, (const float*)B, (float*)C, M, N, K
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_gemm_tiled_bf16(
|
||||
@@ -111,6 +113,7 @@ void launch_gemm_tiled_bf16(
|
||||
gemm_tiled_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
254
csrc/moe/moe_kernels.cu
Normal file
254
csrc/moe/moe_kernels.cu
Normal file
@@ -0,0 +1,254 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <float.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// ============================================================
|
||||
// MoE Top-K + Softmax kernel
|
||||
//
|
||||
// Input: router_logits [num_tokens, num_experts] BF16
|
||||
// Output: topk_ids [num_tokens, top_k] int32
|
||||
// topk_weights [num_tokens, top_k] float32
|
||||
//
|
||||
// One block per token. Threads cooperatively find top-k indices
|
||||
// via repeated argmax, then compute softmax over the k winners.
|
||||
// num_experts <= 256 (fits in registers / shared memory).
|
||||
// ============================================================
|
||||
|
||||
#define MOE_MAX_EXPERTS 256
|
||||
#define MOE_MAX_TOPK 8
|
||||
|
||||
__global__ void moe_topk_softmax_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ router_logits,
|
||||
int* __restrict__ topk_ids,
|
||||
float* __restrict__ topk_weights,
|
||||
int num_experts, int top_k
|
||||
) {
|
||||
int token = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
const __nv_bfloat16* row = router_logits + token * num_experts;
|
||||
|
||||
// Load logits into shared memory
|
||||
__shared__ float smem_logits[MOE_MAX_EXPERTS];
|
||||
__shared__ int smem_ids[MOE_MAX_TOPK];
|
||||
__shared__ float smem_vals[MOE_MAX_TOPK];
|
||||
|
||||
for (int i = tid; i < num_experts; i += blockDim.x) {
|
||||
smem_logits[i] = __bfloat162float(row[i]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Find top-k via repeated argmax (k is small, typically 4)
|
||||
if (tid == 0) {
|
||||
for (int k = 0; k < top_k; k++) {
|
||||
float best_val = -INFINITY;
|
||||
int best_idx = 0;
|
||||
for (int e = 0; e < num_experts; e++) {
|
||||
if (smem_logits[e] > best_val) {
|
||||
best_val = smem_logits[e];
|
||||
best_idx = e;
|
||||
}
|
||||
}
|
||||
smem_ids[k] = best_idx;
|
||||
smem_vals[k] = best_val;
|
||||
smem_logits[best_idx] = -INFINITY; // mask out selected
|
||||
}
|
||||
|
||||
// Softmax over top-k values (in FP32)
|
||||
float max_val = smem_vals[0];
|
||||
for (int k = 1; k < top_k; k++)
|
||||
max_val = fmaxf(max_val, smem_vals[k]);
|
||||
|
||||
float exp_sum = 0.0f;
|
||||
for (int k = 0; k < top_k; k++) {
|
||||
smem_vals[k] = expf(smem_vals[k] - max_val);
|
||||
exp_sum += smem_vals[k];
|
||||
}
|
||||
float inv_sum = 1.0f / exp_sum;
|
||||
for (int k = 0; k < top_k; k++)
|
||||
smem_vals[k] *= inv_sum;
|
||||
|
||||
// Write outputs
|
||||
for (int k = 0; k < top_k; k++) {
|
||||
topk_ids[token * top_k + k] = smem_ids[k];
|
||||
topk_weights[token * top_k + k] = smem_vals[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// MoE Replicate kernel
|
||||
//
|
||||
// Input: x [num_tokens, hidden] BF16
|
||||
// Output: x_rep [local_experts, num_tokens, hidden] BF16
|
||||
//
|
||||
// Copies x into each expert's batch slot.
|
||||
// ============================================================
|
||||
|
||||
__global__ void moe_replicate_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ x,
|
||||
__nv_bfloat16* __restrict__ x_rep,
|
||||
int num_tokens, int hidden, int local_experts
|
||||
) {
|
||||
// 64-bit index: local_experts * num_tokens * hidden overflows int32 at
|
||||
// ~2.3k prefill tokens (gpt-oss TP=1, 32 experts), which is inside the
|
||||
// supported context window. A 32-bit `total` silently wraps, the launch
|
||||
// fails, and (in release) the error is invisible — see common.cuh.
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long total = (long long)local_experts * num_tokens * hidden;
|
||||
if (idx >= total) return;
|
||||
|
||||
// x_rep[expert, token, dim] = x[token, dim]
|
||||
long long row_stride = (long long)num_tokens * hidden;
|
||||
x_rep[idx] = x[idx % row_stride];
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// MoE Bias Add 3D kernel
|
||||
//
|
||||
// Input: x [batch, num_tokens, dim] BF16 (in-place output)
|
||||
// bias [batch, dim] BF16
|
||||
//
|
||||
// x[b, t, d] += bias[b, d]
|
||||
// ============================================================
|
||||
|
||||
__global__ void moe_bias_add_3d_bf16_kernel(
|
||||
__nv_bfloat16* __restrict__ x,
|
||||
const __nv_bfloat16* __restrict__ bias,
|
||||
int batch, int num_tokens, int dim
|
||||
) {
|
||||
// 64-bit index: batch * num_tokens * dim overflows int32 at ~3.6k prefill
|
||||
// tokens (gpt-oss TP=1, 32 experts, 2*intermediate dim) — see moe_replicate.
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long total = (long long)batch * num_tokens * dim;
|
||||
if (idx >= total) return;
|
||||
|
||||
long long td = (long long)num_tokens * dim;
|
||||
int b = (int)(idx / td); // < batch (small)
|
||||
int d = (int)(idx % dim); // < dim
|
||||
float v = __bfloat162float(x[idx]) + __bfloat162float(bias[(long long)b * dim + d]);
|
||||
x[idx] = __float2bfloat16(v);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// MoE Weighted Sum kernel
|
||||
//
|
||||
// Input: expert_out [local_experts, num_tokens, hidden] BF16
|
||||
// topk_ids [num_tokens, top_k] int32 (global expert ids)
|
||||
// topk_weights[num_tokens, top_k] float32
|
||||
// expert_start: first global expert id this rank owns
|
||||
// local_experts: number of experts this rank owns
|
||||
//
|
||||
// Output: out [num_tokens, hidden] BF16
|
||||
//
|
||||
// For each (token, dim): accumulate in FP32:
|
||||
// sum = 0
|
||||
// for k in 0..top_k:
|
||||
// global_id = topk_ids[token, k]
|
||||
// if global_id in [expert_start, expert_start + local_experts):
|
||||
// local_id = global_id - expert_start
|
||||
// sum += topk_weights[token, k] * expert_out[local_id, token, dim]
|
||||
// out[token, dim] = bf16(sum)
|
||||
// ============================================================
|
||||
|
||||
__global__ void moe_weighted_sum_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ expert_out,
|
||||
const int* __restrict__ topk_ids,
|
||||
const float* __restrict__ topk_weights,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int num_tokens, int hidden, int top_k,
|
||||
int expert_start, int local_experts
|
||||
) {
|
||||
// 64-bit index: `local_id * expert_stride` overflows int32 for long prefills
|
||||
// (expert_stride = num_tokens * hidden), reading the wrong expert element.
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long total = (long long)num_tokens * hidden;
|
||||
if (idx >= total) return;
|
||||
|
||||
long long token = idx / hidden;
|
||||
int dim = (int)(idx % hidden);
|
||||
|
||||
long long expert_stride = (long long)num_tokens * hidden; // stride between experts in expert_out
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int k = 0; k < top_k; k++) {
|
||||
int global_id = topk_ids[token * top_k + k];
|
||||
int local_id = global_id - expert_start;
|
||||
if (local_id >= 0 && local_id < local_experts) {
|
||||
float w = topk_weights[token * top_k + k];
|
||||
float v = __bfloat162float(expert_out[local_id * expert_stride + token * hidden + dim]);
|
||||
sum += w * v;
|
||||
}
|
||||
}
|
||||
out[idx] = __float2bfloat16(sum);
|
||||
}
|
||||
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_moe_topk_softmax_bf16(
|
||||
const void* router_logits,
|
||||
void* topk_ids, void* topk_weights,
|
||||
int num_tokens, int num_experts, int top_k,
|
||||
void* stream
|
||||
) {
|
||||
int block = 128;
|
||||
moe_topk_softmax_bf16_kernel<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)router_logits,
|
||||
(int*)topk_ids, (float*)topk_weights,
|
||||
num_experts, top_k
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_moe_replicate_bf16(
|
||||
const void* x, void* x_rep,
|
||||
int num_tokens, int hidden, int local_experts,
|
||||
void* stream
|
||||
) {
|
||||
long long total = (long long)local_experts * num_tokens * hidden;
|
||||
int block = 256;
|
||||
int grid = (int)((total + block - 1) / block);
|
||||
moe_replicate_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)x_rep,
|
||||
num_tokens, hidden, local_experts
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_moe_bias_add_3d_bf16(
|
||||
void* x, const void* bias,
|
||||
int batch, int num_tokens, int dim,
|
||||
void* stream
|
||||
) {
|
||||
long long total = (long long)batch * num_tokens * dim;
|
||||
int block = 256;
|
||||
int grid = (int)((total + block - 1) / block);
|
||||
moe_bias_add_3d_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)x, (const __nv_bfloat16*)bias,
|
||||
batch, num_tokens, dim
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_moe_weighted_sum_bf16(
|
||||
const void* expert_out,
|
||||
const void* topk_ids, const void* topk_weights,
|
||||
void* out,
|
||||
int num_tokens, int hidden, int top_k,
|
||||
int expert_start, int local_experts,
|
||||
void* stream
|
||||
) {
|
||||
long long total = (long long)num_tokens * hidden;
|
||||
int block = 256;
|
||||
int grid = (int)((total + block - 1) / block);
|
||||
moe_weighted_sum_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)expert_out,
|
||||
(const int*)topk_ids, (const float*)topk_weights,
|
||||
(__nv_bfloat16*)out,
|
||||
num_tokens, hidden, top_k,
|
||||
expert_start, local_experts
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
254
csrc/moe/moe_sparse.cu
Normal file
254
csrc/moe/moe_sparse.cu
Normal file
@@ -0,0 +1,254 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cstdint>
|
||||
#include "../common.cuh"
|
||||
|
||||
// ============================================================
|
||||
// Sparse MoE decode GEMVs — compute ONLY the routed experts.
|
||||
//
|
||||
// The dense path replicates x across all local experts and runs a
|
||||
// batched GEMM, reading every expert's weights per token. Decode is
|
||||
// memory-bound, so reading only the top-k routed experts' weights
|
||||
// (~2 of 16 local on average at TP=2) is a ~8x byte reduction.
|
||||
//
|
||||
// Each block handles one (token, slot) pair's tile of output columns.
|
||||
// It reads topk_ids[token, slot] from device memory (no host sync),
|
||||
// and exits early if the expert is not owned by this rank. The early
|
||||
// return is BLOCK-UNIFORM (every thread sees the same topk_ids value
|
||||
// and returns before the shared-memory staging + __syncthreads), so
|
||||
// it is safe — unlike the divergent-return bug fixed in gemv.cu.
|
||||
//
|
||||
// Outputs for non-local slots are NEVER written (uninitialized memory,
|
||||
// possibly NaN bit patterns). Downstream consumers must SKIP non-local
|
||||
// slots rather than multiply by zero (NaN * 0 = NaN).
|
||||
//
|
||||
// Per-expert weight scale and bias are fused into the epilogue:
|
||||
// y[t, slot, n] = acc * w_scale[lid] + bias[lid, n]
|
||||
// which matches the dense path's GEMM -> moe_bias_add_3d sequence.
|
||||
//
|
||||
// Activation addressing (x_per_slot):
|
||||
// gate_up: all slots of a token share x[token, :] (x_per_slot=0)
|
||||
// down: each slot has its own activation row
|
||||
// x[token * top_k + slot, :] (x_per_slot=1)
|
||||
// ============================================================
|
||||
|
||||
#define SPARSE_TILE_N 8 // output columns per block (= warps per block)
|
||||
|
||||
// Weights FP8 E4M3 [local_experts, N, K], activations BF16 (W8A16).
|
||||
// Decode is memory-bound (~2 FLOP/byte), so dequant-in-registers GEMV
|
||||
// loses nothing to tensor cores and skips activation quantization.
|
||||
__global__ void moe_sparse_gemv_fp8_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ x, // [T, K] or [T*top_k, K]
|
||||
const __nv_fp8_e4m3* __restrict__ w, // [local_experts, N, K]
|
||||
const float* __restrict__ w_scales, // [local_experts]
|
||||
const __nv_bfloat16* __restrict__ bias, // [local_experts, N]
|
||||
const int* __restrict__ topk_ids, // [T, top_k] global expert ids
|
||||
__nv_bfloat16* __restrict__ y, // [T, top_k, N]
|
||||
int N, int K, int top_k,
|
||||
int expert_start, int local_experts,
|
||||
int x_per_slot
|
||||
) {
|
||||
int token = blockIdx.z;
|
||||
int slot = blockIdx.y;
|
||||
int eid = topk_ids[token * top_k + slot];
|
||||
int lid = eid - expert_start;
|
||||
if (lid < 0 || lid >= local_experts) return; // block-uniform: safe
|
||||
|
||||
extern __shared__ float xs[]; // [K] activation row as float
|
||||
const __nv_bfloat16* xrow =
|
||||
x + (long long)(x_per_slot ? token * top_k + slot : token) * K;
|
||||
for (int i = threadIdx.x; i < K; i += blockDim.x) {
|
||||
xs[i] = __bfloat162float(xrow[i]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int n = blockIdx.x * SPARSE_TILE_N + (threadIdx.x >> 5);
|
||||
if (n >= N) return; // after __syncthreads: safe
|
||||
int lane = threadIdx.x & 31;
|
||||
|
||||
// One warp per output column; uint4 = 16 FP8 weights per lane, the
|
||||
// warp covers 512 contiguous bytes per iteration (coalesced).
|
||||
const uint8_t* wrow = (const uint8_t*)w + ((long long)lid * N + n) * K;
|
||||
float acc = 0.0f;
|
||||
for (int i = lane; i < (K >> 4); i += 32) {
|
||||
uint4 packed = *(const uint4*)(wrow + (long long)i * 16);
|
||||
const __nv_fp8_e4m3* pw = (const __nv_fp8_e4m3*)&packed;
|
||||
const float* xk = xs + i * 16;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 16; j++) {
|
||||
acc += xk[j] * float(pw[j]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int o = 16; o > 0; o >>= 1) {
|
||||
acc += __shfl_down_sync(0xffffffffu, acc, o);
|
||||
}
|
||||
if (lane == 0) {
|
||||
float v = acc * w_scales[lid]
|
||||
+ __bfloat162float(bias[(long long)lid * N + n]);
|
||||
y[((long long)token * top_k + slot) * N + n] = __float2bfloat16(v);
|
||||
}
|
||||
}
|
||||
|
||||
// MXFP4 W4A16 variant: packed E2M1 nibbles + per-32 UE8M0 block scale,
|
||||
// same structure as batched_gemv_mxfp4_bf16_kernel but expert-indexed
|
||||
// via topk_ids and with fused per-expert bias.
|
||||
#define MXFP4_BLOCK 32
|
||||
|
||||
__device__ __constant__ float kSparseFp4Levels[8] =
|
||||
{0.f, 0.5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f};
|
||||
|
||||
__device__ __forceinline__ float sparse_fp4_to_float(uint8_t code) {
|
||||
float mag = kSparseFp4Levels[code & 0x7];
|
||||
return (code & 0x8) ? -mag : mag;
|
||||
}
|
||||
|
||||
__global__ void moe_sparse_gemv_mxfp4_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ x, // [T, K] or [T*top_k, K]
|
||||
const uint8_t* __restrict__ w_packed, // [local_experts, N, K/2]
|
||||
const uint8_t* __restrict__ w_scales, // [local_experts, N, K/32]
|
||||
const __nv_bfloat16* __restrict__ bias, // [local_experts, N]
|
||||
const int* __restrict__ topk_ids, // [T, top_k]
|
||||
__nv_bfloat16* __restrict__ y, // [T, top_k, N]
|
||||
int N, int K, int top_k,
|
||||
int expert_start, int local_experts,
|
||||
int x_per_slot
|
||||
) {
|
||||
int token = blockIdx.z;
|
||||
int slot = blockIdx.y;
|
||||
int eid = topk_ids[token * top_k + slot];
|
||||
int lid = eid - expert_start;
|
||||
if (lid < 0 || lid >= local_experts) return; // block-uniform: safe
|
||||
|
||||
extern __shared__ float xs[];
|
||||
const __nv_bfloat16* xrow =
|
||||
x + (long long)(x_per_slot ? token * top_k + slot : token) * K;
|
||||
for (int i = threadIdx.x; i < K; i += blockDim.x) {
|
||||
xs[i] = __bfloat162float(xrow[i]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int n = blockIdx.x * SPARSE_TILE_N + (threadIdx.x >> 5);
|
||||
if (n >= N) return;
|
||||
int lane = threadIdx.x & 31;
|
||||
int nblk = K / MXFP4_BLOCK;
|
||||
|
||||
const uint8_t* wp = w_packed + ((long long)lid * N + n) * (K >> 1);
|
||||
const uint8_t* ws = w_scales + ((long long)lid * N + n) * nblk;
|
||||
|
||||
float acc = 0.0f;
|
||||
for (int blk = lane; blk < nblk; blk += 32) {
|
||||
float scale = exp2f((float)((int)ws[blk] - 127));
|
||||
uint4 packed = *(const uint4*)(wp + (long long)blk * 16); // 32 nibbles
|
||||
const uint8_t* pb = (const uint8_t*)&packed;
|
||||
const float* xk = xs + blk * MXFP4_BLOCK;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) {
|
||||
uint8_t b = pb[i];
|
||||
acc += xk[2 * i] * (sparse_fp4_to_float(b & 0xF) * scale);
|
||||
acc += xk[2 * i + 1] * (sparse_fp4_to_float(b >> 4) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int o = 16; o > 0; o >>= 1) {
|
||||
acc += __shfl_down_sync(0xffffffffu, acc, o);
|
||||
}
|
||||
if (lane == 0) {
|
||||
float v = acc + __bfloat162float(bias[(long long)lid * N + n]);
|
||||
y[((long long)token * top_k + slot) * N + n] = __float2bfloat16(v);
|
||||
}
|
||||
}
|
||||
|
||||
// Weighted sum over the slot axis: out[t, d] = sum over local slots of
|
||||
// topk_weights[t, k] * down[t, k, d]. Non-local slots hold uninitialized
|
||||
// memory and are SKIPPED (not multiplied by zero).
|
||||
__global__ void moe_weighted_sum_sparse_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ down, // [T, top_k, hidden]
|
||||
const int* __restrict__ topk_ids, // [T, top_k]
|
||||
const float* __restrict__ topk_weights, // [T, top_k]
|
||||
__nv_bfloat16* __restrict__ out, // [T, hidden]
|
||||
int num_tokens, int hidden, int top_k,
|
||||
int expert_start, int local_experts
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = num_tokens * hidden;
|
||||
if (idx >= total) return;
|
||||
|
||||
int token = idx / hidden;
|
||||
int dim = idx % hidden;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int k = 0; k < top_k; k++) {
|
||||
int lid = topk_ids[token * top_k + k] - expert_start;
|
||||
if (lid >= 0 && lid < local_experts) {
|
||||
float w = topk_weights[token * top_k + k];
|
||||
float v = __bfloat162float(
|
||||
down[((long long)token * top_k + k) * hidden + dim]);
|
||||
sum += w * v;
|
||||
}
|
||||
}
|
||||
out[idx] = __float2bfloat16(sum);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_moe_sparse_gemv_fp8_bf16(
|
||||
const void* x, const void* w, const void* w_scales, const void* bias,
|
||||
const void* topk_ids, void* y,
|
||||
int num_tokens, int N, int K, int top_k,
|
||||
int expert_start, int local_experts, int x_per_slot,
|
||||
void* stream
|
||||
) {
|
||||
dim3 grid((N + SPARSE_TILE_N - 1) / SPARSE_TILE_N, top_k, num_tokens);
|
||||
int block = SPARSE_TILE_N * 32;
|
||||
size_t smem = (size_t)K * sizeof(float);
|
||||
moe_sparse_gemv_fp8_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const __nv_fp8_e4m3*)w,
|
||||
(const float*)w_scales, (const __nv_bfloat16*)bias,
|
||||
(const int*)topk_ids, (__nv_bfloat16*)y,
|
||||
N, K, top_k, expert_start, local_experts, x_per_slot
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_moe_sparse_gemv_mxfp4_bf16(
|
||||
const void* x, const void* w_packed, const void* w_scales, const void* bias,
|
||||
const void* topk_ids, void* y,
|
||||
int num_tokens, int N, int K, int top_k,
|
||||
int expert_start, int local_experts, int x_per_slot,
|
||||
void* stream
|
||||
) {
|
||||
dim3 grid((N + SPARSE_TILE_N - 1) / SPARSE_TILE_N, top_k, num_tokens);
|
||||
int block = SPARSE_TILE_N * 32;
|
||||
size_t smem = (size_t)K * sizeof(float);
|
||||
moe_sparse_gemv_mxfp4_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const uint8_t*)w_packed,
|
||||
(const uint8_t*)w_scales, (const __nv_bfloat16*)bias,
|
||||
(const int*)topk_ids, (__nv_bfloat16*)y,
|
||||
N, K, top_k, expert_start, local_experts, x_per_slot
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_moe_weighted_sum_sparse_bf16(
|
||||
const void* down, const void* topk_ids, const void* topk_weights,
|
||||
void* out,
|
||||
int num_tokens, int hidden, int top_k,
|
||||
int expert_start, int local_experts,
|
||||
void* stream
|
||||
) {
|
||||
int total = num_tokens * hidden;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
moe_weighted_sum_sparse_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)down,
|
||||
(const int*)topk_ids, (const float*)topk_weights,
|
||||
(__nv_bfloat16*)out,
|
||||
num_tokens, hidden, top_k, expert_start, local_experts
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -14,27 +14,34 @@ __global__ void layernorm_f32(
|
||||
const float* x_row = x + row * hidden_size;
|
||||
float* out_row = out + row * hidden_size;
|
||||
|
||||
// Welford online: compute mean and variance in one pass
|
||||
// Pass 1: compute mean
|
||||
float local_sum = 0.0f;
|
||||
float local_sum_sq = 0.0f;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = x_row[i];
|
||||
local_sum += v;
|
||||
local_sum_sq += v * v;
|
||||
local_sum += x_row[i];
|
||||
}
|
||||
local_sum = block_reduce_sum(local_sum);
|
||||
local_sum_sq = block_reduce_sum(local_sum_sq);
|
||||
|
||||
__shared__ float s_mean, s_inv_std;
|
||||
if (threadIdx.x == 0) {
|
||||
float mean = local_sum / hidden_size;
|
||||
float var = local_sum_sq / hidden_size - mean * mean;
|
||||
s_mean = mean;
|
||||
s_inv_std = rsqrtf(var + eps);
|
||||
s_mean = local_sum / hidden_size;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float mean = s_mean;
|
||||
|
||||
// Pass 2: compute variance = sum((x - mean)^2) / N
|
||||
float local_var = 0.0f;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float d = x_row[i] - mean;
|
||||
local_var += d * d;
|
||||
}
|
||||
local_var = block_reduce_sum(local_var);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_inv_std = rsqrtf(local_var / hidden_size + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float inv_std = s_inv_std;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
out_row[i] = gamma[i] * (x_row[i] - mean) * inv_std + beta[i];
|
||||
@@ -52,26 +59,34 @@ __global__ void layernorm_bf16(
|
||||
const __nv_bfloat16* x_row = x + row * hidden_size;
|
||||
__nv_bfloat16* out_row = out + row * hidden_size;
|
||||
|
||||
// Pass 1: compute mean
|
||||
float local_sum = 0.0f;
|
||||
float local_sum_sq = 0.0f;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = __bfloat162float(x_row[i]);
|
||||
local_sum += v;
|
||||
local_sum_sq += v * v;
|
||||
local_sum += __bfloat162float(x_row[i]);
|
||||
}
|
||||
local_sum = block_reduce_sum(local_sum);
|
||||
local_sum_sq = block_reduce_sum(local_sum_sq);
|
||||
|
||||
__shared__ float s_mean, s_inv_std;
|
||||
if (threadIdx.x == 0) {
|
||||
float mean = local_sum / hidden_size;
|
||||
float var = local_sum_sq / hidden_size - mean * mean;
|
||||
s_mean = mean;
|
||||
s_inv_std = rsqrtf(var + eps);
|
||||
s_mean = local_sum / hidden_size;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float mean = s_mean;
|
||||
|
||||
// Pass 2: compute variance = sum((x - mean)^2) / N
|
||||
float local_var = 0.0f;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float d = __bfloat162float(x_row[i]) - mean;
|
||||
local_var += d * d;
|
||||
}
|
||||
local_var = block_reduce_sum(local_var);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_inv_std = rsqrtf(local_var / hidden_size + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float inv_std = s_inv_std;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = __bfloat162float(x_row[i]);
|
||||
@@ -86,17 +101,21 @@ extern "C" {
|
||||
void launch_layernorm_f32(const void* x, const void* gamma, const void* beta,
|
||||
void* out, int rows, int hidden_size, float eps, void* stream) {
|
||||
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
||||
if (block < 32) block = 32;
|
||||
layernorm_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)x, (const float*)gamma, (const float*)beta,
|
||||
(float*)out, hidden_size, eps);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_layernorm_bf16(const void* x, const void* gamma, const void* beta,
|
||||
void* out, int rows, int hidden_size, float eps, void* stream) {
|
||||
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
||||
if (block < 32) block = 32;
|
||||
layernorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const __nv_bfloat16*)gamma, (const __nv_bfloat16*)beta,
|
||||
(__nv_bfloat16*)out, hidden_size, eps);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -63,21 +63,78 @@ __global__ void rmsnorm_bf16(
|
||||
}
|
||||
}
|
||||
|
||||
// Fused Add + RMSNorm: sum_out = x + residual, normed_out = rmsnorm(sum_out, gamma, eps)
|
||||
// Each block handles one row of [hidden_size].
|
||||
__global__ void add_rmsnorm_bf16(
|
||||
const __nv_bfloat16* __restrict__ x,
|
||||
const __nv_bfloat16* __restrict__ residual,
|
||||
const __nv_bfloat16* __restrict__ gamma,
|
||||
__nv_bfloat16* __restrict__ normed_out,
|
||||
__nv_bfloat16* __restrict__ sum_out,
|
||||
int hidden_size, float eps
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const __nv_bfloat16* x_row = x + row * hidden_size;
|
||||
const __nv_bfloat16* res_row = residual + row * hidden_size;
|
||||
__nv_bfloat16* sum_row = sum_out + row * hidden_size;
|
||||
__nv_bfloat16* norm_row = normed_out + row * hidden_size;
|
||||
|
||||
// Pass 1: compute sum = x + residual, and accumulate sum_sq
|
||||
float sum_sq = 0.0f;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float s = __bfloat162float(x_row[i]) + __bfloat162float(res_row[i]);
|
||||
sum_row[i] = __float2bfloat16(s);
|
||||
sum_sq += s * s;
|
||||
}
|
||||
sum_sq = block_reduce_sum(sum_sq);
|
||||
|
||||
__shared__ float s_rms_inv;
|
||||
if (threadIdx.x == 0) {
|
||||
s_rms_inv = rsqrtf(sum_sq / hidden_size + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Pass 2: normed_out = sum * rms_inv * gamma
|
||||
float rms_inv = s_rms_inv;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float s = __bfloat162float(sum_row[i]);
|
||||
float g = __bfloat162float(gamma[i]);
|
||||
norm_row[i] = __float2bfloat16(s * rms_inv * g);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_rmsnorm_f32(const void* x, const void* gamma, void* out,
|
||||
int rows, int hidden_size, float eps, void* stream) {
|
||||
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
||||
if (block < 32) block = 32;
|
||||
rmsnorm_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)x, (const float*)gamma, (float*)out, hidden_size, eps);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_rmsnorm_bf16(const void* x, const void* gamma, void* out,
|
||||
int rows, int hidden_size, float eps, void* stream) {
|
||||
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
||||
if (block < 32) block = 32;
|
||||
rmsnorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const __nv_bfloat16*)gamma,
|
||||
(__nv_bfloat16*)out, hidden_size, eps);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_add_rmsnorm_bf16(const void* x, const void* residual, const void* gamma,
|
||||
void* normed_out, void* sum_out,
|
||||
int rows, int hidden_size, float eps, void* stream) {
|
||||
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
||||
if (block < 32) block = 32;
|
||||
add_rmsnorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const __nv_bfloat16*)residual,
|
||||
(const __nv_bfloat16*)gamma,
|
||||
(__nv_bfloat16*)normed_out, (__nv_bfloat16*)sum_out,
|
||||
hidden_size, eps);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
53
csrc/quantization/dequant_fp8.cu
Normal file
53
csrc/quantization/dequant_fp8.cu
Normal file
@@ -0,0 +1,53 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Dequantize FP8 E4M3 → BF16 with per-expert (per-batch-slice) FP32 scale.
|
||||
//
|
||||
// Input: src [num_experts, rows, cols] FP8 E4M3 (1 byte each)
|
||||
// scales [num_experts] FP32
|
||||
// Output: dst [num_experts, rows, cols] BF16
|
||||
//
|
||||
// Each element: dst[e, r, c] = bf16( float(src[e, r, c]) * scales[e] )
|
||||
|
||||
__global__ void dequant_fp8e4m3_to_bf16_kernel(
|
||||
const __nv_fp8_e4m3* __restrict__ src,
|
||||
const float* __restrict__ scales,
|
||||
__nv_bfloat16* __restrict__ dst,
|
||||
int num_experts, int rows, int cols
|
||||
) {
|
||||
// 64-bit index: num_experts * rows * cols overflows int32 for 32 experts
|
||||
// at ~8k*8k weight matrices, same class as the MoE fix in cfbd64d.
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long total = (long long)num_experts * rows * cols;
|
||||
if (idx >= total) return;
|
||||
|
||||
long long expert_stride = (long long)rows * cols;
|
||||
int expert = (int)(idx / expert_stride);
|
||||
float scale = scales[expert];
|
||||
float val = float(src[idx]) * scale;
|
||||
dst[idx] = __float2bfloat16(val);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_dequant_fp8e4m3_to_bf16(
|
||||
const void* src,
|
||||
const void* scales,
|
||||
void* dst,
|
||||
int num_experts, int rows, int cols,
|
||||
void* stream
|
||||
) {
|
||||
long long total = (long long)num_experts * rows * cols;
|
||||
int block = 256;
|
||||
int grid = (int)((total + block - 1) / block);
|
||||
dequant_fp8e4m3_to_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_fp8_e4m3*)src,
|
||||
(const float*)scales,
|
||||
(__nv_bfloat16*)dst,
|
||||
num_experts, rows, cols
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
135
csrc/quantization/mxfp4_gemm.cu
Normal file
135
csrc/quantization/mxfp4_gemm.cu
Normal file
@@ -0,0 +1,135 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cstdint>
|
||||
#include "../common.cuh"
|
||||
|
||||
// MXFP4 W4A16 for MoE experts. Weights stored [E, N, K] with K (reduction)
|
||||
// contiguous, blocked by 32: packed 4-bit E2M1 (two nibbles/byte, lo = even k)
|
||||
// + one UE8M0 scale byte per 32 elements. The decode win is reading 4-bit
|
||||
// weights from HBM (half of FP8) and dequantizing on-chip to BF16.
|
||||
|
||||
#define MXFP4_BLOCK 32
|
||||
|
||||
// E2M1 magnitude by 3-bit code; bit 3 is the sign.
|
||||
__device__ __constant__ float kFp4Levels[8] = {0.f, 0.5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f};
|
||||
|
||||
__device__ __forceinline__ float fp4_to_float(uint8_t code) {
|
||||
float mag = kFp4Levels[code & 0x7];
|
||||
return (code & 0x8) ? -mag : mag;
|
||||
}
|
||||
|
||||
// Decode (M=1) fused GEMV, batched over experts.
|
||||
// y[e, n] = sum_k x[e, k] * dequant(W[e, n, k])
|
||||
// Grid: (N/TILE_N, E). Each block loads the activation x[e, :] into shared
|
||||
// memory ONCE and computes TILE_N output columns from it (one warp per column),
|
||||
// so the activation is read from HBM once per TILE_N outputs instead of once
|
||||
// per output. Weights are unique per output and read coalesced as uint4; the
|
||||
// UE8M0 block scale is hoisted to once per 32-element block.
|
||||
#define MXFP4_TILE_N 8 // output columns per block (= warps per block)
|
||||
|
||||
__global__ void batched_gemv_mxfp4_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ x, // [E, K]
|
||||
const uint8_t* __restrict__ w_packed, // [E, N, K/2]
|
||||
const uint8_t* __restrict__ w_scales, // [E, N, K/32]
|
||||
__nv_bfloat16* __restrict__ y, // [E, N]
|
||||
int E, int N, int K
|
||||
) {
|
||||
extern __shared__ float xs[]; // [K] activation for this expert
|
||||
int e = blockIdx.y;
|
||||
int n_base = blockIdx.x * MXFP4_TILE_N;
|
||||
int warp = threadIdx.x >> 5; // 0..TILE_N-1
|
||||
int lane = threadIdx.x & 31;
|
||||
int nthreads = blockDim.x; // TILE_N * 32
|
||||
int nblk = K / MXFP4_BLOCK;
|
||||
|
||||
// Cooperatively stage x[e, :] into shared memory (converted to float).
|
||||
const __nv_bfloat16* xe = x + (long long)e * K;
|
||||
for (int k = threadIdx.x; k < K; k += nthreads) {
|
||||
xs[k] = __bfloat162float(xe[k]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int n = n_base + warp;
|
||||
if (n >= N) return;
|
||||
|
||||
const uint8_t* wp = w_packed + ((long long)e * N + n) * (K >> 1);
|
||||
const uint8_t* ws = w_scales + ((long long)e * N + n) * nblk;
|
||||
|
||||
float acc = 0.0f;
|
||||
for (int blk = lane; blk < nblk; blk += 32) {
|
||||
float scale = exp2f((float)((int)ws[blk] - 127));
|
||||
uint4 packed = *(const uint4*)(wp + (long long)blk * 16); // 16 bytes = 32 nibbles
|
||||
const uint8_t* pb = (const uint8_t*)&packed;
|
||||
const float* xk = xs + blk * MXFP4_BLOCK;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) {
|
||||
uint8_t b = pb[i];
|
||||
acc += xk[2 * i] * (fp4_to_float(b & 0xF) * scale);
|
||||
acc += xk[2 * i + 1] * (fp4_to_float(b >> 4) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
// Warp reduction.
|
||||
#pragma unroll
|
||||
for (int o = 16; o > 0; o >>= 1) {
|
||||
acc += __shfl_down_sync(0xffffffffu, acc, o);
|
||||
}
|
||||
if (lane == 0) y[(long long)e * N + n] = __float2bfloat16(acc);
|
||||
}
|
||||
|
||||
// Prefill fallback: dequant MXFP4 [E, N, K] -> BF16 [E, K, N] (transposed back
|
||||
// to the [E, K, N] layout the BF16 batched GEMM expects). Not bandwidth-optimal,
|
||||
// but prefill is compute-bound so it is not the decode hot path.
|
||||
__global__ void dequant_mxfp4_to_bf16_t_kernel(
|
||||
const uint8_t* __restrict__ w_packed, // [E, N, K/2]
|
||||
const uint8_t* __restrict__ w_scales, // [E, N, K/32]
|
||||
__nv_bfloat16* __restrict__ out, // [E, K, N]
|
||||
int E, int N, int K
|
||||
) {
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long total = (long long)E * N * K;
|
||||
if (idx >= total) return;
|
||||
int k = idx % K;
|
||||
int n = (idx / K) % N;
|
||||
int e = idx / ((long long)N * K);
|
||||
|
||||
int Kh = K >> 1;
|
||||
int Ks = K / MXFP4_BLOCK;
|
||||
uint8_t byte = w_packed[((long long)e * N + n) * Kh + (k >> 1)];
|
||||
uint8_t code = (k & 1) ? (byte >> 4) : (byte & 0xF);
|
||||
float scale = exp2f((float)((int)w_scales[((long long)e * N + n) * Ks + k / MXFP4_BLOCK] - 127));
|
||||
float val = fp4_to_float(code) * scale;
|
||||
// write to out[e, k, n]
|
||||
out[((long long)e * K + k) * N + n] = __float2bfloat16(val);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_batched_gemv_mxfp4_bf16(
|
||||
const void* x, const void* w_packed, const void* w_scales, void* y,
|
||||
int E, int N, int K, void* stream
|
||||
) {
|
||||
dim3 grid((N + MXFP4_TILE_N - 1) / MXFP4_TILE_N, E);
|
||||
int block = MXFP4_TILE_N * 32; // one warp per output column
|
||||
size_t smem = (size_t)K * sizeof(float);
|
||||
batched_gemv_mxfp4_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const uint8_t*)w_packed, (const uint8_t*)w_scales,
|
||||
(__nv_bfloat16*)y, E, N, K
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_dequant_mxfp4_to_bf16_t(
|
||||
const void* w_packed, const void* w_scales, void* out,
|
||||
int E, int N, int K, void* stream
|
||||
) {
|
||||
long long total = (long long)E * N * K;
|
||||
int block = 256;
|
||||
long long grid = (total + block - 1) / block;
|
||||
dequant_mxfp4_to_bf16_t_kernel<<<(unsigned)grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const uint8_t*)w_packed, (const uint8_t*)w_scales, (__nv_bfloat16*)out,
|
||||
E, N, K
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
160
csrc/quantization/quantize_fp8.cu
Normal file
160
csrc/quantization/quantize_fp8.cu
Normal file
@@ -0,0 +1,160 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <float.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Per-row quantize BF16 → FP8 E4M3 with per-row FP32 scale output.
|
||||
//
|
||||
// Input: src [num_rows, cols] BF16
|
||||
// Output: dst [num_rows, cols] FP8 E4M3
|
||||
// scales [num_rows] FP32
|
||||
//
|
||||
// Algorithm per row:
|
||||
// absmax = max(|src[row, :]|)
|
||||
// scale = absmax / 448.0 (FP8 E4M3 max representable)
|
||||
// dst[row, i] = fp8(src[row, i] / scale)
|
||||
//
|
||||
// Grid: one block per row. Block: 256 threads.
|
||||
// Each thread handles ceil(cols / 256) elements.
|
||||
|
||||
#define QUANT_BLOCK 256
|
||||
#define FP8_E4M3_MAX 448.0f
|
||||
|
||||
__global__ void quantize_bf16_to_fp8e4m3_rowwise_kernel(
|
||||
const __nv_bfloat16* __restrict__ src,
|
||||
__nv_fp8_e4m3* __restrict__ dst,
|
||||
float* __restrict__ scales,
|
||||
int num_rows, int cols
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
if (row >= num_rows) return;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
const __nv_bfloat16* row_src = src + (long long)row * cols;
|
||||
__nv_fp8_e4m3* row_dst = dst + (long long)row * cols;
|
||||
|
||||
// Step 1: Compute per-row absmax via shared-memory reduction.
|
||||
__shared__ float smem_max[QUANT_BLOCK];
|
||||
float local_max = 0.0f;
|
||||
for (int i = tid; i < cols; i += QUANT_BLOCK) {
|
||||
float v = fabsf(__bfloat162float(row_src[i]));
|
||||
local_max = fmaxf(local_max, v);
|
||||
}
|
||||
smem_max[tid] = local_max;
|
||||
__syncthreads();
|
||||
|
||||
// Block reduction
|
||||
for (int s = QUANT_BLOCK / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
smem_max[tid] = fmaxf(smem_max[tid], smem_max[tid + s]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float absmax = smem_max[0];
|
||||
float scale = absmax / FP8_E4M3_MAX;
|
||||
// Clamp scale to avoid div-by-zero for all-zero rows
|
||||
if (scale < 1e-12f) scale = 1e-12f;
|
||||
float inv_scale = 1.0f / scale;
|
||||
|
||||
// Thread 0 writes the scale
|
||||
if (tid == 0) {
|
||||
scales[row] = scale;
|
||||
}
|
||||
|
||||
// Step 2: Quantize each element
|
||||
for (int i = tid; i < cols; i += QUANT_BLOCK) {
|
||||
float v = __bfloat162float(row_src[i]) * inv_scale;
|
||||
row_dst[i] = __nv_fp8_e4m3(v);
|
||||
}
|
||||
}
|
||||
|
||||
// Row-wise scale: data[row, :] *= scales[row] (in-place, BF16)
|
||||
__global__ void rowwise_scale_bf16_kernel(
|
||||
__nv_bfloat16* __restrict__ data,
|
||||
const float* __restrict__ scales,
|
||||
int num_rows, int cols
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
if (row >= num_rows) return;
|
||||
int tid = threadIdx.x;
|
||||
float s = scales[row];
|
||||
__nv_bfloat16* row_data = data + (long long)row * cols;
|
||||
for (int i = tid; i < cols; i += blockDim.x) {
|
||||
float v = __bfloat162float(row_data[i]) * s;
|
||||
row_data[i] = __float2bfloat16(v);
|
||||
}
|
||||
}
|
||||
|
||||
// Combined dequant scale for batched MoE FP8 GEMM output.
|
||||
// data[row, :] *= a_scales[row] * b_scales[row / tokens]
|
||||
// where row = expert * tokens + token. a_scales is the per-token activation
|
||||
// scale; b_scales is the per-expert scalar weight scale. Lets a single
|
||||
// strided-batched FP8 matmul (alpha=1, scales=1) recover the real result in
|
||||
// one pass instead of folding the weight scale into a per-expert GEMM call.
|
||||
__global__ void rowwise_scale_moe_bf16_kernel(
|
||||
__nv_bfloat16* __restrict__ data,
|
||||
const float* __restrict__ a_scales,
|
||||
const float* __restrict__ b_scales,
|
||||
int num_rows, int cols, int tokens
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
if (row >= num_rows) return;
|
||||
int tid = threadIdx.x;
|
||||
float s = a_scales[row] * b_scales[row / tokens];
|
||||
__nv_bfloat16* row_data = data + (long long)row * cols;
|
||||
for (int i = tid; i < cols; i += blockDim.x) {
|
||||
float v = __bfloat162float(row_data[i]) * s;
|
||||
row_data[i] = __float2bfloat16(v);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_rowwise_scale_bf16(
|
||||
void* data, const void* scales,
|
||||
int num_rows, int cols,
|
||||
void* stream
|
||||
) {
|
||||
int block = 256;
|
||||
int grid = num_rows;
|
||||
rowwise_scale_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)data, (const float*)scales,
|
||||
num_rows, cols
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_rowwise_scale_moe_bf16(
|
||||
void* data, const void* a_scales, const void* b_scales,
|
||||
int num_rows, int cols, int tokens,
|
||||
void* stream
|
||||
) {
|
||||
int block = 256;
|
||||
int grid = num_rows;
|
||||
rowwise_scale_moe_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)data, (const float*)a_scales, (const float*)b_scales,
|
||||
num_rows, cols, tokens
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_quantize_bf16_to_fp8e4m3_rowwise(
|
||||
const void* src,
|
||||
void* dst,
|
||||
void* scales,
|
||||
int num_rows, int cols,
|
||||
void* stream
|
||||
) {
|
||||
int grid = num_rows;
|
||||
int block = QUANT_BLOCK;
|
||||
quantize_bf16_to_fp8e4m3_rowwise_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)src,
|
||||
(__nv_fp8_e4m3*)dst,
|
||||
(float*)scales,
|
||||
num_rows, cols
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
92
csrc/reduce/argmax.cu
Normal file
92
csrc/reduce/argmax.cu
Normal file
@@ -0,0 +1,92 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <float.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Argmax along the last dim of a [rows, cols] tensor.
|
||||
// One block per row; output is [rows] int32 indices of the max element.
|
||||
//
|
||||
// Reduction: each thread scans a strided slice and tracks the running
|
||||
// (value, index) pair, then warp-shuffle reduce, then a single-warp
|
||||
// reduce over per-warp leaders. Tie-break: smaller index wins so the
|
||||
// result is deterministic across launches.
|
||||
//
|
||||
// For BF16 logits the comparison happens in FP32 to avoid losing
|
||||
// precision near the top of the distribution.
|
||||
|
||||
__global__ void argmax_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ logits,
|
||||
int* __restrict__ out_idx,
|
||||
int cols
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const __nv_bfloat16* row_ptr = logits + (long long)row * cols;
|
||||
int tid = threadIdx.x;
|
||||
unsigned mask = 0xffffffff;
|
||||
|
||||
// Strided per-thread max.
|
||||
float local_max = -FLT_MAX;
|
||||
int local_idx = INT_MAX;
|
||||
for (int i = tid; i < cols; i += blockDim.x) {
|
||||
float v = __bfloat162float(row_ptr[i]);
|
||||
// strict `>` keeps the smallest index on ties, since we scan ascending.
|
||||
if (v > local_max) {
|
||||
local_max = v;
|
||||
local_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Warp-level reduce of (val, idx) pairs.
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
float other_val = __shfl_down_sync(mask, local_max, offset);
|
||||
int other_idx = __shfl_down_sync(mask, local_idx, offset);
|
||||
bool take = (other_val > local_max) ||
|
||||
(other_val == local_max && other_idx < local_idx);
|
||||
if (take) {
|
||||
local_max = other_val;
|
||||
local_idx = other_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Per-warp leaders → shared memory → single warp final reduce.
|
||||
__shared__ float s_val[32];
|
||||
__shared__ int s_idx[32];
|
||||
int lane = tid & 31;
|
||||
int warp_id = tid >> 5;
|
||||
int num_warps = (blockDim.x + 31) >> 5;
|
||||
|
||||
if (lane == 0) {
|
||||
s_val[warp_id] = local_max;
|
||||
s_idx[warp_id] = local_idx;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {
|
||||
float v = (tid < num_warps) ? s_val[lane] : -FLT_MAX;
|
||||
int i = (tid < num_warps) ? s_idx[lane] : INT_MAX;
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
float ov = __shfl_down_sync(mask, v, offset);
|
||||
int oi = __shfl_down_sync(mask, i, offset);
|
||||
bool take = (ov > v) || (ov == v && oi < i);
|
||||
if (take) { v = ov; i = oi; }
|
||||
}
|
||||
if (lane == 0) {
|
||||
out_idx[row] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_argmax_bf16(const void* logits, void* out_idx,
|
||||
int rows, int cols, void* stream) {
|
||||
// 1024 threads/block keeps occupancy high and gives 32 warps for the
|
||||
// final reduce (matches the 32-slot shared arrays above).
|
||||
int block = 1024;
|
||||
argmax_bf16_kernel<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)logits, (int*)out_idx, cols);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -90,17 +90,19 @@ __global__ void softmax_bf16(
|
||||
extern "C" {
|
||||
|
||||
void launch_softmax_f32(const void* x, void* out, int rows, int cols, void* stream) {
|
||||
int block = (cols < 1024) ? cols : 1024;
|
||||
int block = (cols < 512) ? cols : 512;
|
||||
if (block < 32) block = 32;
|
||||
softmax_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)x, (float*)out, cols);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_softmax_bf16(const void* x, void* out, int rows, int cols, void* stream) {
|
||||
int block = (cols < 1024) ? cols : 1024;
|
||||
int block = (cols < 512) ? cols : 512;
|
||||
if (block < 32) block = 32;
|
||||
softmax_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, cols);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
| 抽象层级 | Level 0.5 | 自写 CUDA kernel + cuBLAS 可切换,便于 benchmark 对比 |
|
||||
| 硬件 | 8×RTX 5090 (Blackwell, CC 12.0, 32GB GDDR7) | 纯 PCIe Gen5 x16 互联,无 NVLink (详见下方硬件拓扑) |
|
||||
| 语言 | Rust + CUDA (C/C++) | Rust FFI 调用 CUDA |
|
||||
| 起步模型 | GPT-2 124M → Qwen3-7B | 从简单到实用 |
|
||||
| 起步模型 | GPT-2 124M → Qwen3-8B | 从简单到实用 |
|
||||
| 精度 | BF16/FP16 | 后期扩展 FP8 |
|
||||
| Tensor | 自己实现 | 完整学习 tensor 抽象设计 |
|
||||
| Tokenizer | 自己实现 BPE | 学习分词机制 |
|
||||
@@ -101,7 +101,7 @@ Phase 8: GPT-2 完整推理 ◄──────────── 里程碑
|
||||
│
|
||||
Phase 9: KV Cache + Autoregressive Generation
|
||||
│
|
||||
Phase 10: Qwen3-7B 支持 ◄─────────── 里程碑 ② 7B 模型推理
|
||||
Phase 10: Qwen3-8B 支持 ◄─────────── 里程碑 ② 8B 模型推理
|
||||
│
|
||||
Phase 11: Paged Attention + KV Cache Manager
|
||||
│
|
||||
@@ -109,7 +109,7 @@ Phase 12: Continuous Batching + Request Scheduler
|
||||
│
|
||||
Phase 13: HTTP API + SSE Streaming ◄── 里程碑 ③ 端到端 API 可用
|
||||
│
|
||||
Phase 14: Flash Attention v2
|
||||
Phase 14: Flash Attention (FA2 for SM120)
|
||||
│
|
||||
Phase 15: 性能优化 ◄──────────────── 里程碑 ④ 50% vLLM throughput
|
||||
│
|
||||
@@ -625,8 +625,8 @@ safetensors file (disk)
|
||||
|
||||
- [ ] 加载 GPT-2 124M (`openai-community/gpt2`),打印所有 tensor name, shape, dtype
|
||||
- [ ] 抽查几个 tensor 的前 10 个值,与 PyTorch `from_pretrained` 对比
|
||||
- [ ] 加载 Qwen3-7B sharded 权重,验证所有 tensor 都成功加载
|
||||
- [ ] 性能: 测量 7B 模型权重加载时间 (mmap → GPU 全流程)
|
||||
- [ ] 加载 Qwen3-8B sharded 权重,验证所有 tensor 都成功加载
|
||||
- [ ] 性能: 测量 8B 模型权重加载时间 (mmap → GPU 全流程)
|
||||
- [ ] 错误处理: 缺少 tensor、dtype 不匹配、文件不存在等情况
|
||||
|
||||
---
|
||||
@@ -869,15 +869,15 @@ weights × V_cache [B, H, S, D] → output [B, H, 1, D]
|
||||
|
||||
---
|
||||
|
||||
## Phase 10: Qwen3-7B 支持 — 里程碑 ②
|
||||
## Phase 10: Qwen3-8B 支持 — 里程碑 ②
|
||||
|
||||
**Crate**: `xserv-model`
|
||||
|
||||
**目标**: 扩展模型定义以支持 Qwen3-7B,验证输出正确性。
|
||||
**目标**: 扩展模型定义以支持 Qwen3-8B,验证输出正确性。
|
||||
|
||||
### 架构对比
|
||||
|
||||
| 特性 | GPT-2 (124M) | Qwen3-7B |
|
||||
| 特性 | GPT-2 (124M) | Qwen3-8B |
|
||||
|------|-------------|----------|
|
||||
| Normalization | LayerNorm (pre-LN) | RMSNorm (pre-LN) |
|
||||
| Position Encoding | Learned absolute (wpe) | RoPE (无单独参数) |
|
||||
@@ -885,8 +885,8 @@ weights × V_cache [B, H, S, D] → output [B, H, 1, D]
|
||||
| Activation | GELU | SwiGLU (SiLU gate) |
|
||||
| FFN | Linear(H→4H) → GELU → Linear(4H→H) | gate_proj + up_proj → SiLU gate → down_proj |
|
||||
| Vocab Size | 50,257 | ~152,000 |
|
||||
| Hidden Size | 768 | 3,584 (7B) |
|
||||
| Layers | 12 | 28 |
|
||||
| Hidden Size | 768 | 4,096 (8B) |
|
||||
| Layers | 12 | 36 |
|
||||
| Tied Embeddings | Yes | No |
|
||||
|
||||
### 需要新增/修改的组件
|
||||
@@ -948,16 +948,16 @@ pub struct Qwen3DecoderLayer {
|
||||
### 显存预算 (BF16, 单卡 5090 32GB)
|
||||
|
||||
```
|
||||
模型权重: 7B × 2B = ~14 GB
|
||||
KV cache: 28 layers × 2(KV) × 8 heads × 4096 tokens × 128 dim × 2B ≈ 4.5 GB
|
||||
模型权重: 8B × 2B = ~16 GB
|
||||
KV cache: 36 layers × 2(KV) × 8 heads × 4096 tokens × 128 dim × 2B ≈ 5.6 GB
|
||||
Activation (单请求): ~1 GB
|
||||
────────────────────────
|
||||
总计: ~19.5 GB (单请求),剩余 ~12 GB 可用于更多并发
|
||||
总计: ~22.6 GB (单请求),剩余 ~10 GB 可用于更多并发
|
||||
```
|
||||
|
||||
### 测试验收
|
||||
|
||||
- [ ] 加载 Qwen3-7B 权重到单张 5090,打印模型结构和参数量
|
||||
- [ ] 加载 Qwen3-8B 权重到单张 5090,打印模型结构和参数量
|
||||
- [ ] Prefill logits 与 HF transformers 对比: 输入 "你好" → top-5 logits 一致
|
||||
- [ ] 英文生成: "What is the capital of France?" → 生成合理回答
|
||||
- [ ] 中文生成: "请介绍一下量子计算" → 生成通顺中文
|
||||
@@ -1196,7 +1196,7 @@ GET /health # 健康检查
|
||||
**Chat Completion Request**:
|
||||
```json
|
||||
{
|
||||
"model": "qwen3-7b",
|
||||
"model": "qwen3-8b",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is 1+1?"}
|
||||
@@ -1211,13 +1211,13 @@ GET /health # 健康检查
|
||||
|
||||
**SSE Streaming Response**:
|
||||
```
|
||||
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1234567890,"model":"qwen3-7b","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
|
||||
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1234567890,"model":"qwen3-8b","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1234567890,"model":"qwen3-7b","choices":[{"index":0,"delta":{"content":"The"},"finish_reason":null}]}
|
||||
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1234567890,"model":"qwen3-8b","choices":[{"index":0,"delta":{"content":"The"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1234567890,"model":"qwen3-7b","choices":[{"index":0,"delta":{"content":" answer"},"finish_reason":null}]}
|
||||
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1234567890,"model":"qwen3-8b","choices":[{"index":0,"delta":{"content":" answer"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1234567890,"model":"qwen3-7b","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
|
||||
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1234567890,"model":"qwen3-8b","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]
|
||||
```
|
||||
@@ -1228,7 +1228,7 @@ data: [DONE]
|
||||
"id": "chatcmpl-xxx",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "qwen3-7b",
|
||||
"model": "qwen3-8b",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "The answer is 2."},
|
||||
@@ -1278,7 +1278,7 @@ Client (curl / Python OpenAI SDK)
|
||||
```bash
|
||||
curl http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model":"qwen3-7b","messages":[{"role":"user","content":"Hello"}],"stream":true}'
|
||||
-d '{"model":"qwen3-8b","messages":[{"role":"user","content":"Hello"}],"stream":true}'
|
||||
```
|
||||
看到 SSE 逐 token 输出
|
||||
|
||||
@@ -1287,7 +1287,7 @@ Client (curl / Python OpenAI SDK)
|
||||
from openai import OpenAI
|
||||
client = OpenAI(base_url="http://localhost:8080/v1", api_key="unused")
|
||||
for chunk in client.chat.completions.create(
|
||||
model="qwen3-7b",
|
||||
model="qwen3-8b",
|
||||
messages=[{"role": "user", "content": "What is 1+1?"}],
|
||||
stream=True
|
||||
):
|
||||
@@ -1302,12 +1302,26 @@ Client (curl / Python OpenAI SDK)
|
||||
|
||||
---
|
||||
|
||||
## Phase 14: Flash Attention v2
|
||||
## Phase 14: Flash Attention (FA2 for SM120)
|
||||
|
||||
**Crate**: `xserv-kernels`
|
||||
**CUDA 源码**: `csrc/attention/flash_attention.cu`
|
||||
|
||||
**目标**: 实现 Flash Attention v2 的 CUDA kernel,大幅降低 attention 的显存占用并提升速度。
|
||||
**目标**: 实现 Flash Attention 的 CUDA kernel,大幅降低 attention 的显存占用并提升速度。
|
||||
|
||||
### 硬件适配说明
|
||||
|
||||
Flash Attention 已发展到第 4 代 (FA4, arxiv 2603.05451),但各版本有明确的硬件依赖:
|
||||
|
||||
| 版本 | 目标架构 | 关键硬件特性 | RTX 5090 兼容 |
|
||||
|------|---------|------------|--------------|
|
||||
| FA2 | 通用 CUDA (SM75+) | 标准 shared memory + HMMA | **是** ✅ |
|
||||
| FA3 | Hopper SM90 (H100) | TMA + WGMMA + warp specialization | 否 |
|
||||
| FA4 | Blackwell SM100 (B200/B300) | TMEM + async MMA + 2-CTA mode | 否 |
|
||||
|
||||
**RTX 5090 (SM120, CC 12.0) 使用的是消费级 Blackwell 架构 (GB202),与数据中心 Blackwell (B200, SM100) 是不同的硅片设计。SM120 物理上没有 TMEM (Tensor Memory) 子系统,因此 FA4 的 kernel 无法在 5090 上运行。这不是软件限制,是硬件级差异。**
|
||||
|
||||
因此本项目实现 **FA2 算法**,使用标准 CUDA (shared memory + HMMA)。FA2 的核心优化——online softmax tiling、O(1) 显存占用——在任何架构上都有效。
|
||||
|
||||
### 核心思想
|
||||
|
||||
@@ -1323,16 +1337,18 @@ Flash Attention 的解法:
|
||||
- 将 Q, K, V 分成 tiles,在 SRAM (shared memory) 中计算
|
||||
- 使用 **online softmax trick**: 边算边更新 running max 和 running sum
|
||||
|
||||
### 算法 (Forward Pass)
|
||||
### 算法 (Forward Pass, FA2)
|
||||
|
||||
FA2 相比 FA1 的改进: 外层循环遍历 Q tiles (而非 K/V),减少 HBM 读写次数。
|
||||
|
||||
```
|
||||
Br, Bc = tile sizes for Q and K/V respectively
|
||||
|
||||
for each Q tile (q_start..q_start+Br):
|
||||
for each Q tile (q_start..q_start+Br): ← 外层: Q tiles
|
||||
load Q_tile [Br, D] to shared memory
|
||||
initialize: O_tile = 0, l = 0, m = -inf // running sum and max
|
||||
initialize: O_tile = 0, l = 0, m = -inf // running sum and max
|
||||
|
||||
for each K,V tile (kv_start..kv_start+Bc):
|
||||
for each K,V tile (kv_start..kv_start+Bc): ← 内层: K/V tiles
|
||||
load K_tile [Bc, D], V_tile [Bc, D] to shared memory
|
||||
|
||||
// Compute attention scores for this tile pair
|
||||
@@ -1345,6 +1361,8 @@ for each Q tile (q_start..q_start+Br):
|
||||
m_new = max(m, rowmax(S_tile)) // new running max
|
||||
P_tile = exp(S_tile - m_new) // safe exp
|
||||
l_new = exp(m - m_new) * l + rowsum(P_tile) // update running sum
|
||||
|
||||
// Rescale and accumulate output
|
||||
O_tile = diag(exp(m - m_new)) * O_tile + P_tile @ V_tile
|
||||
m = m_new
|
||||
l = l_new
|
||||
@@ -1356,9 +1374,12 @@ for each Q tile (q_start..q_start+Br):
|
||||
### 实现要点
|
||||
|
||||
1. **Tile 大小选择**:
|
||||
- 受限于 shared memory (5090 Blackwell CC 12.0: 需要实测确认 per-SM shared memory 上限)
|
||||
- 需要同时存 Q_tile, K_tile, V_tile, S_tile
|
||||
- 典型值: Br=Bc=128 for D=128, BF16
|
||||
- 5090 SM120: shared memory per SM = 100 KB (需实测确认)
|
||||
- 需同时存 Q_tile, K_tile, V_tile, S_tile
|
||||
- BF16: Q_tile [Br, D] = Br × 128 × 2B; K_tile [Bc, D] = Bc × 128 × 2B
|
||||
- S_tile [Br, Bc] 保持 FP32 = Br × Bc × 4B
|
||||
- 推荐起步: Br=Bc=64, head_dim=128 → 共需 ~100KB shared memory
|
||||
- 优化版: Br=Bc=128 需要更多 shared memory, 可能需要拆分
|
||||
|
||||
2. **Causal mask 优化**:
|
||||
- 如果 K/V tile 完全在 Q tile 的"未来"(kv_start > q_end)→ 跳过整个 tile
|
||||
@@ -1369,10 +1390,14 @@ for each Q tile (q_start..q_start+Br):
|
||||
- Q, K, V 的加载用 BF16(节省 bandwidth)
|
||||
- 最终 O 转回 BF16 写出
|
||||
|
||||
4. **与 Paged Attention 的结合**:
|
||||
- Flash Attention 的 K/V tile 遍历逻辑需要适配间接寻址
|
||||
- 每个 tile 查 block_table 得到物理地址
|
||||
- 这是 "Flash-Decoding" / "FlashInfer" 的核心
|
||||
4. **GQA 支持**:
|
||||
- K/V heads 数量 < Q heads 时,kernel 中做 `kv_head = q_head / num_groups` 索引
|
||||
- 不需要 repeat_kv 操作,直接在 kernel 内部解决
|
||||
|
||||
5. **Decode attention 特化**:
|
||||
- Decode 时 Q 只有 1 行 (Br=1),退化为 vector-matrix attention
|
||||
- 可以写一个专门的 decode attention kernel (类似 FlashDecoding)
|
||||
- 沿 KV sequence 维度做 parallel reduction
|
||||
|
||||
### 测试验收
|
||||
|
||||
@@ -1386,8 +1411,9 @@ for each Q tile (q_start..q_start+Br):
|
||||
| 8192 | OOM? | MB | OOM? | ms |
|
||||
| 32768 | OOM | MB | OOM | ms |
|
||||
|
||||
- [ ] 集成到 Qwen3-7B,端到端 decode latency 对比
|
||||
- [ ] 集成到 Qwen3-8B,端到端 decode latency 对比
|
||||
- [ ] Profile: `ncu` 分析 compute utilization, memory throughput
|
||||
- [ ] GQA 支持: 无 repeat_kv 开销
|
||||
|
||||
---
|
||||
|
||||
@@ -1441,7 +1467,7 @@ ncu --target-processes all --set full ./target/release/xserv-server
|
||||
|
||||
### 测试验收
|
||||
|
||||
- [ ] 安装 vLLM,同一台机器跑 Qwen3-7B
|
||||
- [ ] 安装 vLLM,同一台机器跑 Qwen3-8B
|
||||
- [ ] Benchmark 对比:
|
||||
|
||||
| Metric | vLLM | xserv | Ratio |
|
||||
@@ -1488,7 +1514,7 @@ ncu --target-processes all --set full ./target/release/xserv-server
|
||||
|
||||
- **无损**: rejection sampling 保证输出分布与纯 target model 一致
|
||||
- **加速条件**: draft model 足够快且与 target 分布接近
|
||||
- **Draft model 选择**: Qwen3-0.5B / Qwen3-1.5B 作为 Qwen3-7B 的 draft
|
||||
- **Draft model 选择**: Qwen3-0.5B / Qwen3-1.5B 作为 Qwen3-8B 的 draft
|
||||
|
||||
### KV Cache 处理
|
||||
|
||||
@@ -1578,7 +1604,7 @@ Row Parallel: down_proj 按行切分
|
||||
|
||||
### 测试验收
|
||||
|
||||
- [ ] TP=2: Qwen3-7B 输出与单卡 (TP=1) 完全一致
|
||||
- [ ] TP=2: Qwen3-8B 输出与单卡 (TP=1) 完全一致
|
||||
- [ ] TP=4: 每卡权重显存占用约 1/4
|
||||
- [ ] Scaling benchmark (同组 GPU 0-3):
|
||||
|
||||
@@ -1646,7 +1672,7 @@ tensor_fp8 = cast_to_fp8(tensor / scale)
|
||||
| FP8 E4M3 | X.XX | +0.XX |
|
||||
| INT8 weight-only | X.XX | +0.XX |
|
||||
|
||||
- [ ] 显存: FP8 权重占用约 BF16 的一半 (~7 GB for 7B model)
|
||||
- [ ] 显存: FP8 权重占用约 BF16 的一半 (~8 GB for 8B model)
|
||||
- [ ] 性能: FP8 GEMM throughput vs BF16 GEMM
|
||||
|
||||
---
|
||||
@@ -1722,16 +1748,39 @@ Text → Tokenizer → Text Tokens ────────────→
|
||||
|
||||
---
|
||||
|
||||
## 实际进展记录(与原计划的分叉,2026-06 更新)
|
||||
|
||||
Phase 0–17 按计划完成。Phase 18 起实际路线偏离了上面的原计划
|
||||
(speculative decoding 与多模态推迟),实际走向是 MoE + 量化 + 稀疏化:
|
||||
|
||||
| 实际 Phase | 内容 | 文档 |
|
||||
|---|---|---|
|
||||
| 18 | Pipeline Parallelism(PP=2/4) | `18-pipeline-parallelism.md`、`benchmarks/pp-sweep.md` |
|
||||
| 19 | **gpt-oss-20b MoE**:harmony 格式、attention sinks + sliding window、YaRN;两个 CUDA bug 实战(prefill sinks NaN、GEMV 未初始化 smem);GSM8K 94.5% 对齐 llama.cpp;FP8 W8A8 / MXFP4 W4A16 量化 | `19-gpt-oss-moe.md`、`benchmarks/{fp8-quantization,mxfp4-and-llama-decode}.md` |
|
||||
| 20 | **稀疏 top-k MoE decode**:只算被路由的专家,decode 13.9→7.0ms,TP=2 下 decode/TTFT 全面快于 llama.cpp 同配置;gpt-oss 单卡 serving | `20-sparse-moe.md`、`benchmarks/sparse-moe.md` |
|
||||
| 21 | **decode CUDA Graph + GPU argmax**:整个 decode step 录成一个图回放(thread-local launch stream、retained-warmup 分配策略、NCCL capture);greedy 采样换 GPU argmax。TPOT 7.5→5.9ms(TP=1)/ 5.8ms(TP=2);TP=2 全面领先 llama(1.26-1.47×),TP=1 差距 2.5×→2.0× | `21-cuda-graph-decode.md` |
|
||||
|
||||
**下一步候选(按预期收益排序):**
|
||||
|
||||
| 候选 Phase | 内容 | 预期 |
|
||||
|---|---|---|
|
||||
| 22 | **非专家权重量化**:qkv/o + lm_head(1.16GB/token)仍是 BF16 | TPOT 再省 ~1.5ms |
|
||||
| 23 | **稀疏 prefill**(按专家 permute + grouped GEMM) | 长 prompt TTFT 51-75 → ~30ms |
|
||||
| 24 | server 侧 harmony channel 分离(`reasoning_content` 流式输出,对齐 llama-server 行为) | API 易用性 |
|
||||
| — | Speculative Decoding、多模态(原 16/19) | 推迟 |
|
||||
|
||||
## 里程碑总结
|
||||
|
||||
| 里程碑 | Phase | 验收标准 |
|
||||
|--------|-------|---------|
|
||||
| ① GPT-2 推理 | 8 | CLI 输入 prompt, GPT-2 生成连贯文本, logits 与 PyTorch 一致 |
|
||||
| ② Qwen3-7B 推理 | 10 | 7B 模型中英文对话, 多轮 chat template 正确 |
|
||||
| ② Qwen3-8B 推理 | 10 | 8B 模型中英文对话, 多轮 chat template 正确 |
|
||||
| ③ E2E API | 13 | HTTP streaming API, Python OpenAI SDK 可调用, 10 并发正确 |
|
||||
| ④ 性能达标 | 15 | throughput >= 50% vLLM, profiling 报告完成 |
|
||||
| ⑤ 多卡推理 | 17 | TP=2/4 同组 GPU 推理正确, scaling benchmark 完成 |
|
||||
| ⑥ 多模态 | 19 | 图片输入 → 文字回答, API 端到端 |
|
||||
| ⑥ MoE 模型(实际) | 19 | gpt-oss-20b 端到端正确, GSM8K 与 llama.cpp 持平 ✅ |
|
||||
| ⑦ 性能反超(实际) | 20 | 同配置 decode 快于 llama.cpp(TP=2 达成;单卡是 Phase 21+ 目标) ✅ |
|
||||
| ⑧ 多模态 | 推迟 | 图片输入 → 文字回答, API 端到端 |
|
||||
|
||||
## 外部依赖清单
|
||||
|
||||
|
||||
92
docs/05-attention.md
Normal file
92
docs/05-attention.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# Phase 5: Naive Attention Kernel — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
实现标准 Multi-Head Attention(不做 Flash/Paged 优化),用组合式方法(GEMM + Softmax)完成。这是理解 attention 计算流程的基础,也是后续 Flash Attention 的 baseline。
|
||||
|
||||
## 计算流程
|
||||
|
||||
```
|
||||
Input: Q [B, H, S, D], K [B, H, S, D], V [B, H, S, D]
|
||||
B=batch, H=num_heads, S=seq_len, D=head_dim
|
||||
|
||||
1. scores = Q @ K^T / sqrt(D) → [B, H, S, S]
|
||||
2. scores += causal_mask → 上三角置为 -inf
|
||||
3. weights = softmax(scores, dim=-1) → [B, H, S, S]
|
||||
4. output = weights @ V → [B, H, S, D]
|
||||
```
|
||||
|
||||
## 设计选择
|
||||
|
||||
### 组合式实现(Phase 3 GEMM + Phase 4 Softmax)
|
||||
|
||||
不写新的 fused CUDA kernel,而是复用已有的 matmul 和 softmax:
|
||||
- `scores = batched_matmul(Q, K^T)` — 需要支持 batched GEMM
|
||||
- `masked_fill(scores, causal_mask, -inf)` — 新的逐元素 kernel
|
||||
- `softmax(scores)` — 复用 Phase 4
|
||||
- `output = batched_matmul(weights, V)` — 复用 batched GEMM
|
||||
|
||||
这意味着需要先扩展 matmul 支持 batched GEMM(cublasGemmStridedBatchedEx)。
|
||||
|
||||
### Causal Mask
|
||||
|
||||
不显式构造 mask 矩阵。写一个 kernel:
|
||||
```
|
||||
if (col > row + offset) score = -infinity
|
||||
```
|
||||
其中 offset 用于支持 KV cache 场景(decode 时 query 的 row 偏移)。
|
||||
|
||||
### Batched GEMM via cuBLAS
|
||||
|
||||
`cublasGemmStridedBatchedEx` 在一个 batch 维度上并行执行多个 GEMM:
|
||||
```
|
||||
C[b] = A[b] @ B[b] for b = 0..batch_count
|
||||
stride_a = M * K, stride_b = K * N, stride_c = M * N
|
||||
```
|
||||
|
||||
Attention 中 batch 维度 = B * H(batch_size × num_heads)。
|
||||
|
||||
## 文件布局
|
||||
|
||||
```
|
||||
csrc/attention/
|
||||
└── causal_mask.cu # causal mask fill kernel
|
||||
|
||||
crates/xserv-kernels/src/
|
||||
├── gemm.rs # 扩展: batched_matmul
|
||||
├── attention.rs # NEW: multi_head_attention()
|
||||
└── causal_mask.rs # NEW: causal mask apply
|
||||
```
|
||||
|
||||
## API 设计
|
||||
|
||||
```rust
|
||||
/// Multi-head attention (naive, materializes S×S scores).
|
||||
/// q, k, v: [batch, num_heads, seq_len, head_dim]
|
||||
/// Returns: [batch, num_heads, seq_len, head_dim]
|
||||
pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor;
|
||||
|
||||
/// Batched matmul: A[b] @ B[b] for all b.
|
||||
/// a: [..., M, K], b: [..., K, N] → [..., M, N]
|
||||
pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor;
|
||||
```
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] batched_matmul: [4,8,32,64]×[4,8,64,32] → max_err 2.7e-7
|
||||
- [x] attention (non-causal): B=1,H=2,S=8,D=16 → max_err 4.5e-8
|
||||
- [x] attention (causal): B=1,H=2,S=16,D=32 → max_err 3.0e-8
|
||||
- [x] attention (causal, larger): B=2,H=4,S=64,D=64 → max_err 6.0e-8
|
||||
- [x] causal mask 语义: position 0 只能看到 token 0,output[0] == V[0] → exact
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **`to_device` 不应强制 contiguous**:最初 `to_device()` 会先调 `contiguous()`,而 GPU 的 `contiguous()` 又调 `to_device(Cpu)`,导致无限递归栈溢出。修复:`to_device()` 直接传输 raw storage,保留 strides/offset,用户需要时自己调 `contiguous()`。GPU `contiguous()` 现在走 GPU→CPU→CPU contiguous→CPU→GPU 路径——正确但低效,Phase 15 需要写 GPU contiguous kernel。
|
||||
|
||||
2. **Batched GEMM via `cublasGemmStridedBatchedEx`**:row-major trick 同 Phase 3,额外参数是 stride(元素数,不是字节)。stride_a = M×K, stride_b = K×N, stride_c = M×N。注意初始版本错误地乘了 `elem_size`,cuBLAS 的 stride 单位是元素。
|
||||
|
||||
3. **Attention 的组合式实现足够验证正确性**:没有写 fused kernel,而是复用 `batched_matmul` + `scale` + `causal_mask` + `softmax`。精度极好(max_err < 1e-7),因为每步都在 FP32 中完成。缺点是 S×S score 矩阵完全 materialize(O(S²) 显存),Flash Attention 会解决。
|
||||
|
||||
4. **Scale kernel 的必要性**:原本想在 CPU 上做 scale(round-trip),但那太慢了。加了 `scale_f32/bf16` 逐元素 CUDA kernel。未来可以把 scale 合进 GEMM 的 alpha 参数,省一次 kernel launch。
|
||||
|
||||
5. **Causal mask 的 offset 设计**:`col > row + offset` 中的 offset 为 KV cache 场景预留。Decode 时 Q 只有 1 行但 KV cache 有前 S 行,offset = kv_len - q_len 确保 decode query 能看到所有 cached tokens。
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user