Compare commits
10 Commits
cf1e9e41db
...
531cd3fe08
| Author | SHA1 | Date | |
|---|---|---|---|
| 531cd3fe08 | |||
| 013465fc06 | |||
| 8414f8d1e6 | |||
| 34224c7c93 | |||
| 4088f49b7d | |||
| 2a92f268a9 | |||
| 5343391dbd | |||
| 1897b2e17a | |||
| 63f5599717 | |||
| fb20178992 |
41
README.md
41
README.md
@@ -3,18 +3,24 @@
|
||||
> 从零用 **Rust + CUDA** 构建的 LLM 推理引擎,目标是吃透 LLM Serving 全栈技术。
|
||||
|
||||
xserv 不依赖 PyTorch / vLLM / TensorRT 等现成框架,自己实现了张量抽象、CUDA kernel、
|
||||
分词器、模型前向、KV cache、调度器和 OpenAI 兼容的 HTTP 服务。当前在单张 RTX 5090 上可以
|
||||
跑通 **Qwen3-8B**(BF16),并提供一套与 **llama.cpp** 对比正确性和性能的标准 benchmark。
|
||||
分词器、模型前向、KV cache、调度器和 OpenAI 兼容的 HTTP 服务。支持 **Qwen3-8B**(BF16)
|
||||
和 **gpt-oss-20b**(MoE,BF16/FP8/MXFP4 量化),多卡 TP/PP,并提供一套与 **llama.cpp**
|
||||
对比正确性和性能的标准 benchmark。
|
||||
|
||||
## 现状一览
|
||||
|
||||
- **模型**:GPT-2(124M)、Qwen3-8B(BF16)
|
||||
- **性能**(RTX 5090,Qwen3-8B BF16,贪心解码,单流):约 **56 tok/s**,约为 HF transformers 的 1.4×、llama.cpp 的 ~0.6×
|
||||
- **精度**:在 AIME 2025 / GSM8K 上与 llama.cpp 同权重对比基本持平(数值保真度验证通过)
|
||||
- **服务**:OpenAI 兼容 `/v1/chat/completions`,支持 SSE 流式输出
|
||||
- **关键能力**:自写 GEMM / Flash-Attention 2(SM120) / Paged-Attention kernel、
|
||||
分页 KV cache(含 **CPU 换出/换入** 弹性显存)、连续批处理(continuous batching)、
|
||||
CUDA Graph 解码、按显存自适应的 KV 池
|
||||
- **模型**:GPT-2(124M)、Qwen3-8B(BF16)、gpt-oss-20b(32 专家 top-4 MoE,harmony 格式)
|
||||
- **性能**(RTX 5090,贪心,单流):
|
||||
- Qwen3-8B BF16 单卡:约 56 tok/s(HF transformers 的 1.4×)
|
||||
- gpt-oss-20b FP8 稀疏 MoE + CUDA Graph decode:**TPOT 5.8ms(~172 tok/s,
|
||||
TP=1/2 同速)**;同配置 TP=2 全面快于 llama.cpp(1.26-1.47×),llama
|
||||
单卡模式(2.8ms)仍领先,差距 2.0×
|
||||
- **精度**:GSM8K 全量与 llama.cpp 同权重持平(94.5% vs 94.4%);FP8/MXFP4 量化无回归
|
||||
- **服务**:OpenAI 兼容 `/v1/chat/completions`,SSE 流式;gpt-oss 量化后可**单卡 32GB 服务**
|
||||
- **关键能力**:自写 GEMM / Flash-Attention 2(SM120,含 attention sinks + sliding window) /
|
||||
Paged-Attention kernel、分页 KV cache(含 **CPU 换出/换入**)、连续批处理、
|
||||
CUDA Graph 解码(Qwen3 单卡 + gpt-oss 全路径整图回放)、**Tensor/Pipeline 并行**(NCCL,TP=1/2/4、PP=2/4)、
|
||||
**FP8 W8A8 / MXFP4 W4A16 量化**、**稀疏 top-k MoE decode**(只算被路由的专家)
|
||||
|
||||
> 这是一个以学习为主的项目,逐 Phase 推进,每步都做数值/端到端验证。
|
||||
|
||||
@@ -26,16 +32,19 @@ xserv/
|
||||
│ ├── gemm/ # GEMM (naive / tiled / gemv)
|
||||
│ ├── attention/ # Flash-Attention 2 (SM120)、Paged-Attention、causal mask
|
||||
│ ├── normalization/ # LayerNorm / RMSNorm
|
||||
│ ├── activation/ # GELU / SiLU
|
||||
│ ├── activation/ # GELU / SiLU / gpt-oss GLU
|
||||
│ ├── embedding/ # embedding lookup / RoPE / transpose
|
||||
│ ├── moe/ # MoE top-k 路由、稀疏专家 GEMV、加权求和
|
||||
│ ├── quantization/ # FP8 量化/反量化、cuBLASLt FP8 GEMM、MXFP4 GEMV
|
||||
│ └── reduce/ # softmax
|
||||
├── crates/
|
||||
│ ├── xserv-cuda/ # CUDA FFI、Stream、显存分配器、Pinned 内存、CUDA Graph
|
||||
│ ├── xserv-tensor/ # Tensor 类型(strided 布局、BF16/F16/F32、CPU↔GPU)
|
||||
│ ├── xserv-kernels/ # kernel registry(自写 kernel + cuBLAS 可切换)
|
||||
│ ├── xserv-tokenizer/ # BPE 分词器
|
||||
│ ├── xserv-model/ # 模型定义(GPT-2 / Qwen3)、权重加载、KV cache、采样
|
||||
│ └── xserv-server/ # tokio + axum HTTP 服务、调度器
|
||||
│ ├── xserv-distributed/ # NCCL FFI、TP 上下文(AllReduce)
|
||||
│ ├── xserv-model/ # 模型定义(GPT-2 / Qwen3 / gpt-oss MoE)、权重加载、KV cache、采样
|
||||
│ └── xserv-server/ # tokio + axum HTTP 服务、调度器、TP/PP 引擎
|
||||
├── tools/ # 辅助脚本 + benchmark 套件(见下)
|
||||
└── docs/ # 每个 Phase 的设计文档 + benchmark 报告
|
||||
```
|
||||
@@ -185,12 +194,14 @@ GSM8K 12 个格子全是 29/30,xserv 与 llama.cpp 完全一致;AIME 的 ±1
|
||||
|
||||
## 路线图(节选)
|
||||
|
||||
已完成 Phase 0–18:CUDA 基础设施 → Tensor → GEMM → Transformer kernels → Attention →
|
||||
已完成 Phase 0–21:CUDA 基础设施 → Tensor → GEMM → Transformer kernels → Attention →
|
||||
模型加载 → 分词器 → GPT-2 → KV cache → Qwen3-8B → Paged Attention → 连续批处理 →
|
||||
HTTP API → Flash Attention 2 → 性能优化 → **张量并行(TP)** → **流水线并行(PP)**;
|
||||
HTTP API → Flash Attention 2 → 性能优化 → **张量并行(TP)** → **流水线并行(PP)** →
|
||||
**gpt-oss MoE + FP8/MXFP4 量化** → **稀疏 top-k MoE decode** → **decode CUDA Graph 整图回放**;
|
||||
并加入了 **llama.cpp 对比基准** 与 **KV CPU 换出** 等基础设施。
|
||||
|
||||
后续方向:PP microbatch/1F1B 流水线重叠(吞吐收益)、2D TP×PP、投机解码、量化(FP8 / INT8)、多模态。
|
||||
后续方向:非专家权重量化(lm_head/qkv/o)、稀疏 prefill(grouped GEMM)、server 侧 harmony
|
||||
channel 分离、PP microbatch/1F1B、投机解码、多模态。详见 `docs/00-roadmap.md` 的实际进展记录。
|
||||
|
||||
## 许可
|
||||
|
||||
|
||||
@@ -111,6 +111,22 @@ pub fn cached_trim() {
|
||||
/// Called from `GpuBuffer::Drop` for pooled buffers. Takes raw pointer
|
||||
/// and size to avoid re-triggering Drop.
|
||||
pub fn return_to_pool(ptr: *mut u8, len: usize) {
|
||||
// During CUDA graph capture, buffers freed by the captured code are
|
||||
// quarantined instead of pooled: the instantiated graph references their
|
||||
// addresses on every replay, so they must never be handed to another
|
||||
// consumer for as long as the graph lives.
|
||||
let quarantined = RETAINED.with(|cell| {
|
||||
let mut r = cell.borrow_mut();
|
||||
if let Some(list) = r.as_mut() {
|
||||
list.push((ptr, len));
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
if quarantined {
|
||||
return;
|
||||
}
|
||||
ALLOCATOR.with(|cell| {
|
||||
let mut alloc = cell.borrow_mut();
|
||||
let bucket = bucket_size(len);
|
||||
@@ -119,6 +135,44 @@ pub fn return_to_pool(ptr: *mut u8, len: usize) {
|
||||
});
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static RETAINED: RefCell<Option<Vec<(*mut u8, usize)>>> = const { RefCell::new(None) };
|
||||
}
|
||||
|
||||
/// Buffers freed while a retain window was active. Holding this keeps their
|
||||
/// memory out of the pool; dropping it returns the blocks (on the owning
|
||||
/// thread) for reuse.
|
||||
pub struct RetainedBlocks(Vec<(*mut u8, usize)>);
|
||||
|
||||
impl Drop for RetainedBlocks {
|
||||
fn drop(&mut self) {
|
||||
for (ptr, len) in self.0.drain(..) {
|
||||
return_to_pool(ptr, len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Start quarantining buffers freed on this thread (see `return_to_pool`).
|
||||
/// Must be paired with `end_retain` on the same thread; nesting unsupported.
|
||||
pub fn begin_retain() {
|
||||
RETAINED.with(|cell| {
|
||||
let mut r = cell.borrow_mut();
|
||||
assert!(r.is_none(), "begin_retain: retain window already active");
|
||||
*r = Some(Vec::new());
|
||||
});
|
||||
}
|
||||
|
||||
/// Stop quarantining and hand the quarantined blocks to the caller.
|
||||
pub fn end_retain() -> RetainedBlocks {
|
||||
RETAINED.with(|cell| {
|
||||
let list = cell
|
||||
.borrow_mut()
|
||||
.take()
|
||||
.expect("end_retain without begin_retain");
|
||||
RetainedBlocks(list)
|
||||
})
|
||||
}
|
||||
|
||||
/// Round up to next power-of-2, minimum 512 bytes.
|
||||
fn bucket_size(size: usize) -> usize {
|
||||
let min = 512;
|
||||
|
||||
@@ -48,9 +48,7 @@ pub fn device_info(device: u32) -> Result<DeviceInfo> {
|
||||
// Heap-allocate oversized buffer for cudaDeviceProp (layout varies by CUDA version).
|
||||
// CUDA 12.x struct is ~5-6 KB; use 32 KB to guard against future growth.
|
||||
let mut prop_buf = vec![0u8; 32768];
|
||||
error::check(unsafe {
|
||||
ffi::cudaGetDeviceProperties(prop_buf.as_mut_ptr(), device as i32)
|
||||
})?;
|
||||
error::check(unsafe { ffi::cudaGetDeviceProperties(prop_buf.as_mut_ptr(), device as i32) })?;
|
||||
// Name is always the first field: char[256].
|
||||
let name = unsafe { CStr::from_ptr(prop_buf.as_ptr() as *const c_char) }
|
||||
.to_string_lossy()
|
||||
|
||||
@@ -15,6 +15,7 @@ pub const CUDA_ERROR_OUT_OF_MEMORY: i32 = 2;
|
||||
|
||||
/// cudaStreamCaptureMode::cudaStreamCaptureModeGlobal
|
||||
pub const CUDA_STREAM_CAPTURE_MODE_GLOBAL: i32 = 0;
|
||||
pub const CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL: i32 = 1;
|
||||
|
||||
unsafe extern "C" {
|
||||
// --- Device ---
|
||||
@@ -63,11 +64,5 @@ unsafe extern "C" {
|
||||
pub fn cudaGraphExecDestroy(graph_exec: CudaGraphExec) -> i32;
|
||||
|
||||
// --- Our test kernel ---
|
||||
pub fn launch_vecadd_f32(
|
||||
a: *const f32,
|
||||
b: *const f32,
|
||||
c: *mut f32,
|
||||
n: i32,
|
||||
stream: CudaStream,
|
||||
);
|
||||
pub fn launch_vecadd_f32(a: *const f32, b: *const f32, c: *mut f32, n: i32, stream: CudaStream);
|
||||
}
|
||||
|
||||
@@ -50,31 +50,25 @@ impl CudaGraph {
|
||||
pub fn begin_capture(&mut self, stream: &CudaStream) -> Result<()> {
|
||||
// If we have an old graph, destroy it first
|
||||
self.destroy_inner();
|
||||
// THREAD_LOCAL: only "potentially unsafe" CUDA calls (cudaMalloc etc.)
|
||||
// made by THIS thread invalidate the capture. With GLOBAL mode, TP rank
|
||||
// threads capturing concurrently would poison each other's captures.
|
||||
error::check(unsafe {
|
||||
ffi::cudaStreamBeginCapture(
|
||||
stream.as_raw(),
|
||||
ffi::CUDA_STREAM_CAPTURE_MODE_GLOBAL,
|
||||
)
|
||||
ffi::cudaStreamBeginCapture(stream.as_raw(), ffi::CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL)
|
||||
})
|
||||
}
|
||||
|
||||
/// End capture and instantiate the executable graph.
|
||||
pub fn end_capture(&mut self, stream: &CudaStream) -> Result<()> {
|
||||
error::check(unsafe {
|
||||
ffi::cudaStreamEndCapture(stream.as_raw(), &mut self.graph)
|
||||
})?;
|
||||
error::check(unsafe {
|
||||
ffi::cudaGraphInstantiate(&mut self.exec, self.graph, 0)
|
||||
})
|
||||
error::check(unsafe { ffi::cudaStreamEndCapture(stream.as_raw(), &mut self.graph) })?;
|
||||
error::check(unsafe { ffi::cudaGraphInstantiate(&mut self.exec, self.graph, 0) })
|
||||
}
|
||||
|
||||
/// Replay the captured graph on `stream`.
|
||||
/// Panics if no graph has been captured yet.
|
||||
pub fn launch(&self, stream: &CudaStream) -> Result<()> {
|
||||
assert!(self.is_ready(), "CudaGraph::launch called before capture");
|
||||
error::check(unsafe {
|
||||
ffi::cudaGraphLaunch(self.exec, stream.as_raw())
|
||||
})
|
||||
error::check(unsafe { ffi::cudaGraphLaunch(self.exec, stream.as_raw()) })
|
||||
}
|
||||
|
||||
fn destroy_inner(&mut self) {
|
||||
|
||||
@@ -11,4 +11,4 @@ pub use device::DeviceInfo;
|
||||
pub use error::{CudaError, Result};
|
||||
pub use graph::CudaGraph;
|
||||
pub use memory::{GpuBuffer, PinnedBuffer};
|
||||
pub use stream::CudaStream;
|
||||
pub use stream::{CudaStream, StreamGuard, current_stream_raw, push_stream};
|
||||
|
||||
@@ -22,7 +22,12 @@ impl GpuBuffer {
|
||||
assert!(len > 0, "cannot allocate 0 bytes on GPU");
|
||||
let mut ptr = std::ptr::null_mut();
|
||||
error::check(unsafe { ffi::cudaMalloc(&mut ptr, len) })?;
|
||||
Ok(Self { ptr, len, owned: true, pooled: false })
|
||||
Ok(Self {
|
||||
ptr,
|
||||
len,
|
||||
owned: true,
|
||||
pooled: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Mark this buffer as pooled (returned to caching allocator on drop)
|
||||
@@ -92,9 +97,7 @@ impl GpuBuffer {
|
||||
/// Copy from another GPU buffer (D2D).
|
||||
pub fn copy_from_device(&mut self, src: &GpuBuffer) -> Result<()> {
|
||||
let n = src.len.min(self.len);
|
||||
error::check(unsafe {
|
||||
ffi::cudaMemcpy(self.ptr, src.ptr, n, ffi::CUDA_MEMCPY_D2D)
|
||||
})
|
||||
error::check(unsafe { ffi::cudaMemcpy(self.ptr, src.ptr, n, ffi::CUDA_MEMCPY_D2D) })
|
||||
}
|
||||
|
||||
/// Fill buffer with zeros.
|
||||
@@ -103,7 +106,13 @@ impl GpuBuffer {
|
||||
}
|
||||
|
||||
/// Copy `count` bytes from `src` buffer at `src_offset` to this buffer at `dst_offset`.
|
||||
pub fn copy_from_device_at(&mut self, src: &GpuBuffer, src_offset: usize, dst_offset: usize, count: usize) -> Result<()> {
|
||||
pub fn copy_from_device_at(
|
||||
&mut self,
|
||||
src: &GpuBuffer,
|
||||
src_offset: usize,
|
||||
dst_offset: usize,
|
||||
count: usize,
|
||||
) -> Result<()> {
|
||||
assert!(src_offset + count <= src.len);
|
||||
assert!(dst_offset + count <= self.len);
|
||||
error::check(unsafe {
|
||||
@@ -117,7 +126,14 @@ impl GpuBuffer {
|
||||
}
|
||||
|
||||
/// Async copy `count` bytes from `src` at `src_offset` to `self` at `dst_offset` on `stream`.
|
||||
pub fn copy_from_device_at_async(&mut self, src: &GpuBuffer, src_offset: usize, dst_offset: usize, count: usize, stream: &CudaStream) -> Result<()> {
|
||||
pub fn copy_from_device_at_async(
|
||||
&mut self,
|
||||
src: &GpuBuffer,
|
||||
src_offset: usize,
|
||||
dst_offset: usize,
|
||||
count: usize,
|
||||
stream: &CudaStream,
|
||||
) -> Result<()> {
|
||||
assert!(src_offset + count <= src.len);
|
||||
assert!(dst_offset + count <= self.len);
|
||||
error::check(unsafe {
|
||||
@@ -161,9 +177,7 @@ impl GpuBuffer {
|
||||
|
||||
/// Async zero fill on stream.
|
||||
pub fn zero_async(&mut self, stream: &CudaStream) -> Result<()> {
|
||||
error::check(unsafe {
|
||||
ffi::cudaMemsetAsync(self.ptr, 0, self.len, stream.as_raw())
|
||||
})
|
||||
error::check(unsafe { ffi::cudaMemsetAsync(self.ptr, 0, self.len, stream.as_raw()) })
|
||||
}
|
||||
|
||||
/// Consume the buffer without freeing GPU memory. Returns the raw pointer and length.
|
||||
@@ -178,7 +192,12 @@ impl GpuBuffer {
|
||||
/// Reconstruct a GpuBuffer from a raw pointer + length.
|
||||
/// Safety: ptr must have been allocated with cudaMalloc, len must be correct.
|
||||
pub unsafe fn from_raw(ptr: *mut u8, len: usize) -> Self {
|
||||
Self { ptr, len, owned: true, pooled: false }
|
||||
Self {
|
||||
ptr,
|
||||
len,
|
||||
owned: true,
|
||||
pooled: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a non-owning view of GPU memory. Dropping this buffer does NOT
|
||||
@@ -189,7 +208,12 @@ impl GpuBuffer {
|
||||
/// `ptr` must point to a valid GPU allocation of at least `len` bytes that
|
||||
/// will remain live for the lifetime of the returned `GpuBuffer`.
|
||||
pub unsafe fn borrow_raw(ptr: *mut u8, len: usize) -> Self {
|
||||
Self { ptr, len, owned: false, pooled: false }
|
||||
Self {
|
||||
ptr,
|
||||
len,
|
||||
owned: false,
|
||||
pooled: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -31,3 +31,39 @@ impl Drop for CudaStream {
|
||||
|
||||
// Can move across threads, but not shared without synchronization
|
||||
unsafe impl Send for CudaStream {}
|
||||
|
||||
// --- Thread-local launch stream -------------------------------------------
|
||||
//
|
||||
// Every kernel wrapper in xserv-kernels launches on `current_stream_raw()`,
|
||||
// which defaults to the legacy null stream (the historical behavior). CUDA
|
||||
// graph capture requires work to be issued on an explicit stream, so capture
|
||||
// code installs its stream here for the duration of the captured region via
|
||||
// `push_stream` / `StreamGuard`.
|
||||
|
||||
use std::cell::Cell;
|
||||
|
||||
thread_local! {
|
||||
static CURRENT_STREAM: Cell<ffi::CudaStream> = const { Cell::new(std::ptr::null_mut()) };
|
||||
}
|
||||
|
||||
/// The stream kernel launches on this thread should use (null = legacy default).
|
||||
pub fn current_stream_raw() -> ffi::CudaStream {
|
||||
CURRENT_STREAM.with(|c| c.get())
|
||||
}
|
||||
|
||||
/// RAII guard that installs a launch stream for the current thread and
|
||||
/// restores the previous one on drop.
|
||||
pub struct StreamGuard {
|
||||
prev: ffi::CudaStream,
|
||||
}
|
||||
|
||||
pub fn push_stream(stream: &CudaStream) -> StreamGuard {
|
||||
let prev = CURRENT_STREAM.with(|c| c.replace(stream.as_raw()));
|
||||
StreamGuard { prev }
|
||||
}
|
||||
|
||||
impl Drop for StreamGuard {
|
||||
fn drop(&mut self) {
|
||||
CURRENT_STREAM.with(|c| c.set(self.prev));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,10 @@ fn test_device_info() {
|
||||
info.compute_major, info.compute_minor
|
||||
);
|
||||
println!(" SM Count: {}", info.sm_count);
|
||||
println!(" Shared Mem/Block: {} KB", info.shared_mem_per_block / 1024);
|
||||
println!(
|
||||
" Shared Mem/Block: {} KB",
|
||||
info.shared_mem_per_block / 1024
|
||||
);
|
||||
println!(" Warp Size: {}", info.warp_size);
|
||||
println!(" Max Threads/Block: {}", info.max_threads_per_block);
|
||||
|
||||
@@ -145,7 +148,11 @@ fn test_caching_allocator() {
|
||||
|
||||
// Second allocation of same size: should hit cache
|
||||
let _buf2 = alloc.alloc(1024).unwrap();
|
||||
assert_eq!(alloc.stats().cuda_malloc_count, 1, "should reuse cached buffer");
|
||||
assert_eq!(
|
||||
alloc.stats().cuda_malloc_count,
|
||||
1,
|
||||
"should reuse cached buffer"
|
||||
);
|
||||
assert_eq!(alloc.stats().cache_hit_count, 1);
|
||||
}
|
||||
|
||||
@@ -198,11 +205,17 @@ fn test_async_copy() {
|
||||
}
|
||||
|
||||
let mut gpu = GpuBuffer::alloc(4096).unwrap();
|
||||
unsafe { gpu.copy_from_host_async(pinned.as_slice(), &stream).unwrap() };
|
||||
unsafe {
|
||||
gpu.copy_from_host_async(pinned.as_slice(), &stream)
|
||||
.unwrap()
|
||||
};
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
let mut out_pinned = PinnedBuffer::alloc(4096).unwrap();
|
||||
unsafe { gpu.copy_to_host_async(out_pinned.as_mut_slice(), &stream).unwrap() };
|
||||
unsafe {
|
||||
gpu.copy_to_host_async(out_pinned.as_mut_slice(), &stream)
|
||||
.unwrap()
|
||||
};
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
assert_eq!(pinned.as_slice(), out_pinned.as_slice());
|
||||
|
||||
@@ -34,7 +34,12 @@ pub const NCCL_SUCCESS: i32 = 0;
|
||||
unsafe extern "C" {
|
||||
pub fn ncclGetUniqueId(uid: *mut NcclUniqueId) -> i32;
|
||||
// ncclUniqueId is passed BY VALUE (a 128-byte struct) per the NCCL ABI.
|
||||
pub fn ncclCommInitRank(comm: *mut NcclComm, nranks: i32, commid: NcclUniqueId, rank: i32) -> i32;
|
||||
pub fn ncclCommInitRank(
|
||||
comm: *mut NcclComm,
|
||||
nranks: i32,
|
||||
commid: NcclUniqueId,
|
||||
rank: i32,
|
||||
) -> i32;
|
||||
pub fn ncclCommDestroy(comm: NcclComm) -> i32;
|
||||
pub fn ncclAllReduce(
|
||||
sendbuff: *const c_void,
|
||||
@@ -78,5 +83,10 @@ pub fn err_string(result: i32) -> String {
|
||||
}
|
||||
|
||||
pub fn check(result: i32, what: &str) {
|
||||
assert_eq!(result, NCCL_SUCCESS, "{what} failed: {}", err_string(result));
|
||||
assert_eq!(
|
||||
result,
|
||||
NCCL_SUCCESS,
|
||||
"{what} failed: {}",
|
||||
err_string(result)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -9,15 +9,18 @@ pub mod ffi;
|
||||
use std::ffi::c_void;
|
||||
|
||||
use ffi::{NcclComm, NcclUniqueId};
|
||||
use xserv_cuda::device;
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_cuda::device;
|
||||
|
||||
pub use ffi::NcclUniqueId as UniqueId;
|
||||
|
||||
/// The CUDA "null" (default) stream. The model's kernels and cuBLAS calls run
|
||||
/// on it, so issuing NCCL on the same stream keeps AllReduce correctly ordered
|
||||
/// after the producing matmul and before the consuming kernel — no extra sync.
|
||||
const NULL_STREAM: xserv_cuda::ffi::CudaStream = std::ptr::null_mut();
|
||||
/// NCCL is issued on the thread's current launch stream (legacy null stream
|
||||
/// by default, the capture stream during CUDA graph capture). The model's
|
||||
/// kernels run on the same stream, so AllReduce stays correctly ordered after
|
||||
/// the producing matmul and before the consuming kernel — no extra sync.
|
||||
fn launch_stream() -> xserv_cuda::ffi::CudaStream {
|
||||
xserv_cuda::stream::current_stream_raw()
|
||||
}
|
||||
|
||||
/// Generate a unique id on one rank (typically rank 0) and broadcast the bytes
|
||||
/// to all ranks out-of-band (e.g. via a shared variable across threads).
|
||||
@@ -52,7 +55,12 @@ impl TpContext {
|
||||
"ncclCommInitRank",
|
||||
);
|
||||
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)");
|
||||
Self { rank, world, device, comm }
|
||||
Self {
|
||||
rank,
|
||||
world,
|
||||
device,
|
||||
comm,
|
||||
}
|
||||
}
|
||||
|
||||
/// In-place AllReduce(sum) over `count` BF16 elements in `buf`.
|
||||
@@ -80,7 +88,7 @@ impl TpContext {
|
||||
ffi::NCCL_BF16,
|
||||
ffi::NCCL_SUM,
|
||||
self.comm,
|
||||
NULL_STREAM,
|
||||
launch_stream(),
|
||||
)
|
||||
},
|
||||
"ncclAllReduce",
|
||||
@@ -124,7 +132,12 @@ impl PpContext {
|
||||
"ncclCommInitRank",
|
||||
);
|
||||
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)");
|
||||
Self { stage, world, device, comm }
|
||||
Self {
|
||||
stage,
|
||||
world,
|
||||
device,
|
||||
comm,
|
||||
}
|
||||
}
|
||||
|
||||
/// Send `count` BF16 elements at `ptr` to `peer`, on the null stream so it is
|
||||
@@ -135,7 +148,16 @@ impl PpContext {
|
||||
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
|
||||
pub fn send_bf16_ptr(&self, ptr: *const c_void, count: usize, peer: usize) {
|
||||
ffi::check(
|
||||
unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) },
|
||||
unsafe {
|
||||
ffi::ncclSend(
|
||||
ptr,
|
||||
count,
|
||||
ffi::NCCL_BF16,
|
||||
peer as i32,
|
||||
self.comm,
|
||||
launch_stream(),
|
||||
)
|
||||
},
|
||||
"ncclSend",
|
||||
);
|
||||
}
|
||||
@@ -146,7 +168,16 @@ impl PpContext {
|
||||
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
|
||||
pub fn recv_bf16_ptr(&self, ptr: *mut c_void, count: usize, peer: usize) {
|
||||
ffi::check(
|
||||
unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) },
|
||||
unsafe {
|
||||
ffi::ncclRecv(
|
||||
ptr,
|
||||
count,
|
||||
ffi::NCCL_BF16,
|
||||
peer as i32,
|
||||
self.comm,
|
||||
launch_stream(),
|
||||
)
|
||||
},
|
||||
"ncclRecv",
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
use half::bf16;
|
||||
use std::thread;
|
||||
use xserv_cuda::{device, GpuBuffer};
|
||||
use xserv_distributed::{get_unique_id, TpContext};
|
||||
use xserv_cuda::{GpuBuffer, device};
|
||||
use xserv_distributed::{TpContext, get_unique_id};
|
||||
|
||||
#[test]
|
||||
fn allreduce_two_gpu_sum() {
|
||||
@@ -25,9 +25,7 @@ fn allreduce_two_gpu_sum() {
|
||||
// Rank r fills its buffer with (r + 1).
|
||||
let val = bf16::from_f32((rank + 1) as f32);
|
||||
let host = vec![val; n];
|
||||
let src = unsafe {
|
||||
std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2)
|
||||
};
|
||||
let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
|
||||
let mut buf = GpuBuffer::alloc(n * 2).unwrap();
|
||||
buf.copy_from_host(src).unwrap();
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
use half::bf16;
|
||||
use std::ffi::c_void;
|
||||
use std::thread;
|
||||
use xserv_cuda::{device, GpuBuffer};
|
||||
use xserv_distributed::{get_unique_id, PpContext};
|
||||
use xserv_cuda::{GpuBuffer, device};
|
||||
use xserv_distributed::{PpContext, get_unique_id};
|
||||
|
||||
#[test]
|
||||
fn pp_send_recv_two_stages() {
|
||||
@@ -30,7 +30,8 @@ fn pp_send_recv_two_stages() {
|
||||
if stage == 0 {
|
||||
// Fill with a known pattern and send to stage 1.
|
||||
let host: Vec<bf16> = (0..n).map(|i| bf16::from_f32((i % 97) as f32)).collect();
|
||||
let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
|
||||
let src =
|
||||
unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
|
||||
buf.copy_from_host(src).unwrap();
|
||||
pp.send_bf16_ptr(buf.as_mut_ptr() as *const c_void, n, 1);
|
||||
device::synchronize().unwrap();
|
||||
|
||||
@@ -31,6 +31,7 @@ fn main() {
|
||||
.file("../../csrc/attention/paged_attention.cu")
|
||||
.file("../../csrc/attention/reshape_and_cache.cu")
|
||||
.file("../../csrc/moe/moe_kernels.cu")
|
||||
.file("../../csrc/moe/moe_sparse.cu")
|
||||
.file("../../csrc/quantization/dequant_fp8.cu")
|
||||
.file("../../csrc/quantization/quantize_fp8.cu")
|
||||
.file("../../csrc/quantization/mxfp4_gemm.cu")
|
||||
|
||||
@@ -6,76 +6,220 @@ unsafe extern "C" {
|
||||
fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_scale_f32(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
|
||||
fn launch_scale_bf16(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
|
||||
fn launch_add_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_mul_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_mul_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_gpt_oss_glu_bf16(gate_up: *const c_void, out: *mut c_void, n_elements: i32,
|
||||
alpha: f32, limit: f32, stream: *mut c_void);
|
||||
fn launch_scale_f32(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
scale: f32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_scale_bf16(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
scale: f32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_f32(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_mul_f32(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_mul_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_silu_mul_bf16(
|
||||
gate: *const c_void,
|
||||
up: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gpt_oss_glu_bf16(
|
||||
gate_up: *const c_void,
|
||||
out: *mut c_void,
|
||||
n_elements: i32,
|
||||
alpha: f32,
|
||||
limit: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_bias_add_2d_bf16(
|
||||
x: *const c_void,
|
||||
bias: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
|
||||
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
|
||||
fn dispatch_unary(
|
||||
x: &Tensor,
|
||||
f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
|
||||
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
|
||||
) -> Tensor {
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel();
|
||||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||
assert!(
|
||||
n <= i32::MAX as usize,
|
||||
"tensor too large for i32 kernel param ({n} elements)"
|
||||
);
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
DType::F32 => f32_fn(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => bf16_fn(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype"),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn dispatch_binary(a: &Tensor, b: &Tensor,
|
||||
f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
|
||||
bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
|
||||
fn dispatch_binary(
|
||||
a: &Tensor,
|
||||
b: &Tensor,
|
||||
f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
|
||||
bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
|
||||
) -> Tensor {
|
||||
assert_eq!(a.shape(), b.shape());
|
||||
assert!(a.is_contiguous() && b.is_contiguous());
|
||||
assert!(matches!(a.device(), Device::Cuda(_)));
|
||||
assert_eq!(a.dtype(), b.dtype());
|
||||
let out = Tensor::empty(a.shape(), a.dtype(), a.device());
|
||||
let n = a.numel();
|
||||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||
assert!(
|
||||
n <= i32::MAX as usize,
|
||||
"tensor too large for i32 kernel param ({n} elements)"
|
||||
);
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match a.dtype() {
|
||||
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
DType::F32 => f32_fn(
|
||||
a.data_ptr() as _,
|
||||
b.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => bf16_fn(
|
||||
a.data_ptr() as _,
|
||||
b.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype"),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub fn gelu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16) }
|
||||
pub fn silu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_silu_f32, launch_silu_bf16) }
|
||||
pub fn gelu(x: &Tensor) -> Tensor {
|
||||
dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16)
|
||||
}
|
||||
pub fn silu(x: &Tensor) -> Tensor {
|
||||
dispatch_unary(x, launch_silu_f32, launch_silu_bf16)
|
||||
}
|
||||
|
||||
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel();
|
||||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||
assert!(
|
||||
n <= i32::MAX as usize,
|
||||
"tensor too large for i32 kernel param ({n} elements)"
|
||||
);
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
||||
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
||||
DType::F32 => launch_scale_f32(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
scale_val,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_scale_bf16(
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
scale_val,
|
||||
n,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for scale"),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub fn add(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_add_f32, launch_add_bf16) }
|
||||
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) }
|
||||
pub fn add(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
dispatch_binary(a, b, launch_add_f32, launch_add_bf16)
|
||||
}
|
||||
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16)
|
||||
}
|
||||
|
||||
/// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c] (BF16 only).
|
||||
pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor {
|
||||
assert_eq!(x.ndim(), 2);
|
||||
assert_eq!(bias.ndim(), 1);
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert_eq!(bias.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && bias.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let rows = x.shape()[0];
|
||||
let cols = x.shape()[1];
|
||||
assert_eq!(
|
||||
bias.shape()[0],
|
||||
cols,
|
||||
"bias size {} != cols {cols}",
|
||||
bias.shape()[0]
|
||||
);
|
||||
assert!(rows * cols <= i32::MAX as usize);
|
||||
let out = Tensor::empty(&[rows, cols], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_bias_add_2d_bf16(
|
||||
x.data_ptr() as _,
|
||||
bias.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Fused SiLU×Mul: out = silu(gate) * up (BF16 only)
|
||||
/// Saves one HBM read + one HBM write compared to separate silu + mul.
|
||||
@@ -86,7 +230,10 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
|
||||
assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16");
|
||||
let out = Tensor::empty(gate.shape(), gate.dtype(), gate.device());
|
||||
let n = gate.numel();
|
||||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||
assert!(
|
||||
n <= i32::MAX as usize,
|
||||
"tensor too large for i32 kernel param ({n} elements)"
|
||||
);
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
launch_silu_mul_bf16(
|
||||
@@ -94,7 +241,7 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
|
||||
up.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -122,7 +269,7 @@ pub fn gpt_oss_glu(gate_up: &Tensor, alpha: f32, limit: f32) -> Tensor {
|
||||
n_elements,
|
||||
alpha,
|
||||
limit,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
|
||||
@@ -2,8 +2,13 @@ use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_argmax_bf16(logits: *const c_void, out_idx: *mut c_void,
|
||||
rows: i32, cols: i32, stream: *mut c_void);
|
||||
fn launch_argmax_bf16(
|
||||
logits: *const c_void,
|
||||
out_idx: *mut c_void,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// GPU argmax over the last dim of a [rows, cols] BF16 tensor.
|
||||
@@ -19,7 +24,10 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
|
||||
assert_eq!(logits.ndim(), 2, "argmax expects a 2D [rows, cols] tensor");
|
||||
assert_eq!(logits.dtype(), DType::BF16, "argmax kernel is BF16-only");
|
||||
assert!(logits.is_contiguous(), "argmax requires contiguous input");
|
||||
assert!(matches!(logits.device(), Device::Cuda(_)), "argmax requires GPU input");
|
||||
assert!(
|
||||
matches!(logits.device(), Device::Cuda(_)),
|
||||
"argmax requires GPU input"
|
||||
);
|
||||
|
||||
let rows = logits.shape()[0];
|
||||
let cols = logits.shape()[1];
|
||||
@@ -35,8 +43,9 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
|
||||
launch_argmax_bf16(
|
||||
logits.data_ptr() as *const c_void,
|
||||
out.as_mut_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -44,9 +53,8 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
|
||||
out.copy_to_host(&mut host_bytes).expect("argmax D2H");
|
||||
drop(out); // returned to pool
|
||||
|
||||
let host_i32: &[i32] = unsafe {
|
||||
std::slice::from_raw_parts(host_bytes.as_ptr() as *const i32, rows)
|
||||
};
|
||||
let host_i32: &[i32] =
|
||||
unsafe { std::slice::from_raw_parts(host_bytes.as_ptr() as *const i32, rows) };
|
||||
host_i32.iter().map(|&v| v as u32).collect()
|
||||
}
|
||||
|
||||
@@ -62,4 +70,3 @@ pub fn argmax_bf16_single(logits: &Tensor) -> u32 {
|
||||
};
|
||||
argmax_bf16_to_host(&view)[0]
|
||||
}
|
||||
|
||||
|
||||
@@ -6,28 +6,67 @@ use crate::gemm::batched_matmul;
|
||||
use crate::softmax::softmax;
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_causal_mask_f32(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
|
||||
offset: i32, stream: *mut c_void);
|
||||
fn launch_causal_mask_bf16(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
|
||||
offset: i32, stream: *mut c_void);
|
||||
fn launch_causal_mask_f32(
|
||||
scores: *mut c_void,
|
||||
batch: i32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
offset: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_causal_mask_bf16(
|
||||
scores: *mut c_void,
|
||||
batch: i32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
offset: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_flash_attention_bf16(
|
||||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
q_len: i32, kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, stream: *mut c_void,
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
q_len: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
causal: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_flash_attention_sinks_bf16(
|
||||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
sinks: *const c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
q_len: i32, kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, window_size: i32, stream: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
q_len: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
causal: i32,
|
||||
window_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_decode_attention_bf16(
|
||||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, stream: *mut c_void,
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
causal: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_paged_decode_attention_bf16(
|
||||
q: *const c_void,
|
||||
@@ -36,9 +75,13 @@ unsafe extern "C" {
|
||||
o: *mut c_void,
|
||||
block_tables: *const i32,
|
||||
context_lens: *const i32,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
head_dim: i32, max_blocks_per_seq: i32,
|
||||
scale: f32, stream: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
head_dim: i32,
|
||||
max_blocks_per_seq: i32,
|
||||
scale: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_paged_decode_attention_sinks_bf16(
|
||||
q: *const c_void,
|
||||
@@ -48,24 +91,40 @@ unsafe extern "C" {
|
||||
block_tables: *const i32,
|
||||
context_lens: *const i32,
|
||||
sinks: *const c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
head_dim: i32, max_blocks_per_seq: i32,
|
||||
scale: f32, window_size: i32, stream: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
head_dim: i32,
|
||||
max_blocks_per_seq: i32,
|
||||
scale: f32,
|
||||
window_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_reshape_and_cache_bf16(
|
||||
k_src: *const c_void, v_src: *const c_void,
|
||||
k_pool: *mut c_void, v_pool: *mut c_void,
|
||||
k_src: *const c_void,
|
||||
v_src: *const c_void,
|
||||
k_pool: *mut c_void,
|
||||
v_pool: *mut c_void,
|
||||
block_ids: *const c_void,
|
||||
num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, start_pos: i32, block_size: i32,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
start_pos: i32,
|
||||
block_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_reshape_and_cache_batched_bf16(
|
||||
k_src: *const c_void, v_src: *const c_void,
|
||||
k_pool: *mut c_void, v_pool: *mut c_void,
|
||||
block_tables: *const c_void, kv_lens: *const c_void,
|
||||
batch: i32, num_heads: i32,
|
||||
head_dim: i32, block_size: i32, max_blocks_per_seq: i32,
|
||||
k_src: *const c_void,
|
||||
v_src: *const c_void,
|
||||
k_pool: *mut c_void,
|
||||
v_pool: *mut c_void,
|
||||
block_tables: *const c_void,
|
||||
kv_lens: *const c_void,
|
||||
batch: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
block_size: i32,
|
||||
max_blocks_per_seq: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
@@ -84,20 +143,30 @@ unsafe extern "C" {
|
||||
/// `block_ids_gpu` must contain at least `(start_pos + num_tokens + block_size - 1) / block_size`
|
||||
/// valid physical block ids.
|
||||
pub unsafe fn reshape_and_cache_bf16(
|
||||
k_src: *const c_void, v_src: *const c_void,
|
||||
k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void,
|
||||
k_src: *const c_void,
|
||||
v_src: *const c_void,
|
||||
k_pool_ptr: *mut c_void,
|
||||
v_pool_ptr: *mut c_void,
|
||||
block_ids_gpu: *const i32,
|
||||
num_tokens: usize, num_heads: usize,
|
||||
head_dim: usize, start_pos: usize, block_size: usize,
|
||||
num_tokens: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
start_pos: usize,
|
||||
block_size: usize,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
unsafe {
|
||||
launch_reshape_and_cache_bf16(
|
||||
k_src, v_src,
|
||||
k_pool_ptr, v_pool_ptr,
|
||||
k_src,
|
||||
v_src,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
block_ids_gpu as *const c_void,
|
||||
num_tokens as i32, num_heads as i32,
|
||||
head_dim as i32, start_pos as i32, block_size as i32,
|
||||
num_tokens as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
start_pos as i32,
|
||||
block_size as i32,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
@@ -113,21 +182,32 @@ pub unsafe fn reshape_and_cache_bf16(
|
||||
/// All pointers must be on the same GPU. `block_tables` and `kv_lens` must
|
||||
/// already be synced to the device for the active batch.
|
||||
pub unsafe fn reshape_and_cache_batched_bf16(
|
||||
k_src: *const c_void, v_src: *const c_void,
|
||||
k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void,
|
||||
block_tables_gpu: *const i32, kv_lens_gpu: *const i32,
|
||||
batch: usize, num_heads: usize,
|
||||
head_dim: usize, block_size: usize, max_blocks_per_seq: usize,
|
||||
k_src: *const c_void,
|
||||
v_src: *const c_void,
|
||||
k_pool_ptr: *mut c_void,
|
||||
v_pool_ptr: *mut c_void,
|
||||
block_tables_gpu: *const i32,
|
||||
kv_lens_gpu: *const i32,
|
||||
batch: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
block_size: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
unsafe {
|
||||
launch_reshape_and_cache_batched_bf16(
|
||||
k_src, v_src,
|
||||
k_pool_ptr, v_pool_ptr,
|
||||
k_src,
|
||||
v_src,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
block_tables_gpu as *const c_void,
|
||||
kv_lens_gpu as *const c_void,
|
||||
batch as i32, num_heads as i32,
|
||||
head_dim as i32, block_size as i32, max_blocks_per_seq as i32,
|
||||
batch as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
block_size as i32,
|
||||
max_blocks_per_seq as i32,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
@@ -143,13 +223,19 @@ fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
||||
match scores.dtype() {
|
||||
DType::F32 => launch_causal_mask_f32(
|
||||
scores.data_ptr() as *mut c_void,
|
||||
batch as i32, rows as i32, cols as i32, offset as i32,
|
||||
std::ptr::null_mut(),
|
||||
batch as i32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
offset as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_causal_mask_bf16(
|
||||
scores.data_ptr() as *mut c_void,
|
||||
batch as i32, rows as i32, cols as i32, offset as i32,
|
||||
std::ptr::null_mut(),
|
||||
batch as i32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
offset as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for causal mask"),
|
||||
}
|
||||
@@ -214,11 +300,7 @@ pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
|
||||
let kv_len = k.shape()[2];
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(
|
||||
&[batch, num_q_heads, 1, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
);
|
||||
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
|
||||
|
||||
unsafe {
|
||||
launch_decode_attention_bf16(
|
||||
@@ -233,7 +315,7 @@ pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
|
||||
head_dim as i32,
|
||||
scale,
|
||||
1, // causal (always 1 for decode)
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -266,8 +348,14 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
|
||||
|
||||
assert_eq!(k.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||||
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||||
assert!(num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads");
|
||||
assert!(head_dim <= 128, "flash_attention supports head_dim up to 128");
|
||||
assert!(
|
||||
num_q_heads % num_kv_heads == 0,
|
||||
"num_q_heads must be divisible by num_kv_heads"
|
||||
);
|
||||
assert!(
|
||||
head_dim <= 128,
|
||||
"flash_attention supports head_dim up to 128"
|
||||
);
|
||||
|
||||
// Dispatch to specialized decode kernel for single-token generation
|
||||
if q_len == 1 {
|
||||
@@ -295,7 +383,7 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
|
||||
head_dim as i32,
|
||||
scale,
|
||||
if causal { 1 } else { 0 },
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -333,10 +421,18 @@ pub fn flash_attention_sinks(
|
||||
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||||
assert!(num_q_heads % num_kv_heads == 0);
|
||||
assert!(head_dim <= 128);
|
||||
assert_eq!(sinks.shape()[0], num_q_heads, "sinks must have num_q_heads entries");
|
||||
assert_eq!(
|
||||
sinks.shape()[0],
|
||||
num_q_heads,
|
||||
"sinks must have num_q_heads entries"
|
||||
);
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(&[batch, num_q_heads, q_len, head_dim], DType::BF16, q.device());
|
||||
let output = Tensor::empty(
|
||||
&[batch, num_q_heads, q_len, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
);
|
||||
|
||||
unsafe {
|
||||
launch_flash_attention_sinks_bf16(
|
||||
@@ -354,7 +450,7 @@ pub fn flash_attention_sinks(
|
||||
scale,
|
||||
1, // always causal
|
||||
window_size as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -383,17 +479,20 @@ pub fn paged_decode_attention(
|
||||
max_blocks_per_seq: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(q.shape()[2], 1, "paged_decode_attention requires q_len == 1");
|
||||
assert_eq!(
|
||||
q.shape()[2],
|
||||
1,
|
||||
"paged_decode_attention requires q_len == 1"
|
||||
);
|
||||
assert_eq!(q.dtype(), DType::BF16);
|
||||
assert!(num_q_heads % num_kv_heads == 0, "GQA: num_q_heads must be divisible by num_kv_heads");
|
||||
assert!(
|
||||
num_q_heads % num_kv_heads == 0,
|
||||
"GQA: num_q_heads must be divisible by num_kv_heads"
|
||||
);
|
||||
assert!(head_dim <= 128);
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(
|
||||
&[batch, num_q_heads, 1, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
);
|
||||
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
|
||||
|
||||
unsafe {
|
||||
launch_paged_decode_attention_bf16(
|
||||
@@ -409,7 +508,7 @@ pub fn paged_decode_attention(
|
||||
head_dim as i32,
|
||||
max_blocks_per_seq as i32,
|
||||
scale,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -442,11 +541,7 @@ pub fn paged_decode_attention_sinks(
|
||||
assert!(head_dim <= 128);
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(
|
||||
&[batch, num_q_heads, 1, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
);
|
||||
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
|
||||
|
||||
unsafe {
|
||||
launch_paged_decode_attention_sinks_bf16(
|
||||
@@ -464,7 +559,7 @@ pub fn paged_decode_attention_sinks(
|
||||
max_blocks_per_seq as i32,
|
||||
scale,
|
||||
window_size as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -5,104 +5,302 @@ use std::ffi::c_void;
|
||||
|
||||
// Re-declare the extern functions we need (same as in the individual modules)
|
||||
unsafe extern "C" {
|
||||
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
|
||||
normed_out: *mut c_void, sum_out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
|
||||
fn launch_reshape_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_merge_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void);
|
||||
fn launch_gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, y_fp32_buf: *mut c_void,
|
||||
k: i32, n: i32, stream: *mut c_void);
|
||||
fn launch_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
residual: *const c_void,
|
||||
gamma: *const c_void,
|
||||
normed_out: *mut c_void,
|
||||
sum_out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_silu_mul_bf16(
|
||||
gate: *const c_void,
|
||||
up: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_embedding_bf16(
|
||||
table: *const c_void,
|
||||
token_ids: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden_size: i32,
|
||||
vocab_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_reshape_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_merge_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_transpose_hsd_to_shd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_transpose_shd_to_hsd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_rope_bf16(
|
||||
x: *mut c_void,
|
||||
cos_cache: *const c_void,
|
||||
sin_cache: *const c_void,
|
||||
positions: *const c_void,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gemv_bf16(
|
||||
x: *const c_void,
|
||||
w: *const c_void,
|
||||
y_bf16: *mut c_void,
|
||||
y_fp32_buf: *mut c_void,
|
||||
k: i32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_decode_attention_bf16(
|
||||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, stream: *mut c_void,
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
causal: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// Raw rmsnorm dispatch: writes to pre-allocated `out`.
|
||||
pub unsafe fn rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void) {
|
||||
pub unsafe fn rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_rmsnorm_bf16(x, gamma, out, rows, hidden_size, eps, stream);
|
||||
}
|
||||
|
||||
/// Raw add_rmsnorm dispatch.
|
||||
pub unsafe fn add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
|
||||
normed_out: *mut c_void, sum_out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void) {
|
||||
launch_add_rmsnorm_bf16(x, residual, gamma, normed_out, sum_out, rows, hidden_size, eps, stream);
|
||||
pub unsafe fn add_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
residual: *const c_void,
|
||||
gamma: *const c_void,
|
||||
normed_out: *mut c_void,
|
||||
sum_out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_add_rmsnorm_bf16(
|
||||
x,
|
||||
residual,
|
||||
gamma,
|
||||
normed_out,
|
||||
sum_out,
|
||||
rows,
|
||||
hidden_size,
|
||||
eps,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
/// Raw silu_mul dispatch.
|
||||
pub unsafe fn silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void) {
|
||||
pub unsafe fn silu_mul_bf16(
|
||||
gate: *const c_void,
|
||||
up: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_silu_mul_bf16(gate, up, out, n, stream);
|
||||
}
|
||||
|
||||
/// Raw add dispatch.
|
||||
pub unsafe fn add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void) {
|
||||
pub unsafe fn add_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_add_bf16(a, b, out, n, stream);
|
||||
}
|
||||
|
||||
/// Raw embedding dispatch.
|
||||
pub unsafe fn embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void) {
|
||||
launch_embedding_bf16(table, token_ids, out, num_tokens, hidden_size, vocab_size, stream);
|
||||
pub unsafe fn embedding_bf16(
|
||||
table: *const c_void,
|
||||
token_ids: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden_size: i32,
|
||||
vocab_size: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_embedding_bf16(
|
||||
table,
|
||||
token_ids,
|
||||
out,
|
||||
num_tokens,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
/// Raw reshape_heads dispatch.
|
||||
pub unsafe fn reshape_heads_bf16(inp: *const c_void, out: *mut c_void,
|
||||
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||
pub unsafe fn reshape_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_reshape_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw merge_heads dispatch.
|
||||
pub unsafe fn merge_heads_bf16(inp: *const c_void, out: *mut c_void,
|
||||
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||
pub unsafe fn merge_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_merge_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw transpose HSD->SHD dispatch.
|
||||
pub unsafe fn transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void,
|
||||
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||
pub unsafe fn transpose_hsd_to_shd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_transpose_hsd_to_shd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw transpose SHD->HSD dispatch.
|
||||
pub unsafe fn transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void,
|
||||
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||
pub unsafe fn transpose_shd_to_hsd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_transpose_shd_to_hsd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw RoPE dispatch (in-place).
|
||||
pub unsafe fn rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void) {
|
||||
launch_rope_bf16(x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream);
|
||||
pub unsafe fn rope_bf16(
|
||||
x: *mut c_void,
|
||||
cos_cache: *const c_void,
|
||||
sin_cache: *const c_void,
|
||||
positions: *const c_void,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_rope_bf16(
|
||||
x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream,
|
||||
);
|
||||
}
|
||||
|
||||
/// Raw GEMV dispatch (BF16, M=1). Caller must provide fp32 accumulator buffer.
|
||||
pub unsafe fn gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void,
|
||||
y_fp32_buf: *mut c_void, k: i32, n: i32, stream: *mut c_void) {
|
||||
pub unsafe fn gemv_bf16(
|
||||
x: *const c_void,
|
||||
w: *const c_void,
|
||||
y_bf16: *mut c_void,
|
||||
y_fp32_buf: *mut c_void,
|
||||
k: i32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_gemv_bf16(x, w, y_bf16, y_fp32_buf, k, n, stream);
|
||||
}
|
||||
|
||||
/// Raw decode attention dispatch.
|
||||
pub unsafe fn decode_attention_bf16(q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
kv_len: i32, head_dim: i32,
|
||||
scale: f32, stream: *mut c_void) {
|
||||
launch_decode_attention_bf16(q, k, v, o, batch, num_q_heads, num_kv_heads, kv_len, head_dim, scale, 1, stream);
|
||||
pub unsafe fn decode_attention_bf16(
|
||||
q: *const c_void,
|
||||
k: *const c_void,
|
||||
v: *const c_void,
|
||||
o: *mut c_void,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
kv_len: i32,
|
||||
head_dim: i32,
|
||||
scale: f32,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_decode_attention_bf16(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
batch,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
kv_len,
|
||||
head_dim,
|
||||
scale,
|
||||
1,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
// cuBLAS FFI
|
||||
|
||||
@@ -1,12 +1,25 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_embedding_f32(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
|
||||
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
|
||||
fn launch_embedding_f32(
|
||||
table: *const c_void,
|
||||
token_ids: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden_size: i32,
|
||||
vocab_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_embedding_bf16(
|
||||
table: *const c_void,
|
||||
token_ids: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden_size: i32,
|
||||
vocab_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// Embedding lookup: table[token_ids[i]] for each i.
|
||||
@@ -19,8 +32,14 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
||||
let hidden_size = table.shape()[1];
|
||||
let num_tokens = token_ids.len();
|
||||
let vocab_size = table.shape()[0];
|
||||
assert!(num_tokens <= i32::MAX as usize, "too many tokens for i32 kernel param");
|
||||
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||
assert!(
|
||||
num_tokens <= i32::MAX as usize,
|
||||
"too many tokens for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
hidden_size <= i32::MAX as usize,
|
||||
"hidden_size too large for i32 kernel param"
|
||||
);
|
||||
|
||||
// Upload token_ids to GPU
|
||||
let ids_bytes = unsafe {
|
||||
@@ -29,26 +48,51 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
||||
num_tokens * std::mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
let mut ids_gpu = xserv_cuda::allocator::cached_alloc(ids_bytes.len()).expect("alloc token_ids");
|
||||
let mut ids_gpu =
|
||||
xserv_cuda::allocator::cached_alloc(ids_bytes.len()).expect("alloc token_ids");
|
||||
ids_gpu.copy_from_host(ids_bytes).unwrap();
|
||||
|
||||
for &tid in token_ids {
|
||||
assert!((tid as usize) < vocab_size, "token_id {tid} out of bounds (vocab_size={vocab_size})");
|
||||
assert!(
|
||||
(tid as usize) < vocab_size,
|
||||
"token_id {tid} out of bounds (vocab_size={vocab_size})"
|
||||
);
|
||||
}
|
||||
|
||||
embedding_device_ids(table, ids_gpu.as_ptr() as *const c_void, num_tokens)
|
||||
}
|
||||
|
||||
/// Embedding lookup with token ids already on the GPU (u32, [num_tokens]).
|
||||
/// Used by the CUDA-graph decode path, where ids live in a persistent device
|
||||
/// buffer updated outside the captured region (no bounds check possible here).
|
||||
pub fn embedding_device_ids(table: &Tensor, ids_gpu: *const c_void, num_tokens: usize) -> Tensor {
|
||||
assert_eq!(table.ndim(), 2);
|
||||
assert!(table.is_contiguous());
|
||||
assert!(matches!(table.device(), Device::Cuda(_)));
|
||||
let hidden_size = table.shape()[1];
|
||||
let vocab_size = table.shape()[0];
|
||||
|
||||
let out = Tensor::empty(&[num_tokens, hidden_size], table.dtype(), table.device());
|
||||
|
||||
unsafe {
|
||||
match table.dtype() {
|
||||
DType::F32 => launch_embedding_f32(
|
||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
||||
table.data_ptr() as _,
|
||||
ids_gpu,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(),
|
||||
num_tokens as i32,
|
||||
hidden_size as i32,
|
||||
vocab_size as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_embedding_bf16(
|
||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
||||
table.data_ptr() as _,
|
||||
ids_gpu,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(),
|
||||
num_tokens as i32,
|
||||
hidden_size as i32,
|
||||
vocab_size as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for embedding"),
|
||||
}
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
use std::cell::RefCell;
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::error::{self, Result};
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_cuda::error::{self, Result};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
|
||||
|
||||
// GEMV: single-kernel, no FP32 temp buffer needed
|
||||
unsafe extern "C" {
|
||||
fn launch_gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, y_fp32_buf: *mut c_void, k: i32, n: i32, stream: *mut c_void);
|
||||
fn launch_gemv_bf16(
|
||||
x: *const c_void,
|
||||
w: *const c_void,
|
||||
y_bf16: *mut c_void,
|
||||
y_fp32_buf: *mut c_void,
|
||||
k: i32,
|
||||
n: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
@@ -20,10 +28,42 @@ pub enum GemmBackend {
|
||||
|
||||
// --- FFI: custom CUDA kernels ---
|
||||
unsafe extern "C" {
|
||||
fn launch_gemm_naive_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_naive_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_tiled_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_tiled_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_naive_f32(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
c: *mut c_void,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gemm_naive_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
c: *mut c_void,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gemm_tiled_f32(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
c: *mut c_void,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_gemm_tiled_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
c: *mut c_void,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
// --- FFI: cuBLAS ---
|
||||
@@ -46,25 +86,46 @@ unsafe extern "C" {
|
||||
fn cublasSetWorkspace_v2(handle: CublasHandle, workspace: *mut c_void, size: usize) -> i32;
|
||||
fn cublasGemmEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32, transb: i32,
|
||||
m: i32, n: i32, k: i32,
|
||||
transa: i32,
|
||||
transb: i32,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void, a_type: i32, lda: i32,
|
||||
b: *const c_void, b_type: i32, ldb: i32,
|
||||
a: *const c_void,
|
||||
a_type: i32,
|
||||
lda: i32,
|
||||
b: *const c_void,
|
||||
b_type: i32,
|
||||
ldb: i32,
|
||||
beta: *const c_void,
|
||||
c: *mut c_void, c_type: i32, ldc: i32,
|
||||
c: *mut c_void,
|
||||
c_type: i32,
|
||||
ldc: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
) -> i32;
|
||||
fn cublasGemmStridedBatchedEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32, transb: i32,
|
||||
m: i32, n: i32, k: i32,
|
||||
transa: i32,
|
||||
transb: i32,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void, a_type: i32, lda: i32, stride_a: i64,
|
||||
b: *const c_void, b_type: i32, ldb: i32, stride_b: i64,
|
||||
a: *const c_void,
|
||||
a_type: i32,
|
||||
lda: i32,
|
||||
stride_a: i64,
|
||||
b: *const c_void,
|
||||
b_type: i32,
|
||||
ldb: i32,
|
||||
stride_b: i64,
|
||||
beta: *const c_void,
|
||||
c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64,
|
||||
c: *mut c_void,
|
||||
c_type: i32,
|
||||
ldc: i32,
|
||||
stride_c: i64,
|
||||
batch_count: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
@@ -89,9 +150,16 @@ impl CublasContext {
|
||||
// set, so we keep the GpuBuffer in this struct.
|
||||
let mut workspace = GpuBuffer::alloc(CUBLAS_WORKSPACE_BYTES)?;
|
||||
error::check(unsafe {
|
||||
cublasSetWorkspace_v2(handle, workspace.as_mut_ptr() as *mut c_void, CUBLAS_WORKSPACE_BYTES)
|
||||
cublasSetWorkspace_v2(
|
||||
handle,
|
||||
workspace.as_mut_ptr() as *mut c_void,
|
||||
CUBLAS_WORKSPACE_BYTES,
|
||||
)
|
||||
})?;
|
||||
Ok(Self { handle, _workspace: Some(workspace) })
|
||||
Ok(Self {
|
||||
handle,
|
||||
_workspace: Some(workspace),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,9 +191,7 @@ where
|
||||
|
||||
/// Get the thread-local cuBLAS handle for use with dispatch module.
|
||||
pub fn cublas_handle() -> CublasHandle {
|
||||
CUBLAS_CTX.with(|cell| {
|
||||
cell.borrow().handle
|
||||
})
|
||||
CUBLAS_CTX.with(|cell| cell.borrow().handle)
|
||||
}
|
||||
|
||||
/// Matrix multiplication: C = A @ B
|
||||
@@ -136,8 +202,14 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
assert_eq!(b.ndim(), 2);
|
||||
assert_eq!(a.shape()[1], b.shape()[0], "inner dimension mismatch");
|
||||
assert_eq!(a.dtype(), b.dtype(), "dtype mismatch");
|
||||
assert!(a.is_contiguous() && b.is_contiguous(), "matmul requires contiguous tensors");
|
||||
assert!(matches!(a.device(), Device::Cuda(_)), "matmul requires GPU tensors");
|
||||
assert!(
|
||||
a.is_contiguous() && b.is_contiguous(),
|
||||
"matmul requires contiguous tensors"
|
||||
);
|
||||
assert!(
|
||||
matches!(a.device(), Device::Cuda(_)),
|
||||
"matmul requires GPU tensors"
|
||||
);
|
||||
|
||||
let m = a.shape()[0];
|
||||
let k = a.shape()[1];
|
||||
@@ -151,35 +223,66 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
let a_ptr = a.data_ptr() as *const c_void;
|
||||
let b_ptr = b.data_ptr() as *const c_void;
|
||||
let c_ptr = c.data_ptr() as *mut c_void;
|
||||
let null_stream = std::ptr::null_mut();
|
||||
let null_stream = xserv_cuda::current_stream_raw();
|
||||
|
||||
match backend {
|
||||
GemmBackend::Naive => {
|
||||
unsafe {
|
||||
match dtype {
|
||||
DType::F32 => launch_gemm_naive_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
DType::BF16 => launch_gemm_naive_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
_ => panic!("unsupported dtype for naive GEMM"),
|
||||
}
|
||||
GemmBackend::Naive => unsafe {
|
||||
match dtype {
|
||||
DType::F32 => launch_gemm_naive_f32(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
null_stream,
|
||||
),
|
||||
DType::BF16 => launch_gemm_naive_bf16(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
null_stream,
|
||||
),
|
||||
_ => panic!("unsupported dtype for naive GEMM"),
|
||||
}
|
||||
}
|
||||
GemmBackend::Tiled => {
|
||||
unsafe {
|
||||
match dtype {
|
||||
DType::F32 => launch_gemm_tiled_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
DType::BF16 => launch_gemm_tiled_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
_ => panic!("unsupported dtype for tiled GEMM"),
|
||||
}
|
||||
},
|
||||
GemmBackend::Tiled => unsafe {
|
||||
match dtype {
|
||||
DType::F32 => launch_gemm_tiled_f32(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
null_stream,
|
||||
),
|
||||
DType::BF16 => launch_gemm_tiled_bf16(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
null_stream,
|
||||
),
|
||||
_ => panic!("unsupported dtype for tiled GEMM"),
|
||||
}
|
||||
}
|
||||
},
|
||||
GemmBackend::CuBlas => {
|
||||
if m == 1 && dtype == DType::BF16 && n >= 256 {
|
||||
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(n * 4).unwrap();
|
||||
unsafe {
|
||||
launch_gemv_bf16(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
fp32_buf.as_mut_ptr() as *mut c_void,
|
||||
k as i32, n as i32,
|
||||
k as i32,
|
||||
n as i32,
|
||||
null_stream,
|
||||
);
|
||||
}
|
||||
@@ -197,16 +300,26 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
cublasSetStream_v2(handle, null_stream);
|
||||
error::check(cublasGemmEx(
|
||||
handle,
|
||||
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
n as i32, m as i32, k as i32,
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
n as i32,
|
||||
m as i32,
|
||||
k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b_ptr, b_type, n as i32,
|
||||
a_ptr, a_type, k as i32,
|
||||
b_ptr,
|
||||
b_type,
|
||||
n as i32,
|
||||
a_ptr,
|
||||
a_type,
|
||||
k as i32,
|
||||
&beta as *const f32 as *const c_void,
|
||||
c_ptr, c_type, n as i32,
|
||||
c_ptr,
|
||||
c_type,
|
||||
n as i32,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
-1,
|
||||
)).expect("cuBLAS GEMM failed");
|
||||
))
|
||||
.expect("cuBLAS GEMM failed");
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -260,21 +373,34 @@ pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
let stride_c = (m * n) as i64;
|
||||
|
||||
with_cublas(|handle| unsafe {
|
||||
cublasSetStream_v2(handle, std::ptr::null_mut());
|
||||
cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
|
||||
// Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major)
|
||||
error::check(cublasGemmStridedBatchedEx(
|
||||
handle,
|
||||
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
n as i32, m as i32, k as i32,
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
n as i32,
|
||||
m as i32,
|
||||
k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b.data_ptr() as _, b_type, n as i32, stride_b,
|
||||
a.data_ptr() as _, a_type, k as i32, stride_a,
|
||||
b.data_ptr() as _,
|
||||
b_type,
|
||||
n as i32,
|
||||
stride_b,
|
||||
a.data_ptr() as _,
|
||||
a_type,
|
||||
k as i32,
|
||||
stride_a,
|
||||
&beta as *const f32 as *const c_void,
|
||||
c.data_ptr() as *mut c_void, c_type, n as i32, stride_c,
|
||||
c.data_ptr() as *mut c_void,
|
||||
c_type,
|
||||
n as i32,
|
||||
stride_c,
|
||||
batch as i32,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
-1,
|
||||
)).expect("cuBLAS batched GEMM failed");
|
||||
))
|
||||
.expect("cuBLAS batched GEMM failed");
|
||||
});
|
||||
c
|
||||
}
|
||||
|
||||
@@ -2,10 +2,26 @@ use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_layernorm_f32(x: *const c_void, gamma: *const c_void, beta: *const c_void,
|
||||
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_layernorm_bf16(x: *const c_void, gamma: *const c_void, beta: *const c_void,
|
||||
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_layernorm_f32(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
beta: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_layernorm_bf16(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
beta: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor {
|
||||
@@ -17,21 +33,37 @@ pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor
|
||||
assert_eq!(beta.shape(), &[hidden_size]);
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||
assert!(
|
||||
rows <= i32::MAX as usize,
|
||||
"too many rows for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
hidden_size <= i32::MAX as usize,
|
||||
"hidden_size too large for i32 kernel param"
|
||||
);
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_layernorm_f32(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
||||
x.data_ptr() as _,
|
||||
gamma.data_ptr() as _,
|
||||
beta.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_layernorm_bf16(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
||||
x.data_ptr() as _,
|
||||
gamma.data_ptr() as _,
|
||||
beta.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for layernorm"),
|
||||
}
|
||||
|
||||
@@ -12,16 +12,22 @@ pub mod rope;
|
||||
pub mod softmax;
|
||||
pub mod transpose;
|
||||
|
||||
pub use activation::{add, gelu, gpt_oss_glu, mul, scale, silu, silu_mul};
|
||||
pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu_mul};
|
||||
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
|
||||
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
|
||||
pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16};
|
||||
pub use embedding::embedding;
|
||||
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
||||
pub use attention::{
|
||||
attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention,
|
||||
paged_decode_attention_sinks, reshape_and_cache_batched_bf16, reshape_and_cache_bf16,
|
||||
};
|
||||
pub use embedding::{embedding, embedding_device_ids};
|
||||
pub use gemm::{GemmBackend, batched_matmul, matmul};
|
||||
pub use layernorm::layernorm;
|
||||
pub use rmsnorm::{add_rmsnorm, rmsnorm};
|
||||
pub use rope::{rope_inplace, RopeCache};
|
||||
pub use rope::{RopeCache, rope_inplace, rope_inplace_device_pos};
|
||||
pub use softmax::softmax;
|
||||
pub use transpose::{
|
||||
merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu,
|
||||
transpose_for_rope_gpu, transpose_from_rope_gpu,
|
||||
};
|
||||
|
||||
/// Register GPU kernels with the tensor crate. Call once at startup.
|
||||
pub fn init() {
|
||||
|
||||
@@ -1,43 +1,113 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tensor::{DType, Tensor};
|
||||
|
||||
use crate::gemm::{cublas_handle, CublasHandle};
|
||||
use crate::gemm::{CublasHandle, cublas_handle};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_moe_topk_softmax_bf16(
|
||||
router_logits: *const c_void,
|
||||
topk_ids: *mut c_void, topk_weights: *mut c_void,
|
||||
num_tokens: i32, num_experts: i32, top_k: i32,
|
||||
topk_ids: *mut c_void,
|
||||
topk_weights: *mut c_void,
|
||||
num_tokens: i32,
|
||||
num_experts: i32,
|
||||
top_k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_replicate_bf16(
|
||||
x: *const c_void, x_rep: *mut c_void,
|
||||
num_tokens: i32, hidden: i32, local_experts: i32,
|
||||
x: *const c_void,
|
||||
x_rep: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden: i32,
|
||||
local_experts: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_bias_add_3d_bf16(
|
||||
x: *mut c_void, bias: *const c_void,
|
||||
batch: i32, num_tokens: i32, dim: i32,
|
||||
x: *mut c_void,
|
||||
bias: *const c_void,
|
||||
batch: i32,
|
||||
num_tokens: i32,
|
||||
dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_weighted_sum_bf16(
|
||||
expert_out: *const c_void,
|
||||
topk_ids: *const c_void, topk_weights: *const c_void,
|
||||
topk_ids: *const c_void,
|
||||
topk_weights: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32, hidden: i32, top_k: i32,
|
||||
expert_start: i32, local_experts: i32,
|
||||
num_tokens: i32,
|
||||
hidden: i32,
|
||||
top_k: i32,
|
||||
expert_start: i32,
|
||||
local_experts: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
fn launch_moe_sparse_gemv_fp8_bf16(
|
||||
x: *const c_void,
|
||||
w: *const c_void,
|
||||
w_scales: *const c_void,
|
||||
bias: *const c_void,
|
||||
topk_ids: *const c_void,
|
||||
y: *mut c_void,
|
||||
num_tokens: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
top_k: i32,
|
||||
expert_start: i32,
|
||||
local_experts: i32,
|
||||
x_per_slot: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_sparse_gemv_mxfp4_bf16(
|
||||
x: *const c_void,
|
||||
w_packed: *const c_void,
|
||||
w_scales: *const c_void,
|
||||
bias: *const c_void,
|
||||
topk_ids: *const c_void,
|
||||
y: *mut c_void,
|
||||
num_tokens: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
top_k: i32,
|
||||
expert_start: i32,
|
||||
local_experts: i32,
|
||||
x_per_slot: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_weighted_sum_sparse_bf16(
|
||||
down: *const c_void,
|
||||
topk_ids: *const c_void,
|
||||
topk_weights: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32,
|
||||
hidden: i32,
|
||||
top_k: i32,
|
||||
expert_start: i32,
|
||||
local_experts: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
fn cublasGemmStridedBatchedEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32, transb: i32,
|
||||
m: i32, n: i32, k: i32,
|
||||
transa: i32,
|
||||
transb: i32,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void, a_type: i32, lda: i32, stride_a: i64,
|
||||
b: *const c_void, b_type: i32, ldb: i32, stride_b: i64,
|
||||
a: *const c_void,
|
||||
a_type: i32,
|
||||
lda: i32,
|
||||
stride_a: i64,
|
||||
b: *const c_void,
|
||||
b_type: i32,
|
||||
ldb: i32,
|
||||
stride_b: i64,
|
||||
beta: *const c_void,
|
||||
c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64,
|
||||
c: *mut c_void,
|
||||
c_type: i32,
|
||||
ldc: i32,
|
||||
stride_c: i64,
|
||||
batch_count: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
@@ -65,6 +135,9 @@ pub fn moe_topk_softmax(
|
||||
let num_tokens = router_logits.shape()[0];
|
||||
assert_eq!(router_logits.shape()[1], num_experts);
|
||||
|
||||
// NOTE: topk_ids actually holds i32 expert indices; DType has no I32, so
|
||||
// this is a raw 4-byte buffer mislabeled F32. Never read it as floats —
|
||||
// all consumers (weighted-sum / sparse GEMV kernels) cast to int*.
|
||||
let topk_ids = Tensor::empty(&[num_tokens, top_k], DType::F32, router_logits.device());
|
||||
let topk_weights = Tensor::empty(&[num_tokens, top_k], DType::F32, router_logits.device());
|
||||
|
||||
@@ -73,8 +146,10 @@ pub fn moe_topk_softmax(
|
||||
router_logits.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *mut c_void,
|
||||
topk_weights.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, num_experts as i32, top_k as i32,
|
||||
std::ptr::null_mut(),
|
||||
num_tokens as i32,
|
||||
num_experts as i32,
|
||||
top_k as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -88,14 +163,20 @@ pub fn moe_replicate(x: &Tensor, local_experts: usize) -> Tensor {
|
||||
assert!(x.is_contiguous());
|
||||
let num_tokens = x.shape()[0];
|
||||
let hidden = x.shape()[1];
|
||||
let out = Tensor::empty(&[local_experts, num_tokens, hidden], DType::BF16, x.device());
|
||||
let out = Tensor::empty(
|
||||
&[local_experts, num_tokens, hidden],
|
||||
DType::BF16,
|
||||
x.device(),
|
||||
);
|
||||
|
||||
unsafe {
|
||||
launch_moe_replicate_bf16(
|
||||
x.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden as i32, local_experts as i32,
|
||||
std::ptr::null_mut(),
|
||||
num_tokens as i32,
|
||||
hidden as i32,
|
||||
local_experts as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -117,8 +198,10 @@ pub fn moe_bias_add_3d(x: &Tensor, bias: &Tensor) {
|
||||
launch_moe_bias_add_3d_bf16(
|
||||
x.data_ptr() as *mut c_void,
|
||||
bias.data_ptr() as *const c_void,
|
||||
batch as i32, num_tokens as i32, dim as i32,
|
||||
std::ptr::null_mut(),
|
||||
batch as i32,
|
||||
num_tokens as i32,
|
||||
dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -149,15 +232,171 @@ pub fn moe_weighted_sum(
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
topk_weights.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden as i32, top_k as i32,
|
||||
expert_start as i32, local_experts as i32,
|
||||
std::ptr::null_mut(),
|
||||
num_tokens as i32,
|
||||
hidden as i32,
|
||||
top_k as i32,
|
||||
expert_start as i32,
|
||||
local_experts as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
/// Sparse MoE GEMV (FP8 W8A16): compute only the routed experts.
|
||||
///
|
||||
/// x: [num_tokens, K] BF16 (x_per_slot=false, gate_up) or
|
||||
/// [num_tokens * top_k, K] BF16 (x_per_slot=true, down)
|
||||
/// w_fp8_t: [local_experts, N, K] FP8E4M3 (transposed weight layout)
|
||||
/// w_scales: [local_experts] F32 per-expert scalar scales
|
||||
/// bias: [local_experts, N] BF16 (fused into the epilogue)
|
||||
/// topk_ids: [num_tokens, top_k] i32 global expert ids (GPU)
|
||||
///
|
||||
/// Returns y [num_tokens, top_k, N] BF16. Slots routed to experts NOT
|
||||
/// owned by this rank are left UNWRITTEN (uninitialized memory) — the
|
||||
/// consumer must skip them (see moe_weighted_sum_sparse).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn moe_sparse_gemv_fp8(
|
||||
x: &Tensor,
|
||||
w_fp8_t: &Tensor,
|
||||
w_scales: &Tensor,
|
||||
bias: &Tensor,
|
||||
topk_ids: &Tensor,
|
||||
num_tokens: usize,
|
||||
top_k: usize,
|
||||
expert_start: usize,
|
||||
local_experts: usize,
|
||||
x_per_slot: bool,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous());
|
||||
assert_eq!(w_fp8_t.dtype(), DType::FP8E4M3);
|
||||
let n = w_fp8_t.shape()[1];
|
||||
let k = w_fp8_t.shape()[2];
|
||||
// The kernel reads weights as uint4 (16 FP8 values per lane) and would
|
||||
// silently skip a K%16 tail.
|
||||
assert_eq!(k % 16, 0, "sparse FP8 GEMV requires K % 16 == 0, got {k}");
|
||||
assert_eq!(x.shape()[x.ndim() - 1], k);
|
||||
assert_eq!(
|
||||
x.shape()[0],
|
||||
if x_per_slot {
|
||||
num_tokens * top_k
|
||||
} else {
|
||||
num_tokens
|
||||
}
|
||||
);
|
||||
|
||||
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_moe_sparse_gemv_fp8_bf16(
|
||||
x.data_ptr() as *const c_void,
|
||||
w_fp8_t.data_ptr() as *const c_void,
|
||||
w_scales.data_ptr() as *const c_void,
|
||||
bias.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
y.data_ptr() as *mut c_void,
|
||||
num_tokens as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
top_k as i32,
|
||||
expert_start as i32,
|
||||
local_experts as i32,
|
||||
x_per_slot as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
y
|
||||
}
|
||||
|
||||
/// Sparse MoE GEMV (MXFP4 W4A16): same contract as moe_sparse_gemv_fp8,
|
||||
/// with packed 4-bit weights [E, N, K/2] + UE8M0 block scales [E, N, K/32].
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn moe_sparse_gemv_mxfp4(
|
||||
x: &Tensor,
|
||||
w_packed: &Tensor,
|
||||
w_scales: &Tensor,
|
||||
bias: &Tensor,
|
||||
topk_ids: &Tensor,
|
||||
num_tokens: usize,
|
||||
top_k: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
expert_start: usize,
|
||||
local_experts: usize,
|
||||
x_per_slot: bool,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous());
|
||||
// 32-element MXFP4 blocks, read as uint4 (32 nibbles) per lane.
|
||||
assert_eq!(k % 32, 0, "sparse MXFP4 GEMV requires K % 32 == 0, got {k}");
|
||||
assert_eq!(x.shape()[x.ndim() - 1], k);
|
||||
assert_eq!(
|
||||
x.shape()[0],
|
||||
if x_per_slot {
|
||||
num_tokens * top_k
|
||||
} else {
|
||||
num_tokens
|
||||
}
|
||||
);
|
||||
|
||||
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_moe_sparse_gemv_mxfp4_bf16(
|
||||
x.data_ptr() as *const c_void,
|
||||
w_packed.data_ptr() as *const c_void,
|
||||
w_scales.data_ptr() as *const c_void,
|
||||
bias.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
y.data_ptr() as *mut c_void,
|
||||
num_tokens as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
top_k as i32,
|
||||
expert_start as i32,
|
||||
local_experts as i32,
|
||||
x_per_slot as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
y
|
||||
}
|
||||
|
||||
/// Weighted sum over the slot axis of the sparse GEMV output.
|
||||
///
|
||||
/// down: [num_tokens, top_k, hidden] BF16 (non-local slots uninitialized
|
||||
/// and skipped, never multiplied by zero — NaN * 0 = NaN).
|
||||
pub fn moe_weighted_sum_sparse(
|
||||
down: &Tensor,
|
||||
topk_ids: &Tensor,
|
||||
topk_weights: &Tensor,
|
||||
expert_start: usize,
|
||||
local_experts: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(down.ndim(), 3);
|
||||
assert_eq!(down.dtype(), DType::BF16);
|
||||
let num_tokens = down.shape()[0];
|
||||
let top_k = down.shape()[1];
|
||||
let hidden = down.shape()[2];
|
||||
|
||||
let out = Tensor::empty(&[num_tokens, hidden], DType::BF16, down.device());
|
||||
unsafe {
|
||||
launch_moe_weighted_sum_sparse_bf16(
|
||||
down.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
topk_weights.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32,
|
||||
hidden as i32,
|
||||
top_k as i32,
|
||||
expert_start as i32,
|
||||
local_experts as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Strided batched GEMM for MoE expert forward.
|
||||
/// C[b] = A[b] @ B[b] for b in 0..batch
|
||||
///
|
||||
@@ -202,16 +441,28 @@ pub fn batched_gemm_strided(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
|
||||
let handle = cublas_handle();
|
||||
unsafe {
|
||||
cublasSetStream_v2(handle, std::ptr::null_mut());
|
||||
cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
|
||||
let status = cublasGemmStridedBatchedEx(
|
||||
handle,
|
||||
0, 0, // CUBLAS_OP_N, CUBLAS_OP_N
|
||||
n as i32, m as i32, k as i32,
|
||||
0,
|
||||
0, // CUBLAS_OP_N, CUBLAS_OP_N
|
||||
n as i32,
|
||||
m as i32,
|
||||
k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b.data_ptr() as *const c_void, CUDA_R_16BF, n as i32, stride_b,
|
||||
a.data_ptr() as *const c_void, CUDA_R_16BF, k as i32, stride_a,
|
||||
b.data_ptr() as *const c_void,
|
||||
CUDA_R_16BF,
|
||||
n as i32,
|
||||
stride_b,
|
||||
a.data_ptr() as *const c_void,
|
||||
CUDA_R_16BF,
|
||||
k as i32,
|
||||
stride_a,
|
||||
&beta as *const f32 as *const c_void,
|
||||
c.data_ptr() as *mut c_void, CUDA_R_16BF, n as i32, stride_c,
|
||||
c.data_ptr() as *mut c_void,
|
||||
CUDA_R_16BF,
|
||||
n as i32,
|
||||
stride_c,
|
||||
batch as i32,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT,
|
||||
|
||||
@@ -13,30 +13,46 @@ unsafe extern "C" {
|
||||
src: *const c_void,
|
||||
scales: *const c_void,
|
||||
dst: *mut c_void,
|
||||
num_experts: i32, rows: i32, cols: i32,
|
||||
num_experts: i32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_quantize_bf16_to_fp8e4m3_rowwise(
|
||||
src: *const c_void,
|
||||
dst: *mut c_void,
|
||||
scales: *mut c_void,
|
||||
num_rows: i32, cols: i32,
|
||||
num_rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_rowwise_scale_moe_bf16(
|
||||
data: *mut c_void,
|
||||
a_scales: *const c_void,
|
||||
b_scales: *const c_void,
|
||||
num_rows: i32, cols: i32, tokens: i32,
|
||||
num_rows: i32,
|
||||
cols: i32,
|
||||
tokens: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_batched_gemv_mxfp4_bf16(
|
||||
x: *const c_void, w_packed: *const c_void, w_scales: *const c_void, y: *mut c_void,
|
||||
e: i32, n: i32, k: i32, stream: *mut c_void,
|
||||
x: *const c_void,
|
||||
w_packed: *const c_void,
|
||||
w_scales: *const c_void,
|
||||
y: *mut c_void,
|
||||
e: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_dequant_mxfp4_to_bf16_t(
|
||||
w_packed: *const c_void, w_scales: *const c_void, out: *mut c_void,
|
||||
e: i32, n: i32, k: i32, stream: *mut c_void,
|
||||
w_packed: *const c_void,
|
||||
w_scales: *const c_void,
|
||||
out: *mut c_void,
|
||||
e: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -66,34 +82,68 @@ struct CublasLtMatmulHeuristicResult {
|
||||
unsafe extern "C" {
|
||||
fn cublasLtCreate(handle: *mut CublasLtHandle) -> i32;
|
||||
fn cublasLtDestroy(handle: CublasLtHandle) -> i32;
|
||||
fn cublasLtMatmulDescCreate(desc: *mut CublasLtMatmulDesc, compute_type: i32, scale_type: i32) -> i32;
|
||||
fn cublasLtMatmulDescCreate(
|
||||
desc: *mut CublasLtMatmulDesc,
|
||||
compute_type: i32,
|
||||
scale_type: i32,
|
||||
) -> i32;
|
||||
fn cublasLtMatmulDescDestroy(desc: CublasLtMatmulDesc) -> i32;
|
||||
fn cublasLtMatmulDescSetAttribute(desc: CublasLtMatmulDesc, attr: i32, buf: *const c_void, size: usize) -> i32;
|
||||
fn cublasLtMatrixLayoutCreate(layout: *mut CublasLtMatrixLayout, dtype: i32, rows: u64, cols: u64, ld: i64) -> i32;
|
||||
fn cublasLtMatmulDescSetAttribute(
|
||||
desc: CublasLtMatmulDesc,
|
||||
attr: i32,
|
||||
buf: *const c_void,
|
||||
size: usize,
|
||||
) -> i32;
|
||||
fn cublasLtMatrixLayoutCreate(
|
||||
layout: *mut CublasLtMatrixLayout,
|
||||
dtype: i32,
|
||||
rows: u64,
|
||||
cols: u64,
|
||||
ld: i64,
|
||||
) -> i32;
|
||||
fn cublasLtMatrixLayoutDestroy(layout: CublasLtMatrixLayout) -> i32;
|
||||
fn cublasLtMatrixLayoutSetAttribute(layout: CublasLtMatrixLayout, attr: i32, buf: *const c_void, size: usize) -> i32;
|
||||
fn cublasLtMatrixLayoutSetAttribute(
|
||||
layout: CublasLtMatrixLayout,
|
||||
attr: i32,
|
||||
buf: *const c_void,
|
||||
size: usize,
|
||||
) -> i32;
|
||||
fn cublasLtMatmulPreferenceCreate(pref: *mut CublasLtMatmulPreference) -> i32;
|
||||
fn cublasLtMatmulPreferenceDestroy(pref: CublasLtMatmulPreference) -> i32;
|
||||
fn cublasLtMatmulPreferenceSetAttribute(pref: CublasLtMatmulPreference, attr: i32, buf: *const c_void, size: usize) -> i32;
|
||||
fn cublasLtMatmulPreferenceSetAttribute(
|
||||
pref: CublasLtMatmulPreference,
|
||||
attr: i32,
|
||||
buf: *const c_void,
|
||||
size: usize,
|
||||
) -> i32;
|
||||
fn cublasLtMatmulAlgoGetHeuristic(
|
||||
handle: CublasLtHandle, desc: CublasLtMatmulDesc,
|
||||
a_layout: CublasLtMatrixLayout, b_layout: CublasLtMatrixLayout,
|
||||
c_layout: CublasLtMatrixLayout, d_layout: CublasLtMatrixLayout,
|
||||
handle: CublasLtHandle,
|
||||
desc: CublasLtMatmulDesc,
|
||||
a_layout: CublasLtMatrixLayout,
|
||||
b_layout: CublasLtMatrixLayout,
|
||||
c_layout: CublasLtMatrixLayout,
|
||||
d_layout: CublasLtMatrixLayout,
|
||||
pref: CublasLtMatmulPreference,
|
||||
requested: i32,
|
||||
results: *mut CublasLtMatmulHeuristicResult,
|
||||
found: *mut i32,
|
||||
) -> i32;
|
||||
fn cublasLtMatmul(
|
||||
handle: CublasLtHandle, desc: CublasLtMatmulDesc,
|
||||
handle: CublasLtHandle,
|
||||
desc: CublasLtMatmulDesc,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void, a_layout: CublasLtMatrixLayout,
|
||||
b: *const c_void, b_layout: CublasLtMatrixLayout,
|
||||
a: *const c_void,
|
||||
a_layout: CublasLtMatrixLayout,
|
||||
b: *const c_void,
|
||||
b_layout: CublasLtMatrixLayout,
|
||||
beta: *const c_void,
|
||||
c: *const c_void, c_layout: CublasLtMatrixLayout,
|
||||
d: *mut c_void, d_layout: CublasLtMatrixLayout,
|
||||
c: *const c_void,
|
||||
c_layout: CublasLtMatrixLayout,
|
||||
d: *mut c_void,
|
||||
d_layout: CublasLtMatrixLayout,
|
||||
algo: *const CublasLtMatmulAlgo,
|
||||
workspace: *mut c_void, workspace_size: usize,
|
||||
workspace: *mut c_void,
|
||||
workspace_size: usize,
|
||||
stream: *mut c_void,
|
||||
) -> i32;
|
||||
}
|
||||
@@ -107,17 +157,11 @@ const CUDA_R_8F_E4M3: i32 = 28;
|
||||
// MatmulDesc attributes
|
||||
const CUBLASLT_MATMUL_DESC_A_SCALE_POINTER: i32 = 17;
|
||||
const CUBLASLT_MATMUL_DESC_B_SCALE_POINTER: i32 = 18;
|
||||
const CUBLASLT_MATMUL_DESC_A_SCALE_MODE: i32 = 31;
|
||||
const CUBLASLT_MATMUL_DESC_B_SCALE_MODE: i32 = 32;
|
||||
|
||||
// MatrixLayout attributes
|
||||
const CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT: i32 = 5;
|
||||
const CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET: i32 = 6;
|
||||
|
||||
// Scale modes
|
||||
const CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR: i32 = 0;
|
||||
const CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F: i32 = 3;
|
||||
|
||||
// MatmulPreference attributes
|
||||
const CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES: i32 = 1;
|
||||
|
||||
@@ -159,8 +203,15 @@ impl CublasLtContext {
|
||||
assert_eq!(status, 0, "cublasLtCreate failed: {status}");
|
||||
let workspace = GpuBuffer::alloc(WORKSPACE_BYTES).expect("alloc cublasLt workspace");
|
||||
let mut one_buf = GpuBuffer::alloc(4).expect("alloc cublasLt fp8 scale");
|
||||
one_buf.copy_from_host(&1.0f32.to_le_bytes()).expect("init fp8 scale");
|
||||
Self { handle, workspace, one_buf, plans: HashMap::new() }
|
||||
one_buf
|
||||
.copy_from_host(&1.0f32.to_le_bytes())
|
||||
.expect("init fp8 scale");
|
||||
Self {
|
||||
handle,
|
||||
workspace,
|
||||
one_buf,
|
||||
plans: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the cached strided-batched plan for (m, n, k, batch), building it on
|
||||
@@ -216,10 +267,25 @@ unsafe fn build_fp8_plan(
|
||||
|
||||
// transA=T (required for FP8 on Blackwell)
|
||||
let trans_a: i32 = 1;
|
||||
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a as *const i32 as _, 4);
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
&trans_a as *const i32 as _,
|
||||
4,
|
||||
);
|
||||
let ptr_sz = std::mem::size_of::<*const c_void>();
|
||||
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
|
||||
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
|
||||
&one_ptr as *const _ as _,
|
||||
ptr_sz,
|
||||
);
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
|
||||
&one_ptr as *const _ as _,
|
||||
ptr_sz,
|
||||
);
|
||||
|
||||
// Per-expert strides in ELEMENTS for the strided-batch layout.
|
||||
let stride_a = (n * k) as i64; // weights [N, K]
|
||||
@@ -227,10 +293,18 @@ unsafe fn build_fp8_plan(
|
||||
let stride_c = (m * n) as i64; // output [M, N]
|
||||
let bc = batch as i32;
|
||||
let set_batch = |layout: CublasLtMatrixLayout, stride: i64| {
|
||||
cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&bc as *const i32 as _, 4);
|
||||
cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||
&stride as *const i64 as _, 8);
|
||||
cublasLtMatrixLayoutSetAttribute(
|
||||
layout,
|
||||
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&bc as *const i32 as _,
|
||||
4,
|
||||
);
|
||||
cublasLtMatrixLayoutSetAttribute(
|
||||
layout,
|
||||
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||
&stride as *const i64 as _,
|
||||
8,
|
||||
);
|
||||
};
|
||||
|
||||
// "A" layout (weights, transposed): physical (K, N) col-major, ld=K
|
||||
@@ -252,20 +326,39 @@ unsafe fn build_fp8_plan(
|
||||
let mut pref: CublasLtMatmulPreference = std::ptr::null_mut();
|
||||
cublasLtMatmulPreferenceCreate(&mut pref);
|
||||
let ws_bytes = WORKSPACE_BYTES as u64;
|
||||
cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_bytes as *const u64 as _, 8);
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
pref,
|
||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&ws_bytes as *const u64 as _,
|
||||
8,
|
||||
);
|
||||
|
||||
let mut heuristic = std::mem::zeroed::<CublasLtMatmulHeuristicResult>();
|
||||
let mut found: i32 = 0;
|
||||
let status = cublasLtMatmulAlgoGetHeuristic(
|
||||
handle, desc, a_layout, b_layout, c_layout, d_layout,
|
||||
pref, 1, &mut heuristic, &mut found,
|
||||
handle,
|
||||
desc,
|
||||
a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
d_layout,
|
||||
pref,
|
||||
1,
|
||||
&mut heuristic,
|
||||
&mut found,
|
||||
);
|
||||
assert!(
|
||||
status == 0 && found > 0,
|
||||
"cublasLtMatmulAlgoGetHeuristic failed for batched FP8 GEMM (m={m}, n={n}, k={k}, batch={batch}): status={status}, found={found}"
|
||||
);
|
||||
assert!(status == 0 && found > 0,
|
||||
"cublasLtMatmulAlgoGetHeuristic failed for batched FP8 GEMM (m={m}, n={n}, k={k}, batch={batch}): status={status}, found={found}");
|
||||
cublasLtMatmulPreferenceDestroy(pref);
|
||||
|
||||
Fp8Plan {
|
||||
desc, a_layout, b_layout, c_layout, d_layout,
|
||||
desc,
|
||||
a_layout,
|
||||
b_layout,
|
||||
c_layout,
|
||||
d_layout,
|
||||
algo: heuristic.algo,
|
||||
workspace_size: heuristic.workspace_size,
|
||||
}
|
||||
@@ -305,8 +398,10 @@ pub fn dequant_fp8_to_bf16(src: &Tensor, scales: &Tensor) -> Tensor {
|
||||
src.data_ptr() as *const c_void,
|
||||
scales.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_experts as i32, rows as i32, cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
num_experts as i32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -335,8 +430,9 @@ pub fn quantize_bf16_to_fp8_rowwise(src: &Tensor) -> (Tensor, Tensor) {
|
||||
src.data_ptr() as *const c_void,
|
||||
fp8_out.data_ptr() as *mut c_void,
|
||||
scales.data_ptr() as *mut c_void,
|
||||
num_rows as i32, cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
num_rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -398,23 +494,27 @@ pub fn batched_gemm_fp8(
|
||||
|
||||
unsafe {
|
||||
let status = cublasLtMatmul(
|
||||
handle, plan.desc,
|
||||
handle,
|
||||
plan.desc,
|
||||
&alpha as *const f32 as _,
|
||||
b_fp8_t.data_ptr() as *const c_void, // cuBLASLt "A" = weights
|
||||
plan.a_layout,
|
||||
a_fp8.data_ptr() as *const c_void, // cuBLASLt "B" = activations
|
||||
a_fp8.data_ptr() as *const c_void, // cuBLASLt "B" = activations
|
||||
plan.b_layout,
|
||||
&beta as *const f32 as _,
|
||||
c.data_ptr() as *const c_void, // C (unused with beta=0)
|
||||
c.data_ptr() as *const c_void, // C (unused with beta=0)
|
||||
plan.c_layout,
|
||||
c.data_ptr() as *mut c_void, // D = output
|
||||
c.data_ptr() as *mut c_void, // D = output
|
||||
plan.d_layout,
|
||||
&plan.algo,
|
||||
ws_ptr,
|
||||
plan.workspace_size,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
assert_eq!(
|
||||
status, 0,
|
||||
"batched cublasLtMatmul FP8 failed: status={status}"
|
||||
);
|
||||
assert_eq!(status, 0, "batched cublasLtMatmul FP8 failed: status={status}");
|
||||
}
|
||||
});
|
||||
|
||||
@@ -429,8 +529,10 @@ pub fn batched_gemm_fp8(
|
||||
c.data_ptr() as *mut c_void,
|
||||
a_scales.data_ptr() as *const c_void,
|
||||
b_scales.data_ptr() as *const c_void,
|
||||
total_rows, n as i32, m as i32,
|
||||
std::ptr::null_mut(),
|
||||
total_rows,
|
||||
n as i32,
|
||||
m as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -448,7 +550,13 @@ pub fn batched_gemm_fp8(
|
||||
/// w_scales: [E, N, K/32] byte tensor — UE8M0 scale per 32-element block
|
||||
///
|
||||
/// Returns: [E, N] BF16, where y[e,n] = sum_k x[e,k] * dequant(W[e,n,k]).
|
||||
pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: usize, k: usize) -> Tensor {
|
||||
pub fn batched_gemv_mxfp4(
|
||||
x: &Tensor,
|
||||
w_packed: &Tensor,
|
||||
w_scales: &Tensor,
|
||||
n: usize,
|
||||
k: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous());
|
||||
let e = x.shape()[0];
|
||||
@@ -461,8 +569,10 @@ pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: u
|
||||
w_packed.data_ptr() as *const c_void,
|
||||
w_scales.data_ptr() as *const c_void,
|
||||
y.data_ptr() as *mut c_void,
|
||||
e as i32, n as i32, k as i32,
|
||||
std::ptr::null_mut(),
|
||||
e as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
y
|
||||
@@ -470,15 +580,23 @@ pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: u
|
||||
|
||||
/// Dequantize MXFP4 weights [E, N, K] → BF16 [E, K, N] for the prefill GEMM path
|
||||
/// (the BF16 batched GEMM expects weights as [E, K, N]).
|
||||
pub fn dequant_mxfp4_to_bf16_t(w_packed: &Tensor, w_scales: &Tensor, e: usize, n: usize, k: usize) -> Tensor {
|
||||
pub fn dequant_mxfp4_to_bf16_t(
|
||||
w_packed: &Tensor,
|
||||
w_scales: &Tensor,
|
||||
e: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
) -> Tensor {
|
||||
let out = Tensor::empty(&[e, k, n], DType::BF16, w_packed.device());
|
||||
unsafe {
|
||||
launch_dequant_mxfp4_to_bf16_t(
|
||||
w_packed.data_ptr() as *const c_void,
|
||||
w_scales.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
e as i32, n as i32, k as i32,
|
||||
std::ptr::null_mut(),
|
||||
e as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
|
||||
@@ -2,13 +2,35 @@ use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_rmsnorm_f32(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
|
||||
normed_out: *mut c_void, sum_out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_rmsnorm_f32(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
gamma: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_add_rmsnorm_bf16(
|
||||
x: *const c_void,
|
||||
residual: *const c_void,
|
||||
gamma: *const c_void,
|
||||
normed_out: *mut c_void,
|
||||
sum_out: *mut c_void,
|
||||
rows: i32,
|
||||
hidden_size: i32,
|
||||
eps: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
||||
@@ -20,19 +42,35 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
||||
assert_eq!(x.dtype(), gamma.dtype());
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||
assert!(
|
||||
rows <= i32::MAX as usize,
|
||||
"too many rows for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
hidden_size <= i32::MAX as usize,
|
||||
"hidden_size too large for i32 kernel param"
|
||||
);
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_rmsnorm_f32(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
x.data_ptr() as _,
|
||||
gamma.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_rmsnorm_bf16(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
x.data_ptr() as _,
|
||||
gamma.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for rmsnorm"),
|
||||
}
|
||||
@@ -56,8 +94,14 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (
|
||||
assert_eq!(gamma.shape(), &[hidden_size]);
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||
assert!(
|
||||
rows <= i32::MAX as usize,
|
||||
"too many rows for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
hidden_size <= i32::MAX as usize,
|
||||
"hidden_size too large for i32 kernel param"
|
||||
);
|
||||
let normed_out = Tensor::empty(x.shape(), DType::BF16, x.device());
|
||||
let sum_out = Tensor::empty(x.shape(), DType::BF16, x.device());
|
||||
|
||||
@@ -71,7 +115,7 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -3,15 +3,34 @@ use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_rope_f32(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void);
|
||||
fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void);
|
||||
fn launch_compute_rope_cache(cos_cache: *mut c_void, sin_cache: *mut c_void,
|
||||
max_seq_len: i32, half_dim: i32, theta: f32,
|
||||
stream: *mut c_void);
|
||||
fn launch_rope_f32(
|
||||
x: *mut c_void,
|
||||
cos_cache: *const c_void,
|
||||
sin_cache: *const c_void,
|
||||
positions: *const c_void,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_rope_bf16(
|
||||
x: *mut c_void,
|
||||
cos_cache: *const c_void,
|
||||
sin_cache: *const c_void,
|
||||
positions: *const c_void,
|
||||
num_tokens: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_compute_rope_cache(
|
||||
cos_cache: *mut c_void,
|
||||
sin_cache: *mut c_void,
|
||||
max_seq_len: i32,
|
||||
half_dim: i32,
|
||||
theta: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
pub struct RopeCache {
|
||||
@@ -30,12 +49,21 @@ impl RopeCache {
|
||||
|
||||
unsafe {
|
||||
launch_compute_rope_cache(
|
||||
cos.as_mut_ptr() as _, sin.as_mut_ptr() as _,
|
||||
max_seq_len as i32, half_dim as i32, theta, std::ptr::null_mut(),
|
||||
cos.as_mut_ptr() as _,
|
||||
sin.as_mut_ptr() as _,
|
||||
max_seq_len as i32,
|
||||
half_dim as i32,
|
||||
theta,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
Self { cos, sin, max_seq_len, half_dim }
|
||||
Self {
|
||||
cos,
|
||||
sin,
|
||||
max_seq_len,
|
||||
half_dim,
|
||||
}
|
||||
}
|
||||
|
||||
/// YaRN (Yet another RoPE extensioN) RoPE cache. Applies frequency-dependent
|
||||
@@ -68,8 +96,8 @@ impl RopeCache {
|
||||
let mut inv_freq = vec![0.0f64; half_dim];
|
||||
for i in 0..half_dim {
|
||||
let pos_freq = theta.powf((2 * i) as f64 / dim);
|
||||
let inv_freq_extrapolation = 1.0 / pos_freq; // original
|
||||
let inv_freq_interpolation = 1.0 / (factor * pos_freq); // scaled
|
||||
let inv_freq_extrapolation = 1.0 / pos_freq; // original
|
||||
let inv_freq_interpolation = 1.0 / (factor * pos_freq); // scaled
|
||||
|
||||
// Linear ramp: 0 where we keep original, 1 where we interpolate
|
||||
let ramp = if (high - low).abs() < 0.001 {
|
||||
@@ -101,16 +129,19 @@ impl RopeCache {
|
||||
let nbytes = total * std::mem::size_of::<f32>();
|
||||
let mut cos = GpuBuffer::alloc(nbytes).expect("alloc yarn cos_cache");
|
||||
let mut sin = GpuBuffer::alloc(nbytes).expect("alloc yarn sin_cache");
|
||||
let cos_bytes = unsafe {
|
||||
std::slice::from_raw_parts(cos_host.as_ptr() as *const u8, nbytes)
|
||||
};
|
||||
let sin_bytes = unsafe {
|
||||
std::slice::from_raw_parts(sin_host.as_ptr() as *const u8, nbytes)
|
||||
};
|
||||
let cos_bytes =
|
||||
unsafe { std::slice::from_raw_parts(cos_host.as_ptr() as *const u8, nbytes) };
|
||||
let sin_bytes =
|
||||
unsafe { std::slice::from_raw_parts(sin_host.as_ptr() as *const u8, nbytes) };
|
||||
cos.copy_from_host(cos_bytes).unwrap();
|
||||
sin.copy_from_host(sin_bytes).unwrap();
|
||||
|
||||
Self { cos, sin, max_seq_len, half_dim }
|
||||
Self {
|
||||
cos,
|
||||
sin,
|
||||
max_seq_len,
|
||||
half_dim,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,24 +164,46 @@ pub fn rope_inplace(x: &Tensor, cache: &RopeCache, positions: &[u32]) {
|
||||
num_tokens * std::mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
let mut pos_gpu = xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions");
|
||||
let mut pos_gpu =
|
||||
xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions");
|
||||
pos_gpu.copy_from_host(pos_bytes).unwrap();
|
||||
|
||||
rope_inplace_device_pos(x, cache, pos_gpu.as_ptr() as *const c_void);
|
||||
}
|
||||
|
||||
/// RoPE in-place with positions already on the GPU (u32, [num_tokens]).
|
||||
/// Used by the CUDA-graph decode path, where the position lives in a
|
||||
/// persistent device buffer updated outside the captured region.
|
||||
pub fn rope_inplace_device_pos(x: &Tensor, cache: &RopeCache, pos_gpu: *const c_void) {
|
||||
assert_eq!(x.ndim(), 3);
|
||||
assert!(x.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let num_tokens = x.shape()[0];
|
||||
let num_heads = x.shape()[1];
|
||||
let head_dim = x.shape()[2];
|
||||
assert_eq!(head_dim / 2, cache.half_dim);
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_rope_f32(
|
||||
x.data_ptr() as *mut c_void,
|
||||
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
||||
pos_gpu.as_ptr() as _,
|
||||
num_tokens as i32, num_heads as i32, head_dim as i32,
|
||||
std::ptr::null_mut(),
|
||||
cache.cos.as_ptr() as _,
|
||||
cache.sin.as_ptr() as _,
|
||||
pos_gpu,
|
||||
num_tokens as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_rope_bf16(
|
||||
x.data_ptr() as *mut c_void,
|
||||
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
||||
pos_gpu.as_ptr() as _,
|
||||
num_tokens as i32, num_heads as i32, head_dim as i32,
|
||||
std::ptr::null_mut(),
|
||||
cache.cos.as_ptr() as _,
|
||||
cache.sin.as_ptr() as _,
|
||||
pos_gpu,
|
||||
num_tokens as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for rope"),
|
||||
}
|
||||
|
||||
@@ -2,8 +2,20 @@ use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_softmax_f32(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
|
||||
fn launch_softmax_bf16(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
|
||||
fn launch_softmax_f32(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_softmax_bf16(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// Softmax along the last dimension.
|
||||
@@ -14,19 +26,31 @@ pub fn softmax(x: &Tensor) -> Tensor {
|
||||
|
||||
let cols = *x.shape().last().unwrap();
|
||||
let rows = x.numel() / cols;
|
||||
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||
assert!(cols <= i32::MAX as usize, "cols too large for i32 kernel param");
|
||||
assert!(
|
||||
rows <= i32::MAX as usize,
|
||||
"too many rows for i32 kernel param"
|
||||
);
|
||||
assert!(
|
||||
cols <= i32::MAX as usize,
|
||||
"cols too large for i32 kernel param"
|
||||
);
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_softmax_f32(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32, std::ptr::null_mut(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_softmax_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32, std::ptr::null_mut(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for softmax"),
|
||||
}
|
||||
|
||||
@@ -2,19 +2,79 @@ use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_reshape_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_merge_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_repeat_kv_bf16(inp: *const c_void, out: *mut c_void, kv_heads: i32, n_rep: i32, seq_len: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_strided_copy_bf16(inp: *const c_void, out: *mut c_void, numel: i32, ndim: i32,
|
||||
shape0: i32, shape1: i32, shape2: i32, shape3: i32,
|
||||
in_stride0: i32, in_stride1: i32, in_stride2: i32, in_stride3: i32,
|
||||
in_offset: i32, stream: *mut c_void);
|
||||
fn launch_strided_copy_f32(inp: *const c_void, out: *mut c_void, numel: i32, ndim: i32,
|
||||
shape0: i32, shape1: i32, shape2: i32, shape3: i32,
|
||||
in_stride0: i32, in_stride1: i32, in_stride2: i32, in_stride3: i32,
|
||||
in_offset: i32, stream: *mut c_void);
|
||||
fn launch_reshape_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_merge_heads_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_transpose_hsd_to_shd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_transpose_shd_to_hsd_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
seq_len: i32,
|
||||
num_heads: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_repeat_kv_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
kv_heads: i32,
|
||||
n_rep: i32,
|
||||
seq_len: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_strided_copy_bf16(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
numel: i32,
|
||||
ndim: i32,
|
||||
shape0: i32,
|
||||
shape1: i32,
|
||||
shape2: i32,
|
||||
shape3: i32,
|
||||
in_stride0: i32,
|
||||
in_stride1: i32,
|
||||
in_stride2: i32,
|
||||
in_stride3: i32,
|
||||
in_offset: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_strided_copy_f32(
|
||||
inp: *const c_void,
|
||||
out: *mut c_void,
|
||||
numel: i32,
|
||||
ndim: i32,
|
||||
shape0: i32,
|
||||
shape1: i32,
|
||||
shape2: i32,
|
||||
shape3: i32,
|
||||
in_stride0: i32,
|
||||
in_stride1: i32,
|
||||
in_stride2: i32,
|
||||
in_stride3: i32,
|
||||
in_offset: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// [S, H*D] → [1, H, S, D] on GPU (BF16)
|
||||
@@ -24,8 +84,12 @@ pub fn reshape_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim:
|
||||
let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_reshape_heads_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
seq_len as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -39,36 +103,58 @@ pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
|
||||
let out = Tensor::empty(&[seq_len, hidden], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_merge_heads_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
seq_len as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// [1, H, S, D] → [S, H, D] for RoPE on GPU (BF16)
|
||||
pub fn transpose_for_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
pub fn transpose_for_rope_gpu(
|
||||
x: &Tensor,
|
||||
seq_len: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(&[seq_len, num_heads, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_transpose_hsd_to_shd_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
seq_len as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// [S, H, D] → [1, H, S, D] after RoPE on GPU (BF16)
|
||||
pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
pub fn transpose_from_rope_gpu(
|
||||
x: &Tensor,
|
||||
seq_len: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_transpose_shd_to_hsd_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
seq_len as i32,
|
||||
num_heads as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -76,7 +162,9 @@ pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, hea
|
||||
|
||||
/// [1, KV_H, S, D] → [1, KV_H*n_rep, S, D] on GPU (BF16)
|
||||
pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
|
||||
if n_rep == 1 { return x.clone(); }
|
||||
if n_rep == 1 {
|
||||
return x.clone();
|
||||
}
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let kv_heads = x.shape()[1];
|
||||
@@ -86,8 +174,13 @@ pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
|
||||
let out = Tensor::empty(&[1, new_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_repeat_kv_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
kv_heads as i32, n_rep as i32, seq_len as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
x.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
kv_heads as i32,
|
||||
n_rep as i32,
|
||||
seq_len as i32,
|
||||
head_dim as i32,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -122,20 +215,41 @@ pub fn strided_to_contiguous_gpu(x: &Tensor) -> Tensor {
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::BF16 => launch_strided_copy_bf16(
|
||||
storage_ptr as _, out.data_ptr() as *mut c_void,
|
||||
numel as i32, ndim as i32,
|
||||
shape4[0], shape4[1], shape4[2], shape4[3],
|
||||
strides4[0], strides4[1], strides4[2], strides4[3],
|
||||
in_offset, std::ptr::null_mut(),
|
||||
storage_ptr as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
numel as i32,
|
||||
ndim as i32,
|
||||
shape4[0],
|
||||
shape4[1],
|
||||
shape4[2],
|
||||
shape4[3],
|
||||
strides4[0],
|
||||
strides4[1],
|
||||
strides4[2],
|
||||
strides4[3],
|
||||
in_offset,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::F32 => launch_strided_copy_f32(
|
||||
storage_ptr as _, out.data_ptr() as *mut c_void,
|
||||
numel as i32, ndim as i32,
|
||||
shape4[0], shape4[1], shape4[2], shape4[3],
|
||||
strides4[0], strides4[1], strides4[2], strides4[3],
|
||||
in_offset, std::ptr::null_mut(),
|
||||
storage_ptr as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
numel as i32,
|
||||
ndim as i32,
|
||||
shape4[0],
|
||||
shape4[1],
|
||||
shape4[2],
|
||||
shape4[3],
|
||||
strides4[0],
|
||||
strides4[1],
|
||||
strides4[2],
|
||||
strides4[3],
|
||||
in_offset,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!(
|
||||
"strided_to_contiguous_gpu: unsupported dtype {:?}",
|
||||
x.dtype()
|
||||
),
|
||||
_ => panic!("strided_to_contiguous_gpu: unsupported dtype {:?}", x.dtype()),
|
||||
}
|
||||
}
|
||||
out
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
|
||||
fn init() {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
}
|
||||
|
||||
fn cpu_attention(q: &[f32], k: &[f32], v: &[f32],
|
||||
batch: usize, heads: usize, q_len: usize, kv_len: usize, head_dim: usize,
|
||||
causal: bool) -> Vec<f32> {
|
||||
fn cpu_attention(
|
||||
q: &[f32],
|
||||
k: &[f32],
|
||||
v: &[f32],
|
||||
batch: usize,
|
||||
heads: usize,
|
||||
q_len: usize,
|
||||
kv_len: usize,
|
||||
head_dim: usize,
|
||||
causal: bool,
|
||||
) -> Vec<f32> {
|
||||
let mut out = vec![0.0f32; batch * heads * q_len * head_dim];
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
@@ -70,8 +80,13 @@ fn check_close(a: &[f32], b: &[f32], atol: f32, name: &str) {
|
||||
let mut max_err = 0.0f32;
|
||||
for (i, (x, y)) in a.iter().zip(b).enumerate() {
|
||||
let err = (x - y).abs();
|
||||
if err > max_err { max_err = err; }
|
||||
assert!(err <= atol, "{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}");
|
||||
if err > max_err {
|
||||
max_err = err;
|
||||
}
|
||||
assert!(
|
||||
err <= atol,
|
||||
"{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}"
|
||||
);
|
||||
}
|
||||
println!("{name}: max_err = {max_err:.6e}");
|
||||
}
|
||||
@@ -105,7 +120,9 @@ fn test_batched_matmul() {
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
let mut s = 0.0f32;
|
||||
for kk in 0..k { s += a_cpu[i * k + kk] * b_cpu[kk * n + j]; }
|
||||
for kk in 0..k {
|
||||
s += a_cpu[i * k + kk] * b_cpu[kk * n + j];
|
||||
}
|
||||
expected[i * n + j] = s;
|
||||
}
|
||||
}
|
||||
@@ -116,7 +133,10 @@ fn test_batched_matmul() {
|
||||
#[test]
|
||||
fn test_attention_no_causal() {
|
||||
init();
|
||||
let b = 1; let h = 2; let s = 8; let d = 16;
|
||||
let b = 1;
|
||||
let h = 2;
|
||||
let s = 8;
|
||||
let d = 16;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
@@ -126,13 +146,21 @@ fn test_attention_no_causal() {
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, false).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-4, "attention_no_causal");
|
||||
check_close(
|
||||
out.as_slice::<f32>(),
|
||||
&expected,
|
||||
1e-4,
|
||||
"attention_no_causal",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_causal() {
|
||||
init();
|
||||
let b = 1; let h = 2; let s = 16; let d = 32;
|
||||
let b = 1;
|
||||
let h = 2;
|
||||
let s = 16;
|
||||
let d = 32;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
@@ -148,7 +176,10 @@ fn test_attention_causal() {
|
||||
#[test]
|
||||
fn test_attention_causal_larger() {
|
||||
init();
|
||||
let b = 2; let h = 4; let s = 64; let d = 64;
|
||||
let b = 2;
|
||||
let h = 4;
|
||||
let s = 64;
|
||||
let d = 64;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
@@ -158,18 +189,28 @@ fn test_attention_causal_larger() {
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-2, "attention_causal_larger");
|
||||
check_close(
|
||||
out.as_slice::<f32>(),
|
||||
&expected,
|
||||
1e-2,
|
||||
"attention_causal_larger",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_causal_first_row_sees_only_first_token() {
|
||||
init();
|
||||
let b = 1; let h = 1; let s = 4; let d = 8;
|
||||
let b = 1;
|
||||
let h = 1;
|
||||
let s = 4;
|
||||
let d = 8;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data: Vec<f32> = (0..s * d).map(|i| {
|
||||
if i < d { 1.0 } else { 0.0 } // only first V row is nonzero
|
||||
}).collect();
|
||||
let v_data: Vec<f32> = (0..s * d)
|
||||
.map(|i| {
|
||||
if i < d { 1.0 } else { 0.0 } // only first V row is nonzero
|
||||
})
|
||||
.collect();
|
||||
|
||||
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
@@ -181,7 +222,11 @@ fn test_attention_causal_first_row_sees_only_first_token() {
|
||||
// output[0] should be exactly V[0] = [1, 1, 1, ...1]
|
||||
let result = out.as_slice::<f32>();
|
||||
for i in 0..d {
|
||||
assert!((result[i] - 1.0).abs() < 1e-5,
|
||||
"first row should equal V[0], got {} at dim {}", result[i], i);
|
||||
assert!(
|
||||
(result[i] - 1.0).abs() < 1e-5,
|
||||
"first row should equal V[0], got {} at dim {}",
|
||||
result[i],
|
||||
i
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use half::bf16;
|
||||
use xserv_kernels::{matmul, GemmBackend};
|
||||
use xserv_kernels::{GemmBackend, matmul};
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
|
||||
@@ -75,70 +75,110 @@ fn run_gemm_test_bf16(backend: GemmBackend, m: usize, n: usize, k: usize) {
|
||||
// --- F32 tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_small() { run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4); }
|
||||
fn test_gemm_naive_f32_small() {
|
||||
run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_medium() { run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64); }
|
||||
fn test_gemm_naive_f32_medium() {
|
||||
run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_rect() { run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48); }
|
||||
fn test_gemm_naive_f32_rect() {
|
||||
run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_small() { run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4); }
|
||||
fn test_gemm_tiled_f32_small() {
|
||||
run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_medium() { run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128); }
|
||||
fn test_gemm_tiled_f32_medium() {
|
||||
run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_rect() { run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97); }
|
||||
fn test_gemm_tiled_f32_rect() {
|
||||
run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_small() { run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4); }
|
||||
fn test_gemm_cublas_f32_small() {
|
||||
run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_medium() { run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256); }
|
||||
fn test_gemm_cublas_f32_medium() {
|
||||
run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_rect() { run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97); }
|
||||
fn test_gemm_cublas_f32_rect() {
|
||||
run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97);
|
||||
}
|
||||
|
||||
// --- BF16 tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_bf16_small() { run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4); }
|
||||
fn test_gemm_naive_bf16_small() {
|
||||
run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_bf16_medium() { run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64); }
|
||||
fn test_gemm_naive_bf16_medium() {
|
||||
run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_bf16_small() { run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4); }
|
||||
fn test_gemm_tiled_bf16_small() {
|
||||
run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_bf16_medium() { run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128); }
|
||||
fn test_gemm_tiled_bf16_medium() {
|
||||
run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4); }
|
||||
fn test_gemm_cublas_bf16_small() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256); }
|
||||
fn test_gemm_cublas_bf16_medium() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256);
|
||||
}
|
||||
|
||||
// --- Custom GEMV tests (M=1, BF16 fast path) ---
|
||||
|
||||
#[test]
|
||||
fn test_gemv_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 64, 64); }
|
||||
fn test_gemv_bf16_small() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 64, 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemv_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 256, 256); }
|
||||
fn test_gemv_bf16_medium() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 256, 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemv_bf16_4096() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 4096, 4096); }
|
||||
fn test_gemv_bf16_4096() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 4096, 4096);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemv_bf16_rect() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 512, 4096); }
|
||||
fn test_gemv_bf16_rect() {
|
||||
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 512, 4096);
|
||||
}
|
||||
|
||||
// --- Larger benchmark-style tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_1024() { run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024); }
|
||||
fn test_gemm_cublas_f32_1024() {
|
||||
run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemm_consistency_all_backends() {
|
||||
|
||||
@@ -2,7 +2,9 @@ use half::bf16;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
|
||||
fn init() {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
}
|
||||
|
||||
// --- CPU reference implementations ---
|
||||
|
||||
@@ -37,10 +39,12 @@ fn cpu_layernorm(x: &[f32], gamma: &[f32], beta: &[f32], eps: f32, hidden: usize
|
||||
|
||||
fn cpu_gelu(x: &[f32]) -> Vec<f32> {
|
||||
let sqrt_2_over_pi = 0.7978845608f32;
|
||||
x.iter().map(|&v| {
|
||||
let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
|
||||
0.5 * v * (1.0 + inner.tanh())
|
||||
}).collect()
|
||||
x.iter()
|
||||
.map(|&v| {
|
||||
let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
|
||||
0.5 * v * (1.0 + inner.tanh())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn cpu_silu(x: &[f32]) -> Vec<f32> {
|
||||
@@ -88,8 +92,13 @@ fn check_close(result: &[f32], expected: &[f32], atol: f32, name: &str) {
|
||||
let mut max_err = 0.0f32;
|
||||
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
|
||||
let err = (r - e).abs();
|
||||
if err > max_err { max_err = err; }
|
||||
assert!(err <= atol, "{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}");
|
||||
if err > max_err {
|
||||
max_err = err;
|
||||
}
|
||||
assert!(
|
||||
err <= atol,
|
||||
"{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}"
|
||||
);
|
||||
}
|
||||
println!("{name}: max_err = {max_err:.6e}");
|
||||
}
|
||||
@@ -208,13 +217,18 @@ fn test_softmax_sum_to_one() {
|
||||
init();
|
||||
let rows = 4;
|
||||
let cols = 2048;
|
||||
let data: Vec<f32> = (0..rows * cols).map(|i| ((i % 31) as f32 - 15.0) * 0.5).collect();
|
||||
let data: Vec<f32> = (0..rows * cols)
|
||||
.map(|i| ((i % 31) as f32 - 15.0) * 0.5)
|
||||
.collect();
|
||||
let x = Tensor::from_slice(&data, &[rows, cols]).to_device(Device::Cuda(0));
|
||||
let out = softmax(&x).to_device(Device::Cpu);
|
||||
let result = out.as_slice::<f32>();
|
||||
for r in 0..rows {
|
||||
let row_sum: f32 = result[r * cols..(r + 1) * cols].iter().sum();
|
||||
assert!((row_sum - 1.0).abs() < 1e-5, "softmax row {r} sum = {row_sum}");
|
||||
assert!(
|
||||
(row_sum - 1.0).abs() < 1e-5,
|
||||
"softmax row {r} sum = {row_sum}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,8 +261,10 @@ fn test_embedding_f32() {
|
||||
for i in 0..hidden {
|
||||
let expected = table_data[tid as usize * hidden + i];
|
||||
let got = result[seq_idx * hidden + i];
|
||||
assert!((got - expected).abs() < 1e-6,
|
||||
"embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}");
|
||||
assert!(
|
||||
(got - expected).abs() < 1e-6,
|
||||
"embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -270,8 +286,8 @@ fn test_rope_f32() {
|
||||
let mut expected = x_data.clone();
|
||||
cpu_rope(&mut expected, &positions, num_heads, head_dim, theta);
|
||||
|
||||
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
|
||||
.to_device(Device::Cuda(0));
|
||||
let x =
|
||||
Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim]).to_device(Device::Cuda(0));
|
||||
let cache = RopeCache::new(64, head_dim, theta);
|
||||
rope_inplace(&x, &cache, &positions);
|
||||
|
||||
@@ -292,8 +308,8 @@ fn test_rope_position_0_identity() {
|
||||
.map(|i| (i as f32 + 1.0) * 0.1)
|
||||
.collect();
|
||||
|
||||
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
|
||||
.to_device(Device::Cuda(0));
|
||||
let x =
|
||||
Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim]).to_device(Device::Cuda(0));
|
||||
let cache = RopeCache::new(64, head_dim, 10000.0);
|
||||
rope_inplace(&x, &cache, &positions);
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use xserv_distributed::{TpContext, UniqueId, get_unique_id};
|
||||
use xserv_model::{loader, GptOss, ModelConfig, PagedKVCache, BLOCK_SIZE};
|
||||
use xserv_model::{BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -23,8 +23,12 @@ fn main() {
|
||||
|
||||
eprintln!(
|
||||
"gpt-oss-20b: layers={}, hidden={}, heads={}/{} kv, experts={}, top_k={}, vocab={}",
|
||||
config.num_layers(), config.hidden(), config.num_heads(),
|
||||
config.num_kv_heads(), config.num_experts(), config.experts_per_token(),
|
||||
config.num_layers(),
|
||||
config.hidden(),
|
||||
config.num_heads(),
|
||||
config.num_kv_heads(),
|
||||
config.num_experts(),
|
||||
config.experts_per_token(),
|
||||
config.vocab_size
|
||||
);
|
||||
eprintln!("TP world={world}, max_tokens={max_tokens}");
|
||||
@@ -59,17 +63,29 @@ fn main() {
|
||||
let tp0 = Arc::new(TpContext::init(0, world, uid, 0));
|
||||
eprintln!("[rank 0] Loading weights...");
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
|
||||
eprintln!("[rank 0] Loaded {} tensors, building model...", weights.len());
|
||||
eprintln!(
|
||||
"[rank 0] Loaded {} tensors, building model...",
|
||||
weights.len()
|
||||
);
|
||||
let model = GptOss::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp0));
|
||||
let total_blocks = max_blocks_per_seq + 64;
|
||||
let mut cache = PagedKVCache::new_tp(
|
||||
&config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, 0,
|
||||
&config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
0,
|
||||
);
|
||||
eprintln!("[rank 0] Ready.");
|
||||
|
||||
// Prompt
|
||||
let prompt_arg = get_arg::<String>(&args, "--prompt");
|
||||
let prompt = prompt_arg.as_deref().unwrap_or("What is the meaning of life?");
|
||||
let prompt = prompt_arg
|
||||
.as_deref()
|
||||
.unwrap_or("What is the meaning of life?");
|
||||
let token_ids = tokenizer.encode(prompt);
|
||||
eprintln!("Prompt ({} tokens): {prompt}", token_ids.len());
|
||||
|
||||
@@ -83,11 +99,21 @@ fn main() {
|
||||
// (oracle) next token. Removes free-running compounding so it isolates
|
||||
// whether per-position logits agree with the llama.cpp trajectory.
|
||||
if let Some(forced) = get_arg::<String>(&args, "--forced") {
|
||||
let forced_ids: Vec<u32> = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect();
|
||||
let forced_ids: Vec<u32> = forced
|
||||
.split(',')
|
||||
.filter_map(|s| s.trim().parse().ok())
|
||||
.collect();
|
||||
let mut seq = token_ids.clone();
|
||||
seq.extend_from_slice(&forced_ids);
|
||||
// Workers must run the same prefill in lockstep (TP AllReduces match up).
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: seq.clone(), slot });
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Prefill {
|
||||
tokens: seq.clone(),
|
||||
slot,
|
||||
},
|
||||
);
|
||||
let logits = model.forward_prefill_paged(&seq, slot, &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
@@ -99,19 +125,31 @@ fn main() {
|
||||
// position i predicts seq[i+1]; we check the forced region
|
||||
for i in (plen - 1)..(seq.len() - 1) {
|
||||
let row = &data[i * vocab..(i + 1) * vocab];
|
||||
let argmax = row.iter().enumerate()
|
||||
let argmax = row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(j, _)| j as u32).unwrap();
|
||||
.map(|(j, _)| j as u32)
|
||||
.unwrap();
|
||||
let expected = seq[i + 1];
|
||||
let ok = argmax == expected;
|
||||
if ok { matches += 1; }
|
||||
if ok {
|
||||
matches += 1;
|
||||
}
|
||||
total += 1;
|
||||
eprintln!("pos {i}: xserv_argmax={argmax} oracle={expected} {}", if ok {"OK"} else {"DIFF"});
|
||||
eprintln!(
|
||||
"pos {i}: xserv_argmax={argmax} oracle={expected} {}",
|
||||
if ok { "OK" } else { "DIFF" }
|
||||
);
|
||||
}
|
||||
eprintln!("\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%",
|
||||
100.0 * matches as f64 / total as f64);
|
||||
eprintln!(
|
||||
"\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%",
|
||||
100.0 * matches as f64 / total as f64
|
||||
);
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
|
||||
for (h, _) in worker_handles { h.join().unwrap(); }
|
||||
for (h, _) in worker_handles {
|
||||
h.join().unwrap();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -120,8 +158,18 @@ fn main() {
|
||||
// per-position top-1 agreement bucketed by position. Localizes long-context
|
||||
// decode degradation (which prefill teacher-forcing cannot see).
|
||||
if let Some(forced) = get_arg::<String>(&args, "--forced-decode") {
|
||||
let forced_ids: Vec<u32> = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect();
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: token_ids.clone(), slot });
|
||||
let forced_ids: Vec<u32> = forced
|
||||
.split(',')
|
||||
.filter_map(|s| s.trim().parse().ok())
|
||||
.collect();
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Prefill {
|
||||
tokens: token_ids.clone(),
|
||||
slot,
|
||||
},
|
||||
);
|
||||
let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
let mut pred = sample_greedy_last(&logits); // prediction for forced[0]
|
||||
@@ -133,34 +181,55 @@ fn main() {
|
||||
matches += ok as usize;
|
||||
total += 1;
|
||||
let b = i / bucket;
|
||||
if buckets.len() <= b { buckets.push((0, 0)); }
|
||||
if buckets.len() <= b {
|
||||
buckets.push((0, 0));
|
||||
}
|
||||
buckets[b].0 += ok as usize;
|
||||
buckets[b].1 += 1;
|
||||
// Teacher-force: feed the oracle token through the decode path.
|
||||
let pos = cache.seq_len(slot);
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode {
|
||||
tokens: vec![f], positions: vec![pos], slots: vec![slot],
|
||||
});
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Decode {
|
||||
tokens: vec![f],
|
||||
positions: vec![pos],
|
||||
slots: vec![slot],
|
||||
},
|
||||
);
|
||||
let logits = model.forward_decode_paged(&[f], &[pos], &[slot], &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
pred = sample_greedy_last(&logits);
|
||||
}
|
||||
eprintln!("Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%",
|
||||
100.0 * matches as f64 / total as f64);
|
||||
eprintln!(
|
||||
"Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%",
|
||||
100.0 * matches as f64 / total as f64
|
||||
);
|
||||
for (b, (m, t)) in buckets.iter().enumerate() {
|
||||
eprintln!(" pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%",
|
||||
b * bucket, b * bucket + t, 100.0 * (*m as f64) / (*t as f64));
|
||||
eprintln!(
|
||||
" pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%",
|
||||
b * bucket,
|
||||
b * bucket + t,
|
||||
100.0 * (*m as f64) / (*t as f64)
|
||||
);
|
||||
}
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
|
||||
for (h, _) in worker_handles { h.join().unwrap(); }
|
||||
for (h, _) in worker_handles {
|
||||
h.join().unwrap();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Prefill
|
||||
let t0 = Instant::now();
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill {
|
||||
tokens: token_ids.clone(), slot,
|
||||
});
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Prefill {
|
||||
tokens: token_ids.clone(),
|
||||
slot,
|
||||
},
|
||||
);
|
||||
let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
let ttft = t0.elapsed();
|
||||
@@ -172,18 +241,27 @@ fn main() {
|
||||
print!("{prompt}");
|
||||
|
||||
// Decode
|
||||
let mut decoder = GraphedGptOssDecoder::new();
|
||||
let decode_start = Instant::now();
|
||||
for _ in 1..max_tokens {
|
||||
let text = tokenizer.decode(&[next]);
|
||||
print!("{text}");
|
||||
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = cache.seq_len(slot);
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode {
|
||||
tokens: vec![next], positions: vec![pos], slots: vec![slot],
|
||||
});
|
||||
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut cache);
|
||||
broadcast_cmd(
|
||||
&worker_txs,
|
||||
&worker_handles,
|
||||
WorkerCmd::Decode {
|
||||
tokens: vec![next],
|
||||
positions: vec![pos],
|
||||
slots: vec![slot],
|
||||
},
|
||||
);
|
||||
let logits = decoder.decode(&model, &[next], &[pos], &[slot], &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
|
||||
next = sample_greedy_last(&logits);
|
||||
@@ -195,13 +273,20 @@ fn main() {
|
||||
let gen_tokens = output_tokens.len();
|
||||
let full_text = tokenizer.decode(&output_tokens);
|
||||
eprintln!("\nGenerated text: {full_text}");
|
||||
eprintln!("Token IDs: {:?}", &output_tokens[..output_tokens.len().min(20)]);
|
||||
eprintln!(
|
||||
"Token IDs: {:?}",
|
||||
&output_tokens[..output_tokens.len().min(20)]
|
||||
);
|
||||
let tpot = if gen_tokens > 1 {
|
||||
decode_elapsed.as_secs_f64() * 1000.0 / (gen_tokens - 1) as f64
|
||||
} else { 0.0 };
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let tok_s = if gen_tokens > 1 {
|
||||
(gen_tokens - 1) as f64 / decode_elapsed.as_secs_f64()
|
||||
} else { 0.0 };
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
eprintln!("\n--- Performance ---");
|
||||
eprintln!("Generated: {} tokens", gen_tokens);
|
||||
@@ -221,8 +306,15 @@ fn main() {
|
||||
#[derive(Clone)]
|
||||
enum WorkerCmd {
|
||||
Register(usize),
|
||||
Prefill { tokens: Vec<u32>, slot: usize },
|
||||
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
|
||||
Prefill {
|
||||
tokens: Vec<u32>,
|
||||
slot: usize,
|
||||
},
|
||||
Decode {
|
||||
tokens: Vec<u32>,
|
||||
positions: Vec<usize>,
|
||||
slots: Vec<usize>,
|
||||
},
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
@@ -240,16 +332,25 @@ fn worker_loop(
|
||||
let tp = Arc::new(TpContext::init(rank, world, uid, rank as u32));
|
||||
eprintln!("[rank {rank}] Loading weights...");
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
|
||||
let model = GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp));
|
||||
let model =
|
||||
GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp));
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 64;
|
||||
let mut cache = PagedKVCache::new_tp(
|
||||
&config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, rank as u32,
|
||||
&config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
rank as u32,
|
||||
);
|
||||
eprintln!("[rank {rank}] Ready.");
|
||||
ack_tx.send(()).unwrap();
|
||||
|
||||
let mut decoder = GraphedGptOssDecoder::new();
|
||||
while let Ok(cmd) = rx.recv() {
|
||||
match cmd {
|
||||
WorkerCmd::Register(slot) => {
|
||||
@@ -258,8 +359,12 @@ fn worker_loop(
|
||||
WorkerCmd::Prefill { tokens, slot } => {
|
||||
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
|
||||
}
|
||||
WorkerCmd::Decode { tokens, positions, slots } => {
|
||||
let _ = model.forward_decode_paged(&tokens, &positions, &slots, &mut cache);
|
||||
WorkerCmd::Decode {
|
||||
tokens,
|
||||
positions,
|
||||
slots,
|
||||
} => {
|
||||
let _ = decoder.decode(&model, &tokens, &positions, &slots, &mut cache);
|
||||
}
|
||||
WorkerCmd::Shutdown => break,
|
||||
}
|
||||
@@ -286,20 +391,26 @@ fn wait_workers(handles: &[(std::thread::JoinHandle<()>, std::sync::mpsc::Receiv
|
||||
fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 {
|
||||
use half::bf16;
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
// GPU argmax fast path (4-byte D2H instead of the full logits row).
|
||||
if logits.dtype() == xserv_tensor::DType::BF16 && logits.is_contiguous() {
|
||||
let ids = xserv_kernels::argmax_bf16_to_host(logits);
|
||||
return *ids.last().unwrap();
|
||||
}
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
let data = logits_cpu.as_slice::<bf16>();
|
||||
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
|
||||
|
||||
last.iter().enumerate()
|
||||
last.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| {
|
||||
let af = a.1.to_f32();
|
||||
let bf = b.1.to_f32();
|
||||
af.partial_cmp(&bf).unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|(i, _)| i as u32).unwrap()
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn get_arg<T: std::str::FromStr>(args: &[String], flag: &str) -> Option<T> {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
use xserv_model::gpt2::{sample_greedy, KVCache};
|
||||
use xserv_model::{loader, GPT2, ModelConfig};
|
||||
use xserv_model::gpt2::{KVCache, sample_greedy};
|
||||
use xserv_model::{GPT2, ModelConfig, loader};
|
||||
use xserv_tensor::Device;
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -104,9 +104,15 @@ fn main() {
|
||||
|
||||
let tbt_us = if !token_times_us.is_empty() {
|
||||
token_times_us.iter().sum::<u128>() / token_times_us.len() as u128
|
||||
} else { 0 };
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let total_gen_us: u128 = ttft_us + token_times_us.iter().sum::<u128>();
|
||||
let tpot_us = if num_generated > 0 { total_gen_us / num_generated as u128 } else { 0 };
|
||||
let tpot_us = if num_generated > 0 {
|
||||
total_gen_us / num_generated as u128
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let gen_text_escaped = generated_text
|
||||
.replace('\\', "\\\\")
|
||||
@@ -124,11 +130,16 @@ fn main() {
|
||||
print!("\"ttft_us\": {ttft_us}, ");
|
||||
print!("\"tbt_us\": {tbt_us}, ");
|
||||
print!("\"tpot_us\": {tpot_us}}}");
|
||||
if i < prompts.len() - 1 { println!(","); } else { println!(); }
|
||||
if i < prompts.len() - 1 {
|
||||
println!(",");
|
||||
} else {
|
||||
println!();
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"[{}/{}] input={input_len}tok gen={num_generated}tok ttft={:.1}ms tbt={:.1}ms | {}",
|
||||
i + 1, prompts.len(),
|
||||
i + 1,
|
||||
prompts.len(),
|
||||
ttft_us as f64 / 1000.0,
|
||||
tbt_us as f64 / 1000.0,
|
||||
&generated_text.replace('\n', " ")[..generated_text.len().min(60)]
|
||||
@@ -138,12 +149,18 @@ fn main() {
|
||||
}
|
||||
|
||||
fn generate_with_cache(
|
||||
model: &GPT2, config: &ModelConfig, tokenizer: &Tokenizer,
|
||||
input_ids: &[u32], gen_tokens: usize,
|
||||
model: &GPT2,
|
||||
config: &ModelConfig,
|
||||
tokenizer: &Tokenizer,
|
||||
input_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
) -> (Vec<u32>, u128, Vec<u128>) {
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(), config.num_heads(), config.head_dim(),
|
||||
xserv_tensor::DType::F32, Device::Cuda(0),
|
||||
config.num_layers(),
|
||||
config.num_heads(),
|
||||
config.head_dim(),
|
||||
xserv_tensor::DType::F32,
|
||||
Device::Cuda(0),
|
||||
);
|
||||
|
||||
// Prefill
|
||||
@@ -163,15 +180,19 @@ fn generate_with_cache(
|
||||
let next = sample_greedy(&logits);
|
||||
token_times.push(t_start.elapsed().as_micros());
|
||||
generated.push(next);
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(generated, ttft_us, token_times)
|
||||
}
|
||||
|
||||
fn generate_no_cache(
|
||||
model: &GPT2, tokenizer: &Tokenizer,
|
||||
input_ids: &[u32], gen_tokens: usize,
|
||||
model: &GPT2,
|
||||
tokenizer: &Tokenizer,
|
||||
input_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
) -> (Vec<u32>, u128, Vec<u128>) {
|
||||
let mut all_ids = input_ids.to_vec();
|
||||
|
||||
@@ -191,7 +212,9 @@ fn generate_no_cache(
|
||||
token_times.push(t_start.elapsed().as_micros());
|
||||
all_ids.push(next);
|
||||
generated.push(next);
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(generated, ttft_us, token_times)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
use xserv_model::qwen3::sample_greedy;
|
||||
use xserv_model::{loader, DecodeGraphState, GpuKVCache, ModelConfig, Qwen3};
|
||||
use xserv_model::{DecodeGraphState, GpuKVCache, ModelConfig, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -139,18 +139,35 @@ fn main() {
|
||||
} else {
|
||||
// Replay captured graphs
|
||||
let pos = cache.seq_len() as u32;
|
||||
graph.execute(last, pos, &mut cache, &layer_ptrs, embed, config.vocab_size as i32, config.hidden() as i32);
|
||||
graph.execute(
|
||||
last,
|
||||
pos,
|
||||
&mut cache,
|
||||
&layer_ptrs,
|
||||
embed,
|
||||
config.vocab_size as i32,
|
||||
config.hidden() as i32,
|
||||
);
|
||||
cache.advance_seq_len(1);
|
||||
// Read logits from graph buffer
|
||||
let vocab_size = config.vocab_size;
|
||||
let mut logits_bytes = vec![0u8; vocab_size * 2];
|
||||
graph.logits_buffer().copy_to_host(&mut logits_bytes).unwrap();
|
||||
graph
|
||||
.logits_buffer()
|
||||
.copy_to_host(&mut logits_bytes)
|
||||
.unwrap();
|
||||
let logits_data: &[half::bf16] = unsafe {
|
||||
std::slice::from_raw_parts(logits_bytes.as_ptr() as *const half::bf16, vocab_size)
|
||||
std::slice::from_raw_parts(
|
||||
logits_bytes.as_ptr() as *const half::bf16,
|
||||
vocab_size,
|
||||
)
|
||||
};
|
||||
logits_data.iter().enumerate()
|
||||
logits_data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(idx, _)| idx as u32).unwrap()
|
||||
.map(|(idx, _)| idx as u32)
|
||||
.unwrap()
|
||||
}
|
||||
} else {
|
||||
let logits = model.forward_gpu_cache(&[last], &mut cache);
|
||||
@@ -159,16 +176,24 @@ fn main() {
|
||||
|
||||
token_times.push(t_start.elapsed().as_micros());
|
||||
generated.push(next);
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let num_generated = generated.len();
|
||||
let generated_text = tokenizer.decode(&generated);
|
||||
let tbt_us = if !token_times.is_empty() {
|
||||
token_times.iter().sum::<u128>() / token_times.len() as u128
|
||||
} else { 0 };
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let total_gen_us: u128 = ttft_us + token_times.iter().sum::<u128>();
|
||||
let tpot_us = if num_generated > 0 { total_gen_us / num_generated as u128 } else { 0 };
|
||||
let tpot_us = if num_generated > 0 {
|
||||
total_gen_us / num_generated as u128
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let gen_text_escaped = generated_text
|
||||
.replace('\\', "\\\\")
|
||||
@@ -186,13 +211,18 @@ fn main() {
|
||||
print!("\"ttft_us\": {ttft_us}, ");
|
||||
print!("\"tbt_us\": {tbt_us}, ");
|
||||
print!("\"tpot_us\": {tpot_us}}}");
|
||||
if i < prompts.len() - 1 { println!(","); } else { println!(); }
|
||||
if i < prompts.len() - 1 {
|
||||
println!(",");
|
||||
} else {
|
||||
println!();
|
||||
}
|
||||
|
||||
let display_text = generated_text.replace('\n', " ");
|
||||
let truncated: String = display_text.chars().take(60).collect();
|
||||
eprintln!(
|
||||
"[{}/{}] input={input_len}tok gen={num_generated}tok ttft={:.1}ms tbt={:.1}ms | {}",
|
||||
i + 1, prompts.len(),
|
||||
i + 1,
|
||||
prompts.len(),
|
||||
ttft_us as f64 / 1000.0,
|
||||
tbt_us as f64 / 1000.0,
|
||||
truncated
|
||||
|
||||
@@ -18,7 +18,7 @@ use std::thread;
|
||||
use std::time::Instant;
|
||||
|
||||
use xserv_model::qwen3::sample_greedy;
|
||||
use xserv_model::{loader, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -35,8 +35,13 @@ fn main() {
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let world: usize = arg(&args, "--tp").and_then(|s| s.parse().ok()).unwrap_or(1).max(1);
|
||||
let gen_tokens: usize = arg(&args, "--gen-tokens").and_then(|s| s.parse().ok()).unwrap_or(64);
|
||||
let world: usize = arg(&args, "--tp")
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
let gen_tokens: usize = arg(&args, "--gen-tokens")
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(64);
|
||||
let devices: Vec<u32> = match arg(&args, "--devices") {
|
||||
Some(s) => s.split(',').filter_map(|d| d.trim().parse().ok()).collect(),
|
||||
None => (0..world as u32).collect(),
|
||||
@@ -67,7 +72,11 @@ fn main() {
|
||||
// Tensors are not Send (their Storage holds a raw GPU pointer), so each rank
|
||||
// thread loads its own CPU copy of the weights and shards in-thread. Loading
|
||||
// is not part of the timed region.
|
||||
let id = if world > 1 { Some(xserv_distributed::get_unique_id()) } else { None };
|
||||
let id = if world > 1 {
|
||||
Some(xserv_distributed::get_unique_id())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let handles: Vec<_> = (0..world)
|
||||
.map(|rank| {
|
||||
@@ -76,7 +85,9 @@ fn main() {
|
||||
let prompt_ids = prompt_ids.clone();
|
||||
let device = devices[rank];
|
||||
thread::spawn(move || {
|
||||
run_rank(rank, world, device, id, config, model_dir, prompt_ids, gen_tokens, eos)
|
||||
run_rank(
|
||||
rank, world, device, id, config, model_dir, prompt_ids, gen_tokens, eos,
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
@@ -91,7 +102,10 @@ fn main() {
|
||||
|
||||
let results = rank0.expect("rank 0 produced no results");
|
||||
println!("\n=== TP={world} (devices {devices:?}) — Qwen3 E2E benchmark ===");
|
||||
println!("{:<45} {:>10} {:>12} {:>8}", "prompt", "TTFT(ms)", "decode tok/s", "gen");
|
||||
println!(
|
||||
"{:<45} {:>10} {:>12} {:>8}",
|
||||
"prompt", "TTFT(ms)", "decode tok/s", "gen"
|
||||
);
|
||||
let mut tps_sum = 0.0;
|
||||
for (i, r) in results.iter().enumerate() {
|
||||
let text = tokenizer.decode(&r.gen_ids).replace('\n', " ");
|
||||
@@ -99,16 +113,29 @@ fn main() {
|
||||
let p: String = prompts[i].chars().take(43).collect();
|
||||
println!(
|
||||
"{:<45} {:>10.1} {:>12.1} {:>8} | {}",
|
||||
p, r.ttft_ms, r.decode_tok_s, r.gen_ids.len(), short
|
||||
p,
|
||||
r.ttft_ms,
|
||||
r.decode_tok_s,
|
||||
r.gen_ids.len(),
|
||||
short
|
||||
);
|
||||
tps_sum += r.decode_tok_s;
|
||||
}
|
||||
println!("--- mean decode throughput: {:.1} tok/s ---", tps_sum / results.len() as f64);
|
||||
println!(
|
||||
"--- mean decode throughput: {:.1} tok/s ---",
|
||||
tps_sum / results.len() as f64
|
||||
);
|
||||
|
||||
// Machine-readable line for cross-TP correctness diffing (rank 0 token ids).
|
||||
let all_ids: Vec<String> = results
|
||||
.iter()
|
||||
.map(|r| r.gen_ids.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(","))
|
||||
.map(|r| {
|
||||
r.gen_ids
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
})
|
||||
.collect();
|
||||
println!("CORRECTNESS_IDS tp={world} {}", all_ids.join(" | "));
|
||||
}
|
||||
@@ -126,7 +153,12 @@ fn run_rank(
|
||||
) -> Option<Vec<PromptResult>> {
|
||||
// Bind this thread to its GPU and set up the TP communicator.
|
||||
let tp = if world > 1 {
|
||||
Some(Arc::new(xserv_distributed::TpContext::init(rank, world, id.unwrap(), device)))
|
||||
Some(Arc::new(xserv_distributed::TpContext::init(
|
||||
rank,
|
||||
world,
|
||||
id.unwrap(),
|
||||
device,
|
||||
)))
|
||||
} else {
|
||||
xserv_cuda::device::set_device(device).unwrap();
|
||||
None
|
||||
@@ -142,7 +174,14 @@ fn run_rank(
|
||||
let max_blocks_per_seq = max_seq.div_ceil(BLOCK_SIZE);
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
let mut cache = PagedKVCache::new_tp(
|
||||
&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, device,
|
||||
&config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
1,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
device,
|
||||
);
|
||||
|
||||
// Warmup (init kernels / allocator / NCCL channels) — not timed.
|
||||
@@ -177,12 +216,20 @@ fn run_rank(
|
||||
steps += 1;
|
||||
}
|
||||
let decode_s = t1.elapsed().as_secs_f64();
|
||||
let decode_tok_s = if steps > 0 && decode_s > 0.0 { steps as f64 / decode_s } else { 0.0 };
|
||||
let decode_tok_s = if steps > 0 && decode_s > 0.0 {
|
||||
steps as f64 / decode_s
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
cache.free_sequence(0);
|
||||
|
||||
if rank == 0 {
|
||||
out.push(PromptResult { gen_ids: generated, ttft_ms, decode_tok_s });
|
||||
out.push(PromptResult {
|
||||
gen_ids: generated,
|
||||
ttft_ms,
|
||||
decode_tok_s,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,5 +237,8 @@ fn run_rank(
|
||||
}
|
||||
|
||||
fn arg<'a>(args: &'a [String], flag: &str) -> Option<&'a str> {
|
||||
args.iter().position(|a| a == flag).and_then(|i| args.get(i + 1)).map(|s| s.as_str())
|
||||
args.iter()
|
||||
.position(|a| a == flag)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use half::bf16;
|
||||
use std::path::PathBuf;
|
||||
use xserv_model::{loader, KVCache, ModelConfig, Qwen3};
|
||||
use xserv_model::{KVCache, ModelConfig, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
use half::bf16;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
@@ -20,8 +20,11 @@ fn main() {
|
||||
eprintln!("Token IDs: {token_ids:?}");
|
||||
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(), config.num_kv_heads(), config.head_dim(),
|
||||
DType::BF16, Device::Cuda(0),
|
||||
config.num_layers(),
|
||||
config.num_kv_heads(),
|
||||
config.head_dim(),
|
||||
DType::BF16,
|
||||
Device::Cuda(0),
|
||||
);
|
||||
let logits = model.forward_with_cache(&token_ids, &mut cache);
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
@@ -31,7 +34,9 @@ fn main() {
|
||||
|
||||
// Print top-20 logits for the last position
|
||||
let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
let mut indexed: Vec<(usize, f32)> = last_row.iter().enumerate()
|
||||
let mut indexed: Vec<(usize, f32)> = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (i, v.to_f32()))
|
||||
.collect();
|
||||
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use std::io::{self, IsTerminal, Read, Write};
|
||||
use std::path::PathBuf;
|
||||
|
||||
use std::sync::{mpsc, Arc};
|
||||
use std::sync::{Arc, mpsc};
|
||||
use std::thread;
|
||||
|
||||
use xserv_model::{loader, sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
|
||||
use xserv_model::{
|
||||
BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, SamplingParams,
|
||||
loader, sample, sample_greedy_penalized,
|
||||
};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -14,13 +17,24 @@ enum ChatModel {
|
||||
}
|
||||
|
||||
impl ChatModel {
|
||||
fn forward_prefill_paged(&self, tokens: &[u32], slot: usize, cache: &mut PagedKVCache) -> xserv_tensor::Tensor {
|
||||
fn forward_prefill_paged(
|
||||
&self,
|
||||
tokens: &[u32],
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> xserv_tensor::Tensor {
|
||||
match self {
|
||||
ChatModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache),
|
||||
ChatModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache),
|
||||
}
|
||||
}
|
||||
fn forward_decode_paged(&self, tokens: &[u32], positions: &[usize], slots: &[usize], cache: &mut PagedKVCache) -> xserv_tensor::Tensor {
|
||||
fn forward_decode_paged(
|
||||
&self,
|
||||
tokens: &[u32],
|
||||
positions: &[usize],
|
||||
slots: &[usize],
|
||||
cache: &mut PagedKVCache,
|
||||
) -> xserv_tensor::Tensor {
|
||||
match self {
|
||||
ChatModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache),
|
||||
ChatModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache),
|
||||
@@ -33,8 +47,15 @@ impl ChatModel {
|
||||
enum TpCommand {
|
||||
Register(usize),
|
||||
Free(usize),
|
||||
Prefill { tokens: Vec<u32>, slot: usize },
|
||||
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
|
||||
Prefill {
|
||||
tokens: Vec<u32>,
|
||||
slot: usize,
|
||||
},
|
||||
Decode {
|
||||
tokens: Vec<u32>,
|
||||
positions: Vec<usize>,
|
||||
slots: Vec<usize>,
|
||||
},
|
||||
}
|
||||
|
||||
struct TpHandle {
|
||||
@@ -56,7 +77,8 @@ impl TpHandle {
|
||||
}
|
||||
|
||||
fn tp_worker_loop(
|
||||
rank: usize, world: usize,
|
||||
rank: usize,
|
||||
world: usize,
|
||||
id: xserv_distributed::UniqueId,
|
||||
model_dir: std::path::PathBuf,
|
||||
config: ModelConfig,
|
||||
@@ -64,28 +86,68 @@ fn tp_worker_loop(
|
||||
cmd_rx: mpsc::Receiver<TpCommand>,
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
) {
|
||||
let tp = Arc::new(xserv_distributed::TpContext::init(rank, world, id, rank as u32));
|
||||
let tp = Arc::new(xserv_distributed::TpContext::init(
|
||||
rank,
|
||||
world,
|
||||
id,
|
||||
rank as u32,
|
||||
));
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
|
||||
let model = if config.is_moe() {
|
||||
ChatModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp)))
|
||||
ChatModel::GptOss(GptOss::from_weights_tp(
|
||||
config.clone(),
|
||||
weights,
|
||||
rank,
|
||||
world,
|
||||
rank as u32,
|
||||
Some(tp),
|
||||
))
|
||||
} else {
|
||||
ChatModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp)))
|
||||
ChatModel::Qwen3(Qwen3::from_weights_tp(
|
||||
config.clone(),
|
||||
weights,
|
||||
rank,
|
||||
world,
|
||||
rank as u32,
|
||||
Some(tp),
|
||||
))
|
||||
};
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
let mut cache = PagedKVCache::new_tp(
|
||||
&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, rank as u32,
|
||||
&config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
1,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
rank as u32,
|
||||
);
|
||||
let mut decoder = GraphedGptOssDecoder::new();
|
||||
while let Ok(cmd) = cmd_rx.recv() {
|
||||
match cmd {
|
||||
TpCommand::Register(slot) => { let _ = cache.register_sequence(slot); }
|
||||
TpCommand::Register(slot) => {
|
||||
let _ = cache.register_sequence(slot);
|
||||
}
|
||||
TpCommand::Free(slot) => cache.free_sequence(slot),
|
||||
TpCommand::Prefill { tokens, slot } => {
|
||||
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
|
||||
}
|
||||
TpCommand::Decode { tokens, positions, slots } => {
|
||||
let _ = model.forward_decode_paged(&tokens, &positions, &slots, &mut cache);
|
||||
TpCommand::Decode {
|
||||
tokens,
|
||||
positions,
|
||||
slots,
|
||||
} => {
|
||||
let _ = chat_decode(
|
||||
&model,
|
||||
&mut decoder,
|
||||
&tokens,
|
||||
&positions,
|
||||
&slots,
|
||||
&mut cache,
|
||||
);
|
||||
}
|
||||
}
|
||||
let _ = ack_tx.send(());
|
||||
@@ -220,7 +282,13 @@ fn read_line_edited(prompt: &str) -> Line {
|
||||
}
|
||||
b => {
|
||||
// UTF-8 multi-byte: read the continuation bytes for this char.
|
||||
let extra = if b >= 0xF0 { 3 } else if b >= 0xE0 { 2 } else { 1 };
|
||||
let extra = if b >= 0xF0 {
|
||||
3
|
||||
} else if b >= 0xE0 {
|
||||
2
|
||||
} else {
|
||||
1
|
||||
};
|
||||
let mut bytes = vec![b];
|
||||
let mut cont = [0u8; 1];
|
||||
let mut ok = true;
|
||||
@@ -274,7 +342,8 @@ fn main() {
|
||||
if world > 1 {
|
||||
assert!(
|
||||
config.num_kv_heads() % world == 0,
|
||||
"num_kv_heads {} not divisible by tp {world}", config.num_kv_heads()
|
||||
"num_kv_heads {} not divisible by tp {world}",
|
||||
config.num_kv_heads()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -289,7 +358,16 @@ fn main() {
|
||||
let model_dir = opts.model_dir.clone();
|
||||
let config = config.clone();
|
||||
thread::spawn(move || {
|
||||
tp_worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx);
|
||||
tp_worker_loop(
|
||||
rank,
|
||||
world,
|
||||
id,
|
||||
model_dir,
|
||||
config,
|
||||
max_seq_len,
|
||||
ctx_rx,
|
||||
ack_tx,
|
||||
);
|
||||
});
|
||||
}
|
||||
eprintln!("Loading weights (tp={world})...");
|
||||
@@ -297,14 +375,37 @@ fn main() {
|
||||
let weights = loader::load_model_dir(&opts.model_dir, Device::Cpu);
|
||||
eprintln!("Loaded {} tensors", weights.len());
|
||||
let m = if is_moe {
|
||||
ChatModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp)))
|
||||
ChatModel::GptOss(GptOss::from_weights_tp(
|
||||
config.clone(),
|
||||
weights,
|
||||
0,
|
||||
world,
|
||||
0,
|
||||
Some(tp),
|
||||
))
|
||||
} else {
|
||||
ChatModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp)))
|
||||
ChatModel::Qwen3(Qwen3::from_weights_tp(
|
||||
config.clone(),
|
||||
weights,
|
||||
0,
|
||||
world,
|
||||
0,
|
||||
Some(tp),
|
||||
))
|
||||
};
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
let c = PagedKVCache::new_tp(&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, 0);
|
||||
let c = PagedKVCache::new_tp(
|
||||
&config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
1,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
0,
|
||||
);
|
||||
let h = TpHandle { cmd_txs, ack_rx };
|
||||
(m, c, Some(h))
|
||||
} else {
|
||||
@@ -321,7 +422,11 @@ fn main() {
|
||||
};
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&opts.model_dir.join("tokenizer.json"));
|
||||
if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); }
|
||||
let mut decoder = GraphedGptOssDecoder::new();
|
||||
if let Some(h) = &tp_handle {
|
||||
h.send(TpCommand::Register(SLOT));
|
||||
h.wait();
|
||||
}
|
||||
cache.register_sequence(SLOT).expect("register chat slot");
|
||||
let use_color = opts.color && io::stdout().is_terminal();
|
||||
|
||||
@@ -363,11 +468,8 @@ fn main() {
|
||||
if is_moe {
|
||||
// Harmony multi-turn: re-render the whole conversation (prior
|
||||
// analysis dropped) and re-prefill into a freshly cleared slot.
|
||||
let prompt = build_conversation_gpt_oss(
|
||||
opts.system_prompt.as_deref(),
|
||||
&moe_history,
|
||||
input,
|
||||
);
|
||||
let prompt =
|
||||
build_conversation_gpt_oss(opts.system_prompt.as_deref(), &moe_history, input);
|
||||
let prompt_tokens = tokenizer.encode(&prompt);
|
||||
if prompt_tokens.is_empty() {
|
||||
continue;
|
||||
@@ -384,8 +486,17 @@ fn main() {
|
||||
print!("assistant> ");
|
||||
io::stdout().flush().unwrap();
|
||||
let (_finish, answer) = generate_with_paged_cache(
|
||||
&model, &mut cache, &tokenizer, &prompt_tokens, &opts.sampling,
|
||||
max_new_tokens, use_color, &tp_handle, is_moe, opts.enable_thinking,
|
||||
&model,
|
||||
&mut decoder,
|
||||
&mut cache,
|
||||
&tokenizer,
|
||||
&prompt_tokens,
|
||||
&opts.sampling,
|
||||
max_new_tokens,
|
||||
use_color,
|
||||
&tp_handle,
|
||||
is_moe,
|
||||
opts.enable_thinking,
|
||||
);
|
||||
moe_history.push((input.to_string(), answer));
|
||||
println!();
|
||||
@@ -421,6 +532,7 @@ fn main() {
|
||||
io::stdout().flush().unwrap();
|
||||
let (finish, _answer) = generate_with_paged_cache(
|
||||
&model,
|
||||
&mut decoder,
|
||||
&mut cache,
|
||||
&tokenizer,
|
||||
&prompt_tokens,
|
||||
@@ -433,10 +545,24 @@ fn main() {
|
||||
);
|
||||
match finish {
|
||||
Finish::Stop { token_id } => {
|
||||
append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id, &tp_handle);
|
||||
append_after_stop(
|
||||
&model,
|
||||
&mut cache,
|
||||
&tokenizer,
|
||||
max_seq_len,
|
||||
token_id,
|
||||
&tp_handle,
|
||||
);
|
||||
}
|
||||
Finish::Length => {
|
||||
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n", &tp_handle);
|
||||
append_text_to_cache(
|
||||
&model,
|
||||
&mut cache,
|
||||
&tokenizer,
|
||||
max_seq_len,
|
||||
"<|im_end|>\n",
|
||||
&tp_handle,
|
||||
);
|
||||
}
|
||||
}
|
||||
println!();
|
||||
@@ -445,9 +571,15 @@ fn main() {
|
||||
|
||||
/// Free and re-register the single chat KV slot (clears all cached context).
|
||||
fn reset_slot(cache: &mut PagedKVCache, tp: &Option<TpHandle>) {
|
||||
if let Some(h) = tp { h.send(TpCommand::Free(SLOT)); h.wait(); }
|
||||
if let Some(h) = tp {
|
||||
h.send(TpCommand::Free(SLOT));
|
||||
h.wait();
|
||||
}
|
||||
cache.free_sequence(SLOT);
|
||||
if let Some(h) = tp { h.send(TpCommand::Register(SLOT)); h.wait(); }
|
||||
if let Some(h) = tp {
|
||||
h.send(TpCommand::Register(SLOT));
|
||||
h.wait();
|
||||
}
|
||||
cache.register_sequence(SLOT).expect("register chat slot");
|
||||
}
|
||||
|
||||
@@ -585,7 +717,15 @@ fn new_paged_cache(config: &ModelConfig, max_seq_len: usize) -> PagedKVCache {
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = (max_blocks_per_seq + 1).max(2);
|
||||
// Single-slot interactive CLI: no swap pool (cpu_total_blocks = 0).
|
||||
PagedKVCache::new(config, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, 0)
|
||||
PagedKVCache::new(
|
||||
config,
|
||||
total_blocks,
|
||||
0,
|
||||
1,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
0,
|
||||
)
|
||||
}
|
||||
|
||||
fn build_turn_prompt(
|
||||
@@ -665,7 +805,10 @@ fn build_conversation_gpt_oss(
|
||||
/// civil-calendar conversion (same algorithm the server uses for strftime_now).
|
||||
fn today_ymd() -> String {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
let secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
|
||||
let secs = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
let z = (secs / 86400) as i64 + 719468;
|
||||
let era = (if z >= 0 { z } else { z - 146096 }) / 146097;
|
||||
let doe = z - era * 146097;
|
||||
@@ -679,8 +822,23 @@ fn today_ymd() -> String {
|
||||
format!("{y:04}-{m:02}-{d:02}")
|
||||
}
|
||||
|
||||
fn chat_decode(
|
||||
model: &ChatModel,
|
||||
decoder: &mut GraphedGptOssDecoder,
|
||||
tokens: &[u32],
|
||||
positions: &[usize],
|
||||
slots: &[usize],
|
||||
cache: &mut PagedKVCache,
|
||||
) -> xserv_tensor::Tensor {
|
||||
match model {
|
||||
ChatModel::GptOss(m) => decoder.decode(m, tokens, positions, slots, cache),
|
||||
ChatModel::Qwen3(_) => model.forward_decode_paged(tokens, positions, slots, cache),
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_with_paged_cache(
|
||||
model: &ChatModel,
|
||||
decoder: &mut GraphedGptOssDecoder,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_tokens: &[u32],
|
||||
@@ -691,12 +849,32 @@ fn generate_with_paged_cache(
|
||||
is_moe: bool,
|
||||
enable_thinking: bool,
|
||||
) -> (Finish, String) {
|
||||
let harmony_end_id = if is_moe { tokenizer.special_token_id("<|end|>") } else { None };
|
||||
let harmony_channel_id = if is_moe { tokenizer.special_token_id("<|channel|>") } else { None };
|
||||
let harmony_message_id = if is_moe { tokenizer.special_token_id("<|message|>") } else { None };
|
||||
let harmony_end_id = if is_moe {
|
||||
tokenizer.special_token_id("<|end|>")
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let harmony_channel_id = if is_moe {
|
||||
tokenizer.special_token_id("<|channel|>")
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let harmony_message_id = if is_moe {
|
||||
tokenizer.special_token_id("<|message|>")
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let harmony_special: Vec<u32> = if is_moe {
|
||||
["<|channel|>", "<|start|>", "<|end|>", "<|message|>", "<|return|>"]
|
||||
.iter().filter_map(|s| tokenizer.special_token_id(s)).collect()
|
||||
[
|
||||
"<|channel|>",
|
||||
"<|start|>",
|
||||
"<|end|>",
|
||||
"<|message|>",
|
||||
"<|return|>",
|
||||
]
|
||||
.iter()
|
||||
.filter_map(|s| tokenizer.special_token_id(s))
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
@@ -704,18 +882,29 @@ fn generate_with_paged_cache(
|
||||
// "analysis" channel is rendered as thinking (gray). After <|channel|>
|
||||
// we read the channel name tokens until <|message|>.
|
||||
#[derive(PartialEq, Clone, Copy)]
|
||||
enum HarmonyState { Normal, ReadingChannel, InAnalysis, InFinal }
|
||||
let mut hstate = if is_moe { HarmonyState::InFinal } else { HarmonyState::Normal };
|
||||
enum HarmonyState {
|
||||
Normal,
|
||||
ReadingChannel,
|
||||
InAnalysis,
|
||||
InFinal,
|
||||
}
|
||||
let mut hstate = if is_moe {
|
||||
HarmonyState::InFinal
|
||||
} else {
|
||||
HarmonyState::Normal
|
||||
};
|
||||
|
||||
// Off by default. A repetition penalty over a harmony stream penalizes the
|
||||
// control tokens (<|channel|>, <|message|>, <|start|>) that MUST repeat to
|
||||
// open the final channel — so a non-1.0 default makes gpt-oss stop right
|
||||
// after the analysis block, before emitting any answer. Opt in via the env
|
||||
// var if you want it for plain (non-harmony) generation.
|
||||
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY").ok()
|
||||
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1.0);
|
||||
let rep_window: usize = std::env::var("XSERV_REP_WINDOW").ok()
|
||||
let rep_window: usize = std::env::var("XSERV_REP_WINDOW")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(512);
|
||||
let mut history: Vec<u32> = Vec::new();
|
||||
@@ -729,9 +918,16 @@ fn generate_with_paged_cache(
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: prompt_tokens.to_vec(), slot: SLOT }); }
|
||||
if let Some(h) = tp {
|
||||
h.send(TpCommand::Prefill {
|
||||
tokens: prompt_tokens.to_vec(),
|
||||
slot: SLOT,
|
||||
});
|
||||
}
|
||||
let logits = model.forward_prefill_paged(prompt_tokens, SLOT, cache);
|
||||
if let Some(h) = tp { h.wait(); }
|
||||
if let Some(h) = tp {
|
||||
h.wait();
|
||||
}
|
||||
let mut next = pick(&logits, sampling, &history);
|
||||
let mut decode_buffer = Vec::new();
|
||||
let mut in_thinking = false;
|
||||
@@ -744,9 +940,17 @@ fn generate_with_paged_cache(
|
||||
|
||||
for _ in 0..max_tokens {
|
||||
let position = cache.seq_len(SLOT);
|
||||
if let Some(h) = tp { h.send(TpCommand::Decode { tokens: vec![next], positions: vec![position], slots: vec![SLOT] }); }
|
||||
let logits = model.forward_decode_paged(&[next], &[position], &[SLOT], cache);
|
||||
if let Some(h) = tp { h.wait(); }
|
||||
if let Some(h) = tp {
|
||||
h.send(TpCommand::Decode {
|
||||
tokens: vec![next],
|
||||
positions: vec![position],
|
||||
slots: vec![SLOT],
|
||||
});
|
||||
}
|
||||
let logits = chat_decode(model, decoder, &[next], &[position], &[SLOT], cache);
|
||||
if let Some(h) = tp {
|
||||
h.wait();
|
||||
}
|
||||
if tokenizer.is_eos(next) {
|
||||
print_stream_text(
|
||||
&tokenizer.flush_decode_stream(&mut decode_buffer),
|
||||
@@ -757,7 +961,10 @@ fn generate_with_paged_cache(
|
||||
print_stream_text("\n</think>\n\n", true, use_color);
|
||||
}
|
||||
io::stdout().flush().unwrap();
|
||||
return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids));
|
||||
return (
|
||||
Finish::Stop { token_id: next },
|
||||
tokenizer.decode(&answer_ids),
|
||||
);
|
||||
}
|
||||
if harmony_end_id == Some(next) {
|
||||
// <|end|> closes current segment; if in final channel, we're done
|
||||
@@ -768,7 +975,10 @@ fn generate_with_paged_cache(
|
||||
);
|
||||
if hstate == HarmonyState::InFinal {
|
||||
io::stdout().flush().unwrap();
|
||||
return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids));
|
||||
return (
|
||||
Finish::Stop { token_id: next },
|
||||
tokenizer.decode(&answer_ids),
|
||||
);
|
||||
}
|
||||
// Closing a thinking (analysis/commentary) channel: emit the </think>
|
||||
// marker so it renders like Qwen3's thinking block.
|
||||
@@ -824,7 +1034,13 @@ fn generate_with_paged_cache(
|
||||
// Analysis channel = the model's reasoning. With --think, show it as a
|
||||
// thinking block (gray if color); otherwise suppress it (answer only).
|
||||
if show_thinking {
|
||||
print_generated_token(tokenizer, next, &mut decode_buffer, &mut in_thinking, use_color);
|
||||
print_generated_token(
|
||||
tokenizer,
|
||||
next,
|
||||
&mut decode_buffer,
|
||||
&mut in_thinking,
|
||||
use_color,
|
||||
);
|
||||
io::stdout().flush().unwrap();
|
||||
}
|
||||
next = pick(&logits, sampling, &history);
|
||||
@@ -886,9 +1102,16 @@ fn append_text_to_cache(
|
||||
if tokens.is_empty() || cache.seq_len(SLOT) + tokens.len() > max_seq_len {
|
||||
return;
|
||||
}
|
||||
if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: tokens.clone(), slot: SLOT }); }
|
||||
if let Some(h) = tp {
|
||||
h.send(TpCommand::Prefill {
|
||||
tokens: tokens.clone(),
|
||||
slot: SLOT,
|
||||
});
|
||||
}
|
||||
let _ = model.forward_prefill_paged(&tokens, SLOT, cache);
|
||||
if let Some(h) = tp { h.wait(); }
|
||||
if let Some(h) = tp {
|
||||
h.wait();
|
||||
}
|
||||
}
|
||||
|
||||
fn print_generated_token(
|
||||
@@ -934,4 +1157,3 @@ fn print_stream_text(text: &str, in_thinking: bool, use_color: bool) {
|
||||
print!("{text}");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::io::{self, Write};
|
||||
use std::path::PathBuf;
|
||||
use xserv_model::{loader, KVCache, ModelConfig, PagedKVCache, BLOCK_SIZE};
|
||||
use xserv_model::{BLOCK_SIZE, KVCache, ModelConfig, PagedKVCache, loader};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -21,14 +21,21 @@ fn main() {
|
||||
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
let info = xserv_cuda::device::device_info(0).unwrap();
|
||||
eprintln!("GPU: {} ({} MB free)", info.name, info.free_memory / 1024 / 1024);
|
||||
eprintln!(
|
||||
"GPU: {} ({} MB free)",
|
||||
info.name,
|
||||
info.free_memory / 1024 / 1024
|
||||
);
|
||||
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
let model_type = config.model_type.as_deref().unwrap_or("unknown");
|
||||
eprintln!(
|
||||
"Model: {model_type}, layers={}, hidden={}, heads={}/{} kv, vocab={}",
|
||||
config.num_layers(), config.hidden(), config.num_heads(),
|
||||
config.num_kv_heads(), config.vocab_size
|
||||
config.num_layers(),
|
||||
config.hidden(),
|
||||
config.num_heads(),
|
||||
config.num_kv_heads(),
|
||||
config.vocab_size
|
||||
);
|
||||
|
||||
eprintln!("Loading weights...");
|
||||
@@ -37,7 +44,11 @@ fn main() {
|
||||
|
||||
let is_qwen3 = model_type.contains("qwen");
|
||||
let is_gpt_oss = model_type.contains("gpt_oss");
|
||||
let dtype = if is_qwen3 || is_gpt_oss { DType::BF16 } else { DType::F32 };
|
||||
let dtype = if is_qwen3 || is_gpt_oss {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
|
||||
// Build model
|
||||
enum Model {
|
||||
@@ -60,10 +71,16 @@ fn main() {
|
||||
print!("xserv> ");
|
||||
io::stdout().flush().unwrap();
|
||||
let mut input = String::new();
|
||||
if io::stdin().read_line(&mut input).unwrap() == 0 { break; }
|
||||
if io::stdin().read_line(&mut input).unwrap() == 0 {
|
||||
break;
|
||||
}
|
||||
let input = input.trim();
|
||||
if input.is_empty() { continue; }
|
||||
if input == "quit" || input == "exit" { break; }
|
||||
if input.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if input == "quit" || input == "exit" {
|
||||
break;
|
||||
}
|
||||
|
||||
let token_ids = tokenizer.encode(input);
|
||||
|
||||
@@ -73,12 +90,21 @@ fn main() {
|
||||
let max_blocks_per_seq = (max_seq + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 64;
|
||||
let mut paged_cache = PagedKVCache::new(
|
||||
&config, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, 0,
|
||||
&config,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
0,
|
||||
);
|
||||
let slot = 0;
|
||||
paged_cache.register_sequence(slot).expect("register slot");
|
||||
|
||||
let model = match &model { Model::GptOss(m) => m, _ => unreachable!() };
|
||||
let model = match &model {
|
||||
Model::GptOss(m) => m,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let logits = model.forward_prefill_paged(&token_ids, slot, &mut paged_cache);
|
||||
let mut next = sample_greedy_last(&logits);
|
||||
|
||||
@@ -90,20 +116,28 @@ fn main() {
|
||||
print!("{text}");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = paged_cache.seq_len(slot);
|
||||
let logits = model.forward_decode_paged(
|
||||
&[next], &[pos], &[slot], &mut paged_cache,
|
||||
);
|
||||
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut paged_cache);
|
||||
next = sample_greedy_last(&logits);
|
||||
}
|
||||
println!();
|
||||
paged_cache.free_sequence(slot);
|
||||
} else {
|
||||
let kv_heads = if is_qwen3 { config.num_kv_heads() } else { config.num_heads() };
|
||||
let kv_heads = if is_qwen3 {
|
||||
config.num_kv_heads()
|
||||
} else {
|
||||
config.num_heads()
|
||||
};
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(), kv_heads, config.head_dim(), dtype, Device::Cuda(0),
|
||||
config.num_layers(),
|
||||
kv_heads,
|
||||
config.head_dim(),
|
||||
dtype,
|
||||
Device::Cuda(0),
|
||||
);
|
||||
|
||||
let logits = match &model {
|
||||
@@ -125,7 +159,9 @@ fn main() {
|
||||
print!("{text}");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
|
||||
let logits = match &model {
|
||||
Model::GPT2(m) => m.forward_with_cache(&[next], &mut cache),
|
||||
@@ -151,7 +187,9 @@ fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 {
|
||||
let seq_len = logits.shape()[0];
|
||||
let data = logits_cpu.as_slice::<bf16>();
|
||||
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
last.iter().enumerate()
|
||||
last.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(i, _)| i as u32).unwrap()
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
@@ -88,23 +88,33 @@ impl ModelConfig {
|
||||
}
|
||||
|
||||
pub fn hidden(&self) -> usize {
|
||||
self.hidden_size.or(self.n_embd).expect("hidden_size or n_embd required")
|
||||
self.hidden_size
|
||||
.or(self.n_embd)
|
||||
.expect("hidden_size or n_embd required")
|
||||
}
|
||||
|
||||
pub fn num_heads(&self) -> usize {
|
||||
self.num_attention_heads.or(self.n_head).expect("num_attention_heads or n_head required")
|
||||
self.num_attention_heads
|
||||
.or(self.n_head)
|
||||
.expect("num_attention_heads or n_head required")
|
||||
}
|
||||
|
||||
pub fn num_layers(&self) -> usize {
|
||||
self.num_hidden_layers.or(self.n_layer).expect("num_hidden_layers or n_layer required")
|
||||
self.num_hidden_layers
|
||||
.or(self.n_layer)
|
||||
.expect("num_hidden_layers or n_layer required")
|
||||
}
|
||||
|
||||
pub fn max_seq_len(&self) -> usize {
|
||||
self.max_position_embeddings.or(self.n_positions).unwrap_or(2048)
|
||||
self.max_position_embeddings
|
||||
.or(self.n_positions)
|
||||
.unwrap_or(2048)
|
||||
}
|
||||
|
||||
pub fn ffn_hidden(&self) -> usize {
|
||||
self.intermediate_size.or(self.n_inner).unwrap_or(self.hidden() * 4)
|
||||
self.intermediate_size
|
||||
.or(self.n_inner)
|
||||
.unwrap_or(self.hidden() * 4)
|
||||
}
|
||||
|
||||
pub fn num_kv_heads(&self) -> usize {
|
||||
@@ -112,7 +122,8 @@ impl ModelConfig {
|
||||
}
|
||||
|
||||
pub fn head_dim(&self) -> usize {
|
||||
self.explicit_head_dim.unwrap_or_else(|| self.hidden() / self.num_heads())
|
||||
self.explicit_head_dim
|
||||
.unwrap_or_else(|| self.hidden() / self.num_heads())
|
||||
}
|
||||
|
||||
pub fn ln_eps(&self) -> f32 {
|
||||
|
||||
@@ -18,19 +18,19 @@ use crate::kv_cache::GpuKVCache;
|
||||
/// All buffers have stable GPU addresses for CUDA Graph replay.
|
||||
struct DecodeBuffers {
|
||||
// Hidden-size buffers: [1, hidden]
|
||||
x: GpuBuffer, // running hidden state
|
||||
normed: GpuBuffer, // rmsnorm output
|
||||
attn_out: GpuBuffer, // attention output [1, num_heads, 1, head_dim]
|
||||
attn_merged: GpuBuffer, // merge_heads output [1, hidden]
|
||||
o_proj: GpuBuffer, // O projection output [1, hidden]
|
||||
normed2: GpuBuffer, // post-attn norm output [1, hidden]
|
||||
sum_out: GpuBuffer, // add_rmsnorm sum output [1, hidden]
|
||||
down: GpuBuffer, // down projection output [1, hidden]
|
||||
x: GpuBuffer, // running hidden state
|
||||
normed: GpuBuffer, // rmsnorm output
|
||||
attn_out: GpuBuffer, // attention output [1, num_heads, 1, head_dim]
|
||||
attn_merged: GpuBuffer, // merge_heads output [1, hidden]
|
||||
o_proj: GpuBuffer, // O projection output [1, hidden]
|
||||
normed2: GpuBuffer, // post-attn norm output [1, hidden]
|
||||
sum_out: GpuBuffer, // add_rmsnorm sum output [1, hidden]
|
||||
down: GpuBuffer, // down projection output [1, hidden]
|
||||
|
||||
// QKV projection outputs
|
||||
q_proj: GpuBuffer, // [1, num_heads * head_dim]
|
||||
k_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
|
||||
v_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
|
||||
q_proj: GpuBuffer, // [1, num_heads * head_dim]
|
||||
k_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
|
||||
v_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
|
||||
|
||||
// Reshaped: [1, H, 1, D]
|
||||
q_reshaped: GpuBuffer,
|
||||
@@ -50,23 +50,23 @@ struct DecodeBuffers {
|
||||
k_final: GpuBuffer,
|
||||
|
||||
// FFN intermediates
|
||||
gate: GpuBuffer, // [1, intermediate]
|
||||
up: GpuBuffer, // [1, intermediate]
|
||||
silu_out: GpuBuffer, // [1, intermediate]
|
||||
gate: GpuBuffer, // [1, intermediate]
|
||||
up: GpuBuffer, // [1, intermediate]
|
||||
silu_out: GpuBuffer, // [1, intermediate]
|
||||
|
||||
// GEMV fp32 accumulators (separate per output dimension)
|
||||
fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs
|
||||
fp32_q: GpuBuffer, // for Q projection
|
||||
fp32_kv: GpuBuffer, // for K/V projection
|
||||
fp32_intermediate: GpuBuffer,// for gate/up projections
|
||||
fp32_vocab: GpuBuffer, // for lm_head
|
||||
fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs
|
||||
fp32_q: GpuBuffer, // for Q projection
|
||||
fp32_kv: GpuBuffer, // for K/V projection
|
||||
fp32_intermediate: GpuBuffer, // for gate/up projections
|
||||
fp32_vocab: GpuBuffer, // for lm_head
|
||||
|
||||
// Token ID and position (GPU-resident, updated before replay)
|
||||
token_id_gpu: GpuBuffer, // 4 bytes (u32)
|
||||
position_gpu: GpuBuffer, // 4 bytes (u32)
|
||||
token_id_gpu: GpuBuffer, // 4 bytes (u32)
|
||||
position_gpu: GpuBuffer, // 4 bytes (u32)
|
||||
|
||||
// Final output
|
||||
logits: GpuBuffer, // [1, vocab_size]
|
||||
logits: GpuBuffer, // [1, vocab_size]
|
||||
}
|
||||
|
||||
pub struct DecodeGraphState {
|
||||
@@ -199,127 +199,296 @@ impl DecodeGraphState {
|
||||
let cublas = cublas_handle();
|
||||
|
||||
// Set cuBLAS to use our stream
|
||||
unsafe { dispatch::set_cublas_stream(cublas, s); }
|
||||
unsafe {
|
||||
dispatch::set_cublas_stream(cublas, s);
|
||||
}
|
||||
|
||||
for (l, lw) in layers.iter().enumerate() {
|
||||
// === Pre-attention graph ===
|
||||
self.pre_attn_graphs[l].begin_capture(&self.stream).expect("begin pre-attn capture");
|
||||
self.pre_attn_graphs[l]
|
||||
.begin_capture(&self.stream)
|
||||
.expect("begin pre-attn capture");
|
||||
unsafe {
|
||||
// RMSNorm
|
||||
dispatch::rmsnorm_bf16(
|
||||
self.buffers.x.as_ptr() as _, lw.input_norm, self.buffers.normed.as_mut_ptr() as _,
|
||||
1, h, eps, s,
|
||||
self.buffers.x.as_ptr() as _,
|
||||
lw.input_norm,
|
||||
self.buffers.normed.as_mut_ptr() as _,
|
||||
1,
|
||||
h,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
|
||||
// Q projection (GEMV)
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _, lw.q_proj_wt, self.buffers.q_proj.as_mut_ptr() as _,
|
||||
self.buffers.normed.as_ptr() as _,
|
||||
lw.q_proj_wt,
|
||||
self.buffers.q_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_q.as_mut_ptr() as _,
|
||||
h, nh * hd, s,
|
||||
h,
|
||||
nh * hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// K projection (GEMV)
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _, lw.k_proj_wt, self.buffers.k_proj.as_mut_ptr() as _,
|
||||
self.buffers.normed.as_ptr() as _,
|
||||
lw.k_proj_wt,
|
||||
self.buffers.k_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_kv.as_mut_ptr() as _,
|
||||
h, nkv * hd, s,
|
||||
h,
|
||||
nkv * hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// V projection (GEMV)
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _, lw.v_proj_wt, self.buffers.v_proj.as_mut_ptr() as _,
|
||||
self.buffers.normed.as_ptr() as _,
|
||||
lw.v_proj_wt,
|
||||
self.buffers.v_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_kv.as_mut_ptr() as _,
|
||||
h, nkv * hd, s,
|
||||
h,
|
||||
nkv * hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// Reshape heads: [1, H*D] -> [1, H, 1, D]
|
||||
dispatch::reshape_heads_bf16(self.buffers.q_proj.as_ptr() as _, self.buffers.q_reshaped.as_mut_ptr() as _, 1, nh, hd, s);
|
||||
dispatch::reshape_heads_bf16(self.buffers.k_proj.as_ptr() as _, self.buffers.k_reshaped.as_mut_ptr() as _, 1, nkv, hd, s);
|
||||
dispatch::reshape_heads_bf16(self.buffers.v_proj.as_ptr() as _, self.buffers.v_reshaped.as_mut_ptr() as _, 1, nkv, hd, s);
|
||||
dispatch::reshape_heads_bf16(
|
||||
self.buffers.q_proj.as_ptr() as _,
|
||||
self.buffers.q_reshaped.as_mut_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::reshape_heads_bf16(
|
||||
self.buffers.k_proj.as_ptr() as _,
|
||||
self.buffers.k_reshaped.as_mut_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::reshape_heads_bf16(
|
||||
self.buffers.v_proj.as_ptr() as _,
|
||||
self.buffers.v_reshaped.as_mut_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// QK norm (head-level rmsnorm: treat [1,H,1,D] as [H, D])
|
||||
dispatch::rmsnorm_bf16(self.buffers.q_reshaped.as_ptr() as _, lw.q_norm, self.buffers.q_normed.as_mut_ptr() as _, nh, hd, eps, s);
|
||||
dispatch::rmsnorm_bf16(self.buffers.k_reshaped.as_ptr() as _, lw.k_norm, self.buffers.k_normed.as_mut_ptr() as _, nkv, hd, eps, s);
|
||||
dispatch::rmsnorm_bf16(
|
||||
self.buffers.q_reshaped.as_ptr() as _,
|
||||
lw.q_norm,
|
||||
self.buffers.q_normed.as_mut_ptr() as _,
|
||||
nh,
|
||||
hd,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
dispatch::rmsnorm_bf16(
|
||||
self.buffers.k_reshaped.as_ptr() as _,
|
||||
lw.k_norm,
|
||||
self.buffers.k_normed.as_mut_ptr() as _,
|
||||
nkv,
|
||||
hd,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
|
||||
// Transpose for RoPE: [1,H,1,D] -> [1,H,D]
|
||||
dispatch::transpose_hsd_to_shd_bf16(self.buffers.q_normed.as_ptr() as _, self.buffers.q_rope.as_mut_ptr() as _, 1, nh, hd, s);
|
||||
dispatch::transpose_hsd_to_shd_bf16(self.buffers.k_normed.as_ptr() as _, self.buffers.k_rope.as_mut_ptr() as _, 1, nkv, hd, s);
|
||||
dispatch::transpose_hsd_to_shd_bf16(
|
||||
self.buffers.q_normed.as_ptr() as _,
|
||||
self.buffers.q_rope.as_mut_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::transpose_hsd_to_shd_bf16(
|
||||
self.buffers.k_normed.as_ptr() as _,
|
||||
self.buffers.k_rope.as_mut_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// RoPE (in-place, reads position_gpu)
|
||||
dispatch::rope_bf16(self.buffers.q_rope.as_mut_ptr() as _, rope_cos, rope_sin, self.buffers.position_gpu.as_ptr() as _, 1, nh, hd, s);
|
||||
dispatch::rope_bf16(self.buffers.k_rope.as_mut_ptr() as _, rope_cos, rope_sin, self.buffers.position_gpu.as_ptr() as _, 1, nkv, hd, s);
|
||||
dispatch::rope_bf16(
|
||||
self.buffers.q_rope.as_mut_ptr() as _,
|
||||
rope_cos,
|
||||
rope_sin,
|
||||
self.buffers.position_gpu.as_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::rope_bf16(
|
||||
self.buffers.k_rope.as_mut_ptr() as _,
|
||||
rope_cos,
|
||||
rope_sin,
|
||||
self.buffers.position_gpu.as_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// Transpose back: [1,H,D] -> [1,H,1,D]
|
||||
dispatch::transpose_shd_to_hsd_bf16(self.buffers.q_rope.as_ptr() as _, self.buffers.q_final.as_mut_ptr() as _, 1, nh, hd, s);
|
||||
dispatch::transpose_shd_to_hsd_bf16(self.buffers.k_rope.as_ptr() as _, self.buffers.k_final.as_mut_ptr() as _, 1, nkv, hd, s);
|
||||
dispatch::transpose_shd_to_hsd_bf16(
|
||||
self.buffers.q_rope.as_ptr() as _,
|
||||
self.buffers.q_final.as_mut_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
dispatch::transpose_shd_to_hsd_bf16(
|
||||
self.buffers.k_rope.as_ptr() as _,
|
||||
self.buffers.k_final.as_mut_ptr() as _,
|
||||
1,
|
||||
nkv,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
}
|
||||
self.pre_attn_graphs[l].end_capture(&self.stream).expect("end pre-attn capture");
|
||||
self.pre_attn_graphs[l]
|
||||
.end_capture(&self.stream)
|
||||
.expect("end pre-attn capture");
|
||||
|
||||
// === Post-attention graph ===
|
||||
self.post_attn_graphs[l].begin_capture(&self.stream).expect("begin post-attn capture");
|
||||
self.post_attn_graphs[l]
|
||||
.begin_capture(&self.stream)
|
||||
.expect("begin post-attn capture");
|
||||
unsafe {
|
||||
// Merge heads: [1,H,1,D] -> [1, hidden]
|
||||
// attn_out is written by ungraphed attention
|
||||
dispatch::merge_heads_bf16(self.buffers.attn_out.as_ptr() as _, self.buffers.attn_merged.as_mut_ptr() as _, 1, nh, hd, s);
|
||||
dispatch::merge_heads_bf16(
|
||||
self.buffers.attn_out.as_ptr() as _,
|
||||
self.buffers.attn_merged.as_mut_ptr() as _,
|
||||
1,
|
||||
nh,
|
||||
hd,
|
||||
s,
|
||||
);
|
||||
|
||||
// O projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.attn_merged.as_ptr() as _, lw.o_proj_wt, self.buffers.o_proj.as_mut_ptr() as _,
|
||||
self.buffers.attn_merged.as_ptr() as _,
|
||||
lw.o_proj_wt,
|
||||
self.buffers.o_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_hidden.as_mut_ptr() as _,
|
||||
nh * hd, h, s,
|
||||
nh * hd,
|
||||
h,
|
||||
s,
|
||||
);
|
||||
|
||||
// Fused Add+RMSNorm: normed2 = rmsnorm(o_proj + x), sum_out = o_proj + x
|
||||
dispatch::add_rmsnorm_bf16(
|
||||
self.buffers.o_proj.as_ptr() as _, self.buffers.x.as_ptr() as _, lw.post_norm,
|
||||
self.buffers.normed2.as_mut_ptr() as _, self.buffers.sum_out.as_mut_ptr() as _,
|
||||
1, h, eps, s,
|
||||
self.buffers.o_proj.as_ptr() as _,
|
||||
self.buffers.x.as_ptr() as _,
|
||||
lw.post_norm,
|
||||
self.buffers.normed2.as_mut_ptr() as _,
|
||||
self.buffers.sum_out.as_mut_ptr() as _,
|
||||
1,
|
||||
h,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
|
||||
// Gate projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed2.as_ptr() as _, lw.gate_proj_wt, self.buffers.gate.as_mut_ptr() as _,
|
||||
self.buffers.normed2.as_ptr() as _,
|
||||
lw.gate_proj_wt,
|
||||
self.buffers.gate.as_mut_ptr() as _,
|
||||
self.buffers.fp32_intermediate.as_mut_ptr() as _,
|
||||
h, inter, s,
|
||||
h,
|
||||
inter,
|
||||
s,
|
||||
);
|
||||
|
||||
// Up projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed2.as_ptr() as _, lw.up_proj_wt, self.buffers.up.as_mut_ptr() as _,
|
||||
self.buffers.normed2.as_ptr() as _,
|
||||
lw.up_proj_wt,
|
||||
self.buffers.up.as_mut_ptr() as _,
|
||||
self.buffers.fp32_intermediate.as_mut_ptr() as _,
|
||||
h, inter, s,
|
||||
h,
|
||||
inter,
|
||||
s,
|
||||
);
|
||||
|
||||
// Fused SiLU x Mul
|
||||
dispatch::silu_mul_bf16(self.buffers.gate.as_ptr() as _, self.buffers.up.as_ptr() as _, self.buffers.silu_out.as_mut_ptr() as _, inter, s);
|
||||
dispatch::silu_mul_bf16(
|
||||
self.buffers.gate.as_ptr() as _,
|
||||
self.buffers.up.as_ptr() as _,
|
||||
self.buffers.silu_out.as_mut_ptr() as _,
|
||||
inter,
|
||||
s,
|
||||
);
|
||||
|
||||
// Down projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.silu_out.as_ptr() as _, lw.down_proj_wt, self.buffers.down.as_mut_ptr() as _,
|
||||
self.buffers.silu_out.as_ptr() as _,
|
||||
lw.down_proj_wt,
|
||||
self.buffers.down.as_mut_ptr() as _,
|
||||
self.buffers.fp32_hidden.as_mut_ptr() as _,
|
||||
inter, h, s,
|
||||
inter,
|
||||
h,
|
||||
s,
|
||||
);
|
||||
|
||||
// x = sum_out + down (residual connection for next layer)
|
||||
dispatch::add_bf16(self.buffers.sum_out.as_ptr() as _, self.buffers.down.as_ptr() as _, self.buffers.x.as_mut_ptr() as _, h, s);
|
||||
dispatch::add_bf16(
|
||||
self.buffers.sum_out.as_ptr() as _,
|
||||
self.buffers.down.as_ptr() as _,
|
||||
self.buffers.x.as_mut_ptr() as _,
|
||||
h,
|
||||
s,
|
||||
);
|
||||
}
|
||||
self.post_attn_graphs[l].end_capture(&self.stream).expect("end post-attn capture");
|
||||
self.post_attn_graphs[l]
|
||||
.end_capture(&self.stream)
|
||||
.expect("end post-attn capture");
|
||||
}
|
||||
|
||||
// === Final graph: norm + lm_head ===
|
||||
self.final_graph.begin_capture(&self.stream).expect("begin final capture");
|
||||
self.final_graph
|
||||
.begin_capture(&self.stream)
|
||||
.expect("begin final capture");
|
||||
unsafe {
|
||||
dispatch::rmsnorm_bf16(self.buffers.x.as_ptr() as _, norm_weight, self.buffers.normed.as_mut_ptr() as _, 1, h, eps, s);
|
||||
dispatch::rmsnorm_bf16(
|
||||
self.buffers.x.as_ptr() as _,
|
||||
norm_weight,
|
||||
self.buffers.normed.as_mut_ptr() as _,
|
||||
1,
|
||||
h,
|
||||
eps,
|
||||
s,
|
||||
);
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _, lm_head_wt, self.buffers.logits.as_mut_ptr() as _,
|
||||
self.buffers.normed.as_ptr() as _,
|
||||
lm_head_wt,
|
||||
self.buffers.logits.as_mut_ptr() as _,
|
||||
self.buffers.fp32_vocab.as_mut_ptr() as _,
|
||||
h, vocab, s,
|
||||
h,
|
||||
vocab,
|
||||
s,
|
||||
);
|
||||
}
|
||||
self.final_graph.end_capture(&self.stream).expect("end final capture");
|
||||
self.final_graph
|
||||
.end_capture(&self.stream)
|
||||
.expect("end final capture");
|
||||
|
||||
// Reset cuBLAS back to null stream
|
||||
unsafe { dispatch::set_cublas_stream(cublas, std::ptr::null_mut()); }
|
||||
unsafe {
|
||||
dispatch::set_cublas_stream(cublas, std::ptr::null_mut());
|
||||
}
|
||||
|
||||
self.captured = true;
|
||||
}
|
||||
@@ -343,8 +512,14 @@ impl DecodeGraphState {
|
||||
let es = 2usize; // BF16
|
||||
|
||||
// Upload token ID and position to fixed GPU buffers
|
||||
self.buffers.token_id_gpu.copy_from_host(&token_id.to_le_bytes()).unwrap();
|
||||
self.buffers.position_gpu.copy_from_host(&position.to_le_bytes()).unwrap();
|
||||
self.buffers
|
||||
.token_id_gpu
|
||||
.copy_from_host(&token_id.to_le_bytes())
|
||||
.unwrap();
|
||||
self.buffers
|
||||
.position_gpu
|
||||
.copy_from_host(&position.to_le_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Embedding (outside graph since token_id changes each step)
|
||||
unsafe {
|
||||
@@ -352,13 +527,18 @@ impl DecodeGraphState {
|
||||
embed_table,
|
||||
self.buffers.token_id_gpu.as_ptr() as _,
|
||||
self.buffers.x.as_mut_ptr() as _,
|
||||
1, hidden_size, vocab_size, s,
|
||||
1,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
s,
|
||||
);
|
||||
}
|
||||
|
||||
for l in 0..self.num_layers {
|
||||
// Pre-attention graph (norm + QKV + reshape + QK-norm + RoPE)
|
||||
self.pre_attn_graphs[l].launch(&self.stream).expect("launch pre-attn graph");
|
||||
self.pre_attn_graphs[l]
|
||||
.launch(&self.stream)
|
||||
.expect("launch pre-attn graph");
|
||||
|
||||
// Ungraphed: KV cache append
|
||||
// k_final shape: [1, num_kv_heads, 1, head_dim] (after RoPE pipeline)
|
||||
@@ -402,9 +582,13 @@ impl DecodeGraphState {
|
||||
k_full.data_ptr() as _,
|
||||
v_full.data_ptr() as _,
|
||||
self.buffers.attn_out.as_mut_ptr() as _,
|
||||
1, nh as i32, nkv as i32,
|
||||
kv_len, hd as i32,
|
||||
scale, s,
|
||||
1,
|
||||
nh as i32,
|
||||
nkv as i32,
|
||||
kv_len,
|
||||
hd as i32,
|
||||
scale,
|
||||
s,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -412,11 +596,15 @@ impl DecodeGraphState {
|
||||
self.stream.synchronize().expect("sync before post-attn");
|
||||
|
||||
// Post-attention graph (merge + O-proj + add_rmsnorm + FFN + residual)
|
||||
self.post_attn_graphs[l].launch(&self.stream).expect("launch post-attn graph");
|
||||
self.post_attn_graphs[l]
|
||||
.launch(&self.stream)
|
||||
.expect("launch post-attn graph");
|
||||
}
|
||||
|
||||
// Final graph (norm + lm_head)
|
||||
self.final_graph.launch(&self.stream).expect("launch final graph");
|
||||
self.final_graph
|
||||
.launch(&self.stream)
|
||||
.expect("launch final graph");
|
||||
|
||||
// Sync to ensure logits are ready
|
||||
self.stream.synchronize().expect("sync after decode");
|
||||
|
||||
@@ -31,7 +31,7 @@ struct GPT2Block {
|
||||
|
||||
pub struct KVCache {
|
||||
// Per layer, per head: raw bytes (works for both f32 and bf16)
|
||||
k: Vec<Vec<Vec<u8>>>, // [num_layers][num_heads][seq_len * head_dim * elem_size]
|
||||
k: Vec<Vec<Vec<u8>>>, // [num_layers][num_heads][seq_len * head_dim * elem_size]
|
||||
v: Vec<Vec<Vec<u8>>>,
|
||||
len: usize,
|
||||
num_heads: usize,
|
||||
@@ -42,7 +42,13 @@ pub struct KVCache {
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
pub fn new(num_layers: usize, num_heads: usize, head_dim: usize, dtype: DType, device: Device) -> Self {
|
||||
pub fn new(
|
||||
num_layers: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
) -> Self {
|
||||
Self {
|
||||
k: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(),
|
||||
v: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(),
|
||||
@@ -55,10 +61,18 @@ impl KVCache {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn seq_len(&self) -> usize { self.len }
|
||||
pub fn seq_len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
/// Append from a CPU tensor with shape [1, H, new_tokens, D].
|
||||
pub fn append_kv_tensor(&mut self, layer: usize, k_cpu: &Tensor, v_cpu: &Tensor, new_tokens: usize) {
|
||||
pub fn append_kv_tensor(
|
||||
&mut self,
|
||||
layer: usize,
|
||||
k_cpu: &Tensor,
|
||||
v_cpu: &Tensor,
|
||||
new_tokens: usize,
|
||||
) {
|
||||
let hd = self.head_dim;
|
||||
let es = self.elem_size;
|
||||
let k_bytes = k_cpu.storage().as_cpu_bytes();
|
||||
@@ -118,7 +132,8 @@ impl GPT2 {
|
||||
pub fn from_weights(config: ModelConfig, mut w: HashMap<String, Tensor>) -> Self {
|
||||
crate::init_kernels();
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
w.remove(name)
|
||||
.unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
|
||||
let wte = take(&mut w, "wte.weight");
|
||||
@@ -147,7 +162,15 @@ impl GPT2 {
|
||||
});
|
||||
}
|
||||
|
||||
Self { config, wte, wpe, layers, ln_f_g, ln_f_b, lm_head }
|
||||
Self {
|
||||
config,
|
||||
wte,
|
||||
wpe,
|
||||
layers,
|
||||
ln_f_g,
|
||||
ln_f_b,
|
||||
lm_head,
|
||||
}
|
||||
}
|
||||
|
||||
/// Full forward pass without KV cache (for testing / correctness comparison).
|
||||
@@ -179,14 +202,22 @@ impl GPT2 {
|
||||
let head_dim = self.config.head_dim();
|
||||
|
||||
let tok_emb = embedding(&self.wte, token_ids);
|
||||
let pos_ids: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
let pos_ids: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
let pos_emb = embedding(&self.wpe, &pos_ids);
|
||||
let mut x = add_tensors(&tok_emb, &pos_emb);
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
x = self.transformer_block(
|
||||
layer, &x, Some((cache, layer_idx)),
|
||||
pos_offset, new_tokens, num_heads, head_dim, hidden,
|
||||
layer,
|
||||
&x,
|
||||
Some((cache, layer_idx)),
|
||||
pos_offset,
|
||||
new_tokens,
|
||||
num_heads,
|
||||
head_dim,
|
||||
hidden,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -199,7 +230,7 @@ impl GPT2 {
|
||||
layer: &GPT2Block,
|
||||
x: &Tensor,
|
||||
cache: Option<(&mut KVCache, usize)>,
|
||||
pos_offset: usize,
|
||||
_pos_offset: usize,
|
||||
new_tokens: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
@@ -238,7 +269,11 @@ impl GPT2 {
|
||||
|
||||
fn linear(x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Tensor {
|
||||
let out = matmul_2d(x, weight);
|
||||
if let Some(b) = bias { add_bias(&out, b) } else { out }
|
||||
if let Some(b) = bias {
|
||||
add_bias(&out, b)
|
||||
} else {
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
@@ -277,7 +312,12 @@ fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) -> (Tensor, Tensor, Tensor) {
|
||||
fn split_qkv(
|
||||
qkv: &Tensor,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
seq_len: usize,
|
||||
) -> (Tensor, Tensor, Tensor) {
|
||||
let hidden = num_heads * head_dim;
|
||||
let qkv_cpu = qkv.to_device(Device::Cpu);
|
||||
let device = qkv.device();
|
||||
@@ -294,14 +334,21 @@ fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) ->
|
||||
for h in 0..num_heads {
|
||||
let src_off = h * head_dim;
|
||||
let dst_off = (h * seq_len + s) * head_dim;
|
||||
q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]);
|
||||
q_data[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(
|
||||
&row[2 * hidden + src_off..2 * hidden + src_off + head_dim],
|
||||
);
|
||||
}
|
||||
}
|
||||
let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let q =
|
||||
Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k =
|
||||
Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v =
|
||||
Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
(q, k, v)
|
||||
}
|
||||
DType::BF16 => {
|
||||
@@ -314,14 +361,21 @@ fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) ->
|
||||
for h in 0..num_heads {
|
||||
let src_off = h * head_dim;
|
||||
let dst_off = (h * seq_len + s) * head_dim;
|
||||
q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]);
|
||||
q_data[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(
|
||||
&row[2 * hidden + src_off..2 * hidden + src_off + head_dim],
|
||||
);
|
||||
}
|
||||
}
|
||||
let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let q =
|
||||
Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k =
|
||||
Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v =
|
||||
Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
(q, k, v)
|
||||
}
|
||||
_ => panic!("unsupported dtype {:?} in split_qkv", dtype),
|
||||
@@ -343,7 +397,8 @@ fn merge_heads(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor {
|
||||
for h in 0..num_heads {
|
||||
let src_off = (h * seq_len + s) * head_dim;
|
||||
let dst_off = s * hidden + h * head_dim;
|
||||
out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
out[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(device)
|
||||
@@ -355,7 +410,8 @@ fn merge_heads(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor {
|
||||
for h in 0..num_heads {
|
||||
let src_off = (h * seq_len + s) * head_dim;
|
||||
let dst_off = s * hidden + h * head_dim;
|
||||
out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
out[dst_off..dst_off + head_dim]
|
||||
.copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(device)
|
||||
@@ -372,7 +428,8 @@ pub fn sample_greedy(logits: &Tensor) -> u32 {
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
last_row.iter()
|
||||
last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.map(|(idx, _)| idx as u32)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use half::bf16;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use half::bf16;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
@@ -49,10 +49,10 @@ struct GptOssBlock {
|
||||
expert_down_bias: Tensor, // [local_experts, hidden]
|
||||
// FP8 quantized expert weights (Some when running FP8 W8A8)
|
||||
// Transposed layout [E, N, K] for cuBLASLt FP8 (Blackwell requires transA=T)
|
||||
expert_gate_up_fp8: Option<Tensor>, // [local_experts, 2*inter, hidden] FP8E4M3
|
||||
expert_gate_up_scale: Option<Tensor>,// [local_experts] F32
|
||||
expert_down_fp8: Option<Tensor>, // [local_experts, hidden, inter] FP8E4M3
|
||||
expert_down_scale: Option<Tensor>, // [local_experts] F32
|
||||
expert_gate_up_fp8: Option<Tensor>, // [local_experts, 2*inter, hidden] FP8E4M3
|
||||
expert_gate_up_scale: Option<Tensor>, // [local_experts] F32
|
||||
expert_down_fp8: Option<Tensor>, // [local_experts, hidden, inter] FP8E4M3
|
||||
expert_down_scale: Option<Tensor>, // [local_experts] F32
|
||||
// MXFP4 W4A16 expert weights (Some when running 4-bit weight-only).
|
||||
// (packed [E, N, K/2] u8, scales [E, N, K/32] u8) in [E, N, K] layout.
|
||||
expert_gate_up_mxfp4: Option<(Tensor, Tensor)>,
|
||||
@@ -79,16 +79,23 @@ impl GptOss {
|
||||
crate::init_kernels();
|
||||
let dev = Device::Cuda(device);
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
w.remove(name)
|
||||
.unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
|
||||
// column-parallel: shard rows of [out, in], transpose → [in, out/world]
|
||||
let col = |t: Tensor| -> Tensor {
|
||||
shard_rows(&t, rank, world).to_device(dev).transpose(0, 1).contiguous()
|
||||
shard_rows(&t, rank, world)
|
||||
.to_device(dev)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
};
|
||||
// row-parallel: shard cols of [out, in], transpose → [in/world, out]
|
||||
let row = |t: Tensor| -> Tensor {
|
||||
shard_cols(&t, rank, world).to_device(dev).transpose(0, 1).contiguous()
|
||||
shard_cols(&t, rank, world)
|
||||
.to_device(dev)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
};
|
||||
// Bias sharding helpers
|
||||
let col_bias = |t: Tensor| -> Tensor { shard_1d(&t, rank, world).to_device(dev) };
|
||||
@@ -97,7 +104,9 @@ impl GptOss {
|
||||
let embed_tokens = repl(take(&mut w, "model.embed_tokens.weight"));
|
||||
let norm = repl(take(&mut w, "model.norm.weight"));
|
||||
let norm_bias = w.remove("model.norm.bias").map(|t| repl(t));
|
||||
let lm_head_t = repl(take(&mut w, "lm_head.weight")).transpose(0, 1).contiguous();
|
||||
let lm_head_t = repl(take(&mut w, "lm_head.weight"))
|
||||
.transpose(0, 1)
|
||||
.contiguous();
|
||||
|
||||
let head_dim = config.head_dim();
|
||||
let rope_theta = config.rope_theta.unwrap_or(150000.0);
|
||||
@@ -176,15 +185,30 @@ impl GptOss {
|
||||
// MXFP4 stores 4-bit weights in an FP8E4M3 byte container (same dtype
|
||||
// as FP8), so distinguish by the scale rank: FP8 scale is 1-D [E],
|
||||
// MXFP4 scale is 3-D [E, N, K/32].
|
||||
let is_mxfp4 = gate_up_scale.as_ref().map(|s| s.ndim() == 3).unwrap_or(false);
|
||||
let is_mxfp4 = gate_up_scale
|
||||
.as_ref()
|
||||
.map(|s| s.ndim() == 3)
|
||||
.unwrap_or(false);
|
||||
let is_fp8 = !is_mxfp4 && gate_up_3d.dtype() == xserv_tensor::DType::FP8E4M3;
|
||||
|
||||
let mut expert_gate_up_mxfp4: Option<(Tensor, Tensor)> = None;
|
||||
let mut expert_down_mxfp4: Option<(Tensor, Tensor)> = None;
|
||||
|
||||
let inter2 = if is_mxfp4 { gate_up_3d.shape()[1] } else { gate_up_3d.shape()[2] }; // 2*inter (N)
|
||||
let hidden = if is_mxfp4 { gate_up_3d.shape()[2] * 2 } else { gate_up_3d.shape()[1] };
|
||||
let inter = if is_mxfp4 { down_3d.shape()[2] * 2 } else { down_3d.shape()[1] };
|
||||
let inter2 = if is_mxfp4 {
|
||||
gate_up_3d.shape()[1]
|
||||
} else {
|
||||
gate_up_3d.shape()[2]
|
||||
}; // 2*inter (N)
|
||||
let hidden = if is_mxfp4 {
|
||||
gate_up_3d.shape()[2] * 2
|
||||
} else {
|
||||
gate_up_3d.shape()[1]
|
||||
};
|
||||
let inter = if is_mxfp4 {
|
||||
down_3d.shape()[2] * 2
|
||||
} else {
|
||||
down_3d.shape()[1]
|
||||
};
|
||||
|
||||
// Slice the rank's range of experts as contiguous 3D tensors on GPU
|
||||
let expert_gate_up_wt;
|
||||
@@ -199,10 +223,38 @@ impl GptOss {
|
||||
// + scales [E, N, K/32]. Slice this rank's experts (raw bytes).
|
||||
let gu_s = gate_up_scale.expect("MXFP4 model missing gate_up_proj_scale");
|
||||
let d_s = down_scale.expect("MXFP4 model missing down_proj_scale");
|
||||
let gu_packed = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, inter2, hidden / 2).to_device(dev);
|
||||
let gu_scl = slice_expert_range_3d_raw(&gu_s, expert_start, local_experts, inter2, hidden / 32).to_device(dev);
|
||||
let dn_packed = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, hidden, inter / 2).to_device(dev);
|
||||
let dn_scl = slice_expert_range_3d_raw(&d_s, expert_start, local_experts, hidden, inter / 32).to_device(dev);
|
||||
let gu_packed = slice_expert_range_3d_raw(
|
||||
&gate_up_3d,
|
||||
expert_start,
|
||||
local_experts,
|
||||
inter2,
|
||||
hidden / 2,
|
||||
)
|
||||
.to_device(dev);
|
||||
let gu_scl = slice_expert_range_3d_raw(
|
||||
&gu_s,
|
||||
expert_start,
|
||||
local_experts,
|
||||
inter2,
|
||||
hidden / 32,
|
||||
)
|
||||
.to_device(dev);
|
||||
let dn_packed = slice_expert_range_3d_raw(
|
||||
&down_3d,
|
||||
expert_start,
|
||||
local_experts,
|
||||
hidden,
|
||||
inter / 2,
|
||||
)
|
||||
.to_device(dev);
|
||||
let dn_scl = slice_expert_range_3d_raw(
|
||||
&d_s,
|
||||
expert_start,
|
||||
local_experts,
|
||||
hidden,
|
||||
inter / 32,
|
||||
)
|
||||
.to_device(dev);
|
||||
expert_gate_up_mxfp4 = Some((gu_packed, gu_scl));
|
||||
expert_down_mxfp4 = Some((dn_packed, dn_scl));
|
||||
expert_gate_up_fp8 = None;
|
||||
@@ -214,36 +266,65 @@ impl GptOss {
|
||||
} else if is_fp8 {
|
||||
// FP8 W8A8 path: load and TRANSPOSE weights for cuBLASLt (requires transA=T on Blackwell).
|
||||
// Original: [E, K, N] → Transposed: [E, N, K]
|
||||
let gu_sliced = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2);
|
||||
let dn_sliced = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden);
|
||||
expert_gate_up_fp8 = Some(transpose_3d_inner_raw(&gu_sliced, local_experts, hidden, inter2).to_device(dev));
|
||||
expert_down_fp8 = Some(transpose_3d_inner_raw(&dn_sliced, local_experts, inter, hidden).to_device(dev));
|
||||
let gu_sliced = slice_expert_range_3d_raw(
|
||||
&gate_up_3d,
|
||||
expert_start,
|
||||
local_experts,
|
||||
hidden,
|
||||
inter2,
|
||||
);
|
||||
let dn_sliced =
|
||||
slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden);
|
||||
expert_gate_up_fp8 = Some(
|
||||
transpose_3d_inner_raw(&gu_sliced, local_experts, hidden, inter2)
|
||||
.to_device(dev),
|
||||
);
|
||||
expert_down_fp8 = Some(
|
||||
transpose_3d_inner_raw(&dn_sliced, local_experts, inter, hidden).to_device(dev),
|
||||
);
|
||||
// Scales: [num_experts] F32 → slice to [local_experts]
|
||||
let gu_s = gate_up_scale.expect("FP8 model missing gate_up_proj_scale");
|
||||
let d_s = down_scale.expect("FP8 model missing down_proj_scale");
|
||||
expert_gate_up_scale_gpu = Some(slice_scale_range(&gu_s, expert_start, local_experts).to_device(dev));
|
||||
expert_down_scale_gpu = Some(slice_scale_range(&d_s, expert_start, local_experts).to_device(dev));
|
||||
expert_gate_up_scale_gpu =
|
||||
Some(slice_scale_range(&gu_s, expert_start, local_experts).to_device(dev));
|
||||
expert_down_scale_gpu =
|
||||
Some(slice_scale_range(&d_s, expert_start, local_experts).to_device(dev));
|
||||
// Dummy BF16 tensors (never read in FP8 path)
|
||||
expert_gate_up_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev);
|
||||
expert_down_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev);
|
||||
} else {
|
||||
// BF16 path: existing behavior
|
||||
expert_gate_up_wt = slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev);
|
||||
expert_down_wt = slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev);
|
||||
expert_gate_up_wt =
|
||||
slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2)
|
||||
.to_device(dev);
|
||||
expert_down_wt =
|
||||
slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden)
|
||||
.to_device(dev);
|
||||
expert_gate_up_fp8 = None;
|
||||
expert_gate_up_scale_gpu = None;
|
||||
expert_down_fp8 = None;
|
||||
expert_down_scale_gpu = None;
|
||||
}
|
||||
let expert_gate_up_bias = slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2).to_device(dev);
|
||||
let expert_down_bias = slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden).to_device(dev);
|
||||
let expert_gate_up_bias =
|
||||
slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2)
|
||||
.to_device(dev);
|
||||
let expert_down_bias =
|
||||
slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden)
|
||||
.to_device(dev);
|
||||
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
let input_norm = repl(take(&mut w, &format!("{p}.input_layernorm.weight")));
|
||||
let input_norm_bias = w.remove(&format!("{p}.input_layernorm.bias")).map(|t| repl(t));
|
||||
let post_norm = repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight")));
|
||||
let post_norm_bias = w.remove(&format!("{p}.post_attention_layernorm.bias")).map(|t| repl(t));
|
||||
let input_norm_bias = w
|
||||
.remove(&format!("{p}.input_layernorm.bias"))
|
||||
.map(|t| repl(t));
|
||||
let post_norm = repl(take(
|
||||
&mut w,
|
||||
&format!("{p}.post_attention_layernorm.weight"),
|
||||
));
|
||||
let post_norm_bias = w
|
||||
.remove(&format!("{p}.post_attention_layernorm.bias"))
|
||||
.map(|t| repl(t));
|
||||
|
||||
layers.push(GptOssBlock {
|
||||
input_norm,
|
||||
@@ -283,17 +364,27 @@ impl GptOss {
|
||||
let local_num_kv_heads = config.num_kv_heads() / world;
|
||||
|
||||
let has_norm_bias = norm_bias.is_some();
|
||||
let is_fp8 = layers.first().map(|l| l.expert_gate_up_fp8.is_some()).unwrap_or(false);
|
||||
let is_mxfp4 = layers.first().map(|l| l.expert_gate_up_mxfp4.is_some()).unwrap_or(false);
|
||||
let is_fp8 = layers
|
||||
.first()
|
||||
.map(|l| l.expert_gate_up_fp8.is_some())
|
||||
.unwrap_or(false);
|
||||
let is_mxfp4 = layers
|
||||
.first()
|
||||
.map(|l| l.expert_gate_up_mxfp4.is_some())
|
||||
.unwrap_or(false);
|
||||
if rank == 0 {
|
||||
if has_norm_bias {
|
||||
eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm");
|
||||
}
|
||||
if is_fp8 {
|
||||
eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)");
|
||||
eprintln!(
|
||||
"gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)"
|
||||
);
|
||||
}
|
||||
if is_mxfp4 {
|
||||
eprintln!("gpt-oss: MXFP4 quantized expert weights detected (W4A16 fused-GEMV mode)");
|
||||
eprintln!(
|
||||
"gpt-oss: MXFP4 quantized expert weights detected (W4A16 fused-GEMV mode)"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,7 +432,13 @@ impl GptOss {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn add_norm(x: &Tensor, residual: &Tensor, weight: &Tensor, bias: &Option<Tensor>, eps: f32) -> (Tensor, Tensor) {
|
||||
fn add_norm(
|
||||
x: &Tensor,
|
||||
residual: &Tensor,
|
||||
weight: &Tensor,
|
||||
bias: &Option<Tensor>,
|
||||
eps: f32,
|
||||
) -> (Tensor, Tensor) {
|
||||
match bias {
|
||||
Some(b) => {
|
||||
let sum = xserv_kernels::add(x, residual);
|
||||
@@ -373,24 +470,62 @@ impl GptOss {
|
||||
assert_eq!(seq_slots.len(), batch);
|
||||
assert!(batch > 0);
|
||||
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.norm_eps();
|
||||
self.decode_prepare(positions, seq_slots, paged_cache);
|
||||
|
||||
// Upload token ids + positions, then run the pure-GPU core.
|
||||
let ids_gpu = upload_u32(tokens);
|
||||
let positions_u32: Vec<u32> = positions.iter().map(|&p| p as u32).collect();
|
||||
let pos_gpu = upload_u32(&positions_u32);
|
||||
let logits = self.decode_core(
|
||||
ids_gpu.as_ptr() as *const c_void,
|
||||
pos_gpu.as_ptr() as *const c_void,
|
||||
batch,
|
||||
paged_cache,
|
||||
);
|
||||
|
||||
for &slot in seq_slots {
|
||||
paged_cache.advance_seq_len(slot, 1);
|
||||
}
|
||||
logits
|
||||
}
|
||||
|
||||
/// Host-side per-step cache bookkeeping: block allocation + uploading
|
||||
/// block tables / context lens to their (stable-address) GPU buffers.
|
||||
/// Runs OUTSIDE the CUDA-graph captured region.
|
||||
pub fn decode_prepare(
|
||||
&self,
|
||||
positions: &[usize],
|
||||
seq_slots: &[usize],
|
||||
paged_cache: &mut PagedKVCache,
|
||||
) {
|
||||
let kv_lens: Vec<i32> = positions.iter().map(|&p| (p + 1) as i32).collect();
|
||||
for (b, &slot) in seq_slots.iter().enumerate() {
|
||||
paged_cache.ensure_capacity(slot, positions[b] + 1);
|
||||
}
|
||||
paged_cache.sync_active_batch_with_lens(seq_slots, &kv_lens);
|
||||
}
|
||||
|
||||
/// The pure-GPU decode step: embedding → 24 layers → final norm → logits.
|
||||
/// Token ids and positions are read from device buffers; every other input
|
||||
/// (weights, KV pools, block table, context lens) has a stable address —
|
||||
/// which is exactly what makes this region CUDA-graph capturable.
|
||||
pub fn decode_core(
|
||||
&self,
|
||||
ids_gpu: *const c_void,
|
||||
pos_gpu: *const c_void,
|
||||
batch: usize,
|
||||
paged_cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.norm_eps();
|
||||
|
||||
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
|
||||
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
|
||||
let max_blocks = paged_cache.max_blocks_per_seq();
|
||||
|
||||
let positions_u32: Vec<u32> = positions.iter().map(|&p| p as u32).collect();
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, tokens);
|
||||
let mut x = embedding_device_ids(&self.embed_tokens, ids_gpu, batch);
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
@@ -401,14 +536,13 @@ impl GptOss {
|
||||
let k_all = add_bias(&matmul_2d(&normed, &layer.k_proj_wt), &layer.k_proj_bias);
|
||||
let v_all = add_bias(&matmul_2d(&normed, &layer.v_proj_wt), &layer.v_proj_bias);
|
||||
|
||||
|
||||
// Reshape for RoPE: [B, H*D] → [B, H, D]
|
||||
let q_3d = q_all.reshape(&[batch, num_heads, head_dim]);
|
||||
let k_3d = k_all.reshape(&[batch, num_kv_heads, head_dim]);
|
||||
|
||||
// RoPE (no QK-norm for gpt-oss)
|
||||
rope_inplace(&q_3d, &self.rope_cache, &positions_u32);
|
||||
rope_inplace(&k_3d, &self.rope_cache, &positions_u32);
|
||||
rope_inplace_device_pos(&q_3d, &self.rope_cache, pos_gpu);
|
||||
rope_inplace_device_pos(&k_3d, &self.rope_cache, pos_gpu);
|
||||
|
||||
let v_3d = v_all.reshape(&[batch, num_kv_heads, head_dim]);
|
||||
|
||||
@@ -422,9 +556,17 @@ impl GptOss {
|
||||
let sinks_ptr = layer.sinks.data_ptr() as *const c_void;
|
||||
|
||||
let attn_out = paged_decode_attention_sinks(
|
||||
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
|
||||
&q_4d,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
sinks_ptr,
|
||||
batch, num_heads, num_kv_heads, head_dim, max_blocks,
|
||||
batch,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_blocks,
|
||||
layer.window_size,
|
||||
);
|
||||
|
||||
@@ -433,9 +575,14 @@ impl GptOss {
|
||||
self.all_reduce(&attn_proj);
|
||||
let attn_proj = add_bias(&attn_proj, &layer.o_proj_bias);
|
||||
|
||||
|
||||
// Residual + post-norm
|
||||
let (normed, x_new) = Self::add_norm(&attn_proj, &residual, &layer.post_norm, &layer.post_norm_bias, eps);
|
||||
let (normed, x_new) = Self::add_norm(
|
||||
&attn_proj,
|
||||
&residual,
|
||||
&layer.post_norm,
|
||||
&layer.post_norm_bias,
|
||||
eps,
|
||||
);
|
||||
|
||||
let residual = x_new;
|
||||
let normed = normed.contiguous();
|
||||
@@ -445,17 +592,8 @@ impl GptOss {
|
||||
x = xserv_kernels::add(&residual, &moe_out);
|
||||
}
|
||||
|
||||
// Advance KV cache
|
||||
for &slot in seq_slots {
|
||||
paged_cache.advance_seq_len(slot, 1);
|
||||
}
|
||||
|
||||
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
|
||||
let x = Self::norm(&x, &self.norm, &self.norm_bias, eps);
|
||||
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
|
||||
let logits = matmul_2d(&x, &self.lm_head_t);
|
||||
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
|
||||
logits
|
||||
matmul_2d(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
/// Paged prefill: process full prompt tokens.
|
||||
@@ -476,7 +614,9 @@ impl GptOss {
|
||||
paged_cache.advance_seq_len(slot, new_tokens);
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
@@ -503,14 +643,21 @@ impl GptOss {
|
||||
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
|
||||
|
||||
// Flash attention with gpt-oss sinks + (per-layer) sliding window.
|
||||
let attn_out = flash_attention_sinks(&q, &k_full, &v_full, &layer.sinks, layer.window_size);
|
||||
let attn_out =
|
||||
flash_attention_sinks(&q, &k_full, &v_full, &layer.sinks, layer.window_size);
|
||||
|
||||
let attn_merged = merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj);
|
||||
let attn_proj = add_bias(&attn_proj, &layer.o_proj_bias);
|
||||
|
||||
let (normed, x_new) = Self::add_norm(&attn_proj, &residual, &layer.post_norm, &layer.post_norm_bias, eps);
|
||||
let (normed, x_new) = Self::add_norm(
|
||||
&attn_proj,
|
||||
&residual,
|
||||
&layer.post_norm,
|
||||
&layer.post_norm_bias,
|
||||
eps,
|
||||
);
|
||||
let residual = x_new;
|
||||
|
||||
// MoE MLP
|
||||
@@ -519,9 +666,7 @@ impl GptOss {
|
||||
}
|
||||
|
||||
let x = Self::norm(&x, &self.norm, &self.norm_bias, eps);
|
||||
let logits = matmul_2d(&x, &self.lm_head_t);
|
||||
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
|
||||
logits
|
||||
matmul_2d(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
/// MoE forward pass — fully on GPU via batched GEMM.
|
||||
@@ -539,15 +684,101 @@ impl GptOss {
|
||||
let expert_start = rank * local_experts;
|
||||
|
||||
// 1. Router: [tokens, hidden] @ [hidden, num_experts] + bias → [tokens, num_experts]
|
||||
let router_logits = add_bias(
|
||||
&matmul_2d(x, &layer.router_wt),
|
||||
&layer.router_bias,
|
||||
);
|
||||
let router_logits = add_bias(&matmul_2d(x, &layer.router_wt), &layer.router_bias);
|
||||
|
||||
// 2. GPU top-k + softmax
|
||||
let (topk_ids, topk_weights) = xserv_kernels::moe::moe_topk_softmax(
|
||||
&router_logits, num_experts, top_k,
|
||||
);
|
||||
let (topk_ids, topk_weights) =
|
||||
xserv_kernels::moe::moe_topk_softmax(&router_logits, num_experts, top_k);
|
||||
|
||||
// Sparse decode path: compute ONLY the routed experts. The dense path
|
||||
// below reads every local expert's weights per forward; the sparse
|
||||
// GEMVs read ~top_k/num_experts of the bytes, which dominates decode
|
||||
// (memory-bound). Dense reads each weight once for ALL tokens, so it
|
||||
// wins back at num_tokens ≈ local_experts / E[local hits] ≈ 8.
|
||||
const SPARSE_MAX_TOKENS: usize = 8;
|
||||
let quantized = layer.expert_gate_up_fp8.is_some() || layer.expert_gate_up_mxfp4.is_some();
|
||||
if num_tokens <= SPARSE_MAX_TOKENS && quantized && !dense_moe_forced() {
|
||||
let gate_up = if let Some((ref packed, ref scales)) = layer.expert_gate_up_mxfp4 {
|
||||
let n = packed.shape()[1];
|
||||
let k = packed.shape()[2] * 2;
|
||||
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
|
||||
x,
|
||||
packed,
|
||||
scales,
|
||||
&layer.expert_gate_up_bias,
|
||||
&topk_ids,
|
||||
num_tokens,
|
||||
top_k,
|
||||
n,
|
||||
k,
|
||||
expert_start,
|
||||
local_experts,
|
||||
false,
|
||||
)
|
||||
} else {
|
||||
xserv_kernels::moe::moe_sparse_gemv_fp8(
|
||||
x,
|
||||
layer.expert_gate_up_fp8.as_ref().unwrap(),
|
||||
layer.expert_gate_up_scale.as_ref().unwrap(),
|
||||
&layer.expert_gate_up_bias,
|
||||
&topk_ids,
|
||||
num_tokens,
|
||||
top_k,
|
||||
expert_start,
|
||||
local_experts,
|
||||
false,
|
||||
)
|
||||
};
|
||||
|
||||
// GLU over all slots. Non-local slots hold unwritten memory; they
|
||||
// are never consumed (the down GEMV and the weighted sum both skip
|
||||
// slots whose expert this rank does not own).
|
||||
let inter2 = gate_up.shape()[2];
|
||||
let gate_up_flat = gate_up.reshape(&[num_tokens * top_k, inter2]);
|
||||
let activated = gpt_oss_glu(&gate_up_flat, layer.glu_alpha, layer.glu_limit);
|
||||
|
||||
let down = if let Some((ref packed, ref scales)) = layer.expert_down_mxfp4 {
|
||||
let n = packed.shape()[1];
|
||||
let k = packed.shape()[2] * 2;
|
||||
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
|
||||
&activated,
|
||||
packed,
|
||||
scales,
|
||||
&layer.expert_down_bias,
|
||||
&topk_ids,
|
||||
num_tokens,
|
||||
top_k,
|
||||
n,
|
||||
k,
|
||||
expert_start,
|
||||
local_experts,
|
||||
true,
|
||||
)
|
||||
} else {
|
||||
xserv_kernels::moe::moe_sparse_gemv_fp8(
|
||||
&activated,
|
||||
layer.expert_down_fp8.as_ref().unwrap(),
|
||||
layer.expert_down_scale.as_ref().unwrap(),
|
||||
&layer.expert_down_bias,
|
||||
&topk_ids,
|
||||
num_tokens,
|
||||
top_k,
|
||||
expert_start,
|
||||
local_experts,
|
||||
true,
|
||||
)
|
||||
};
|
||||
|
||||
let moe_out = xserv_kernels::moe::moe_weighted_sum_sparse(
|
||||
&down,
|
||||
&topk_ids,
|
||||
&topk_weights,
|
||||
expert_start,
|
||||
local_experts,
|
||||
);
|
||||
self.all_reduce(&moe_out);
|
||||
return moe_out;
|
||||
}
|
||||
|
||||
// 3. Replicate input: [tokens, hidden] → [local_experts, tokens, hidden]
|
||||
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
|
||||
@@ -563,14 +794,24 @@ impl GptOss {
|
||||
xserv_kernels::quantization::batched_gemv_mxfp4(&x2, packed, scales, n, k)
|
||||
.reshape(&[local_experts, 1, n])
|
||||
} else {
|
||||
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k);
|
||||
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(
|
||||
packed,
|
||||
scales,
|
||||
local_experts,
|
||||
n,
|
||||
k,
|
||||
);
|
||||
xserv_kernels::moe::batched_gemm_strided(&x_rep, &w_bf16)
|
||||
}
|
||||
} else if let Some(ref wt_fp8_t) = layer.expert_gate_up_fp8 {
|
||||
// W8A8: quantize activations with per-expert scalar scale, use cuBLASLt FP8 GEMM
|
||||
let (x_fp8, x_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep);
|
||||
let (x_fp8, x_scales) =
|
||||
xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep);
|
||||
xserv_kernels::quantization::batched_gemm_fp8(
|
||||
&x_fp8, &x_scales, wt_fp8_t, layer.expert_gate_up_scale.as_ref().unwrap(),
|
||||
&x_fp8,
|
||||
&x_scales,
|
||||
wt_fp8_t,
|
||||
layer.expert_gate_up_scale.as_ref().unwrap(),
|
||||
)
|
||||
} else {
|
||||
xserv_kernels::moe::batched_gemm_strided(&x_rep, &layer.expert_gate_up_wt)
|
||||
@@ -596,14 +837,24 @@ impl GptOss {
|
||||
xserv_kernels::quantization::batched_gemv_mxfp4(&a2, packed, scales, n, k)
|
||||
.reshape(&[local_experts, 1, n])
|
||||
} else {
|
||||
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k);
|
||||
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(
|
||||
packed,
|
||||
scales,
|
||||
local_experts,
|
||||
n,
|
||||
k,
|
||||
);
|
||||
xserv_kernels::moe::batched_gemm_strided(&activated, &w_bf16)
|
||||
}
|
||||
} else if let Some(ref wt_fp8) = layer.expert_down_fp8 {
|
||||
// W8A8: quantize post-GLU activations to FP8, use cuBLASLt FP8 GEMM
|
||||
let (act_fp8, act_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated);
|
||||
let (act_fp8, act_scales) =
|
||||
xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated);
|
||||
xserv_kernels::quantization::batched_gemm_fp8(
|
||||
&act_fp8, &act_scales, wt_fp8, layer.expert_down_scale.as_ref().unwrap(),
|
||||
&act_fp8,
|
||||
&act_scales,
|
||||
wt_fp8,
|
||||
layer.expert_down_scale.as_ref().unwrap(),
|
||||
)
|
||||
} else {
|
||||
xserv_kernels::moe::batched_gemm_strided(&activated, &layer.expert_down_wt)
|
||||
@@ -614,8 +865,12 @@ impl GptOss {
|
||||
|
||||
// 9. Weighted sum across experts → [tokens, hidden]
|
||||
let moe_out = xserv_kernels::moe::moe_weighted_sum(
|
||||
&down, &topk_ids, &topk_weights,
|
||||
expert_start, local_experts, top_k,
|
||||
&down,
|
||||
&topk_ids,
|
||||
&topk_weights,
|
||||
expert_start,
|
||||
local_experts,
|
||||
top_k,
|
||||
);
|
||||
|
||||
self.all_reduce(&moe_out);
|
||||
@@ -625,45 +880,45 @@ impl GptOss {
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
/// Upload a u32 slice to a pooled GPU buffer (synchronous H2D).
|
||||
fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer {
|
||||
let bytes = unsafe { std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) };
|
||||
let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc u32 upload");
|
||||
buf.copy_from_host(bytes).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
/// XSERV_DENSE_MOE=1 forces the dense all-expert path (A/B benchmarking).
|
||||
fn dense_moe_forced() -> bool {
|
||||
static FORCED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
|
||||
*FORCED.get_or_init(|| std::env::var("XSERV_DENSE_MOE").is_ok_and(|v| v != "0"))
|
||||
}
|
||||
|
||||
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert_eq!(b.ndim(), 2);
|
||||
matmul(a, b, GemmBackend::CuBlas)
|
||||
}
|
||||
|
||||
/// Add bias to a 2D tensor: [rows, cols] + [cols] → [rows, cols]
|
||||
/// Add bias to a 2D tensor: [rows, cols] + [cols] → [rows, cols].
|
||||
/// Single GPU broadcast kernel — the old rows>1 path tiled the bias on the
|
||||
/// CPU (D2H + host loop + H2D) on every call, 96×/prefill in the hot path.
|
||||
fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
|
||||
assert_eq!(x.ndim(), 2);
|
||||
assert_eq!(bias.ndim(), 1);
|
||||
let rows = x.shape()[0];
|
||||
let cols = x.shape()[1];
|
||||
assert_eq!(bias.shape()[0], cols, "bias size {} != cols {}", bias.shape()[0], cols);
|
||||
|
||||
let x_c = x.contiguous();
|
||||
|
||||
if rows == 1 {
|
||||
// Fast path: reshape bias [cols] → [1, cols] (zero-copy), add directly on GPU
|
||||
let bias_2d = bias.reshape(&[1, cols]);
|
||||
return xserv_kernels::add(&x_c, &bias_2d);
|
||||
}
|
||||
|
||||
// General path: tile bias to [rows, cols] via CPU, then add on GPU
|
||||
let bias_cpu = bias.to_device(Device::Cpu);
|
||||
let bias_data = bias_cpu.as_slice::<bf16>();
|
||||
let mut tiled = Vec::with_capacity(rows * cols);
|
||||
for _ in 0..rows {
|
||||
tiled.extend_from_slice(bias_data);
|
||||
}
|
||||
let bias_tiled = Tensor::from_slice(&tiled, &[rows, cols]).to_device(x.device());
|
||||
xserv_kernels::add(&x_c, &bias_tiled)
|
||||
xserv_kernels::bias_add_2d(&x_c, bias)
|
||||
}
|
||||
|
||||
fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
if world == 1 { return t.clone(); }
|
||||
if world == 1 {
|
||||
return t.clone();
|
||||
}
|
||||
let shape = t.shape();
|
||||
assert_eq!(shape.len(), 2);
|
||||
let (rows, cols) = (shape[0], shape[1]);
|
||||
assert!(rows % world == 0, "rows {rows} not divisible by world {world}");
|
||||
assert!(
|
||||
rows % world == 0,
|
||||
"rows {rows} not divisible by world {world}"
|
||||
);
|
||||
let local = rows / world;
|
||||
let host = t.to_device(Device::Cpu);
|
||||
let data = host.as_slice::<bf16>();
|
||||
@@ -673,11 +928,16 @@ fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
}
|
||||
|
||||
fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
if world == 1 { return t.clone(); }
|
||||
if world == 1 {
|
||||
return t.clone();
|
||||
}
|
||||
let shape = t.shape();
|
||||
assert_eq!(shape.len(), 2);
|
||||
let (rows, cols) = (shape[0], shape[1]);
|
||||
assert!(cols % world == 0, "cols {cols} not divisible by world {world}");
|
||||
assert!(
|
||||
cols % world == 0,
|
||||
"cols {cols} not divisible by world {world}"
|
||||
);
|
||||
let local = cols / world;
|
||||
let c0 = rank * local;
|
||||
let host = t.to_device(Device::Cpu);
|
||||
@@ -691,11 +951,16 @@ fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
}
|
||||
|
||||
fn shard_1d(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
if world == 1 { return t.clone(); }
|
||||
if world == 1 {
|
||||
return t.clone();
|
||||
}
|
||||
let shape = t.shape();
|
||||
assert_eq!(shape.len(), 1);
|
||||
let total = shape[0];
|
||||
assert!(total % world == 0, "dim {total} not divisible by world {world}");
|
||||
assert!(
|
||||
total % world == 0,
|
||||
"dim {total} not divisible by world {world}"
|
||||
);
|
||||
let local = total / world;
|
||||
let host = t.to_device(Device::Cpu);
|
||||
let data = host.as_slice::<bf16>();
|
||||
@@ -726,7 +991,13 @@ fn transpose_3d_inner_raw(t: &Tensor, batch: usize, rows: usize, cols: usize) ->
|
||||
}
|
||||
|
||||
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor (any dtype, raw bytes).
|
||||
fn slice_expert_range_3d_raw(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
|
||||
fn slice_expert_range_3d_raw(
|
||||
t: &Tensor,
|
||||
start: usize,
|
||||
count: usize,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(t.ndim(), 3);
|
||||
let host = t.to_device(Device::Cpu);
|
||||
let elem_size = t.dtype().size_bytes();
|
||||
@@ -748,7 +1019,13 @@ fn slice_scale_range(t: &Tensor, start: usize, count: usize) -> Tensor {
|
||||
}
|
||||
|
||||
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor
|
||||
fn slice_expert_range_3d(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
|
||||
fn slice_expert_range_3d(
|
||||
t: &Tensor,
|
||||
start: usize,
|
||||
count: usize,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(t.ndim(), 3);
|
||||
let host = t.to_device(Device::Cpu);
|
||||
let data = host.as_slice::<bf16>();
|
||||
|
||||
195
crates/xserv-model/src/gpt_oss_graph.rs
Normal file
195
crates/xserv-model/src/gpt_oss_graph.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
//! CUDA-graph replay for gpt-oss batch=1 decode (Phase 21).
|
||||
//!
|
||||
//! A decode step launches ~200 kernels; with sparse MoE the GPU work is only
|
||||
//! a few ms, so launch overhead dominates TPOT. The whole step (embedding →
|
||||
//! 24 layers → logits) is captured ONCE into a CUDA graph and replayed per
|
||||
//! token with a single `cudaGraphLaunch`.
|
||||
//!
|
||||
//! Why the existing forward is capturable as-is:
|
||||
//! - Every per-step variable input lives in a stable-address device buffer
|
||||
//! whose CONTENTS are updated outside the captured region: token id and
|
||||
//! position (persistent buffers owned here), block table and context lens
|
||||
//! (PagedKVCache GPU buffers, refreshed by `decode_prepare`). The KV scatter
|
||||
//! and paged attention kernels read their write/read positions from those
|
||||
//! buffers, and the sparse-MoE GEMVs read expert ids from `topk_ids` written
|
||||
//! earlier in the same graph — all data-dependent, no host branching.
|
||||
//! - Kernel launches go through the thread-local launch stream
|
||||
//! (`xserv_cuda::stream::push_stream`), so the capture stream sees them.
|
||||
//! - Intermediate tensors come from the caching allocator. Blocks freed while
|
||||
//! capturing are quarantined (`allocator::begin_retain`) for the graph's
|
||||
//! lifetime so no later allocation can take ownership of memory the graph
|
||||
//! still references on every replay.
|
||||
//!
|
||||
//! Capture preconditions: at least one EAGER decode step must have run first,
|
||||
//! so the allocator pool already holds every bucket size the step needs
|
||||
//! (a pool-miss inside capture would call cudaMalloc — illegal while
|
||||
//! capturing) and cuBLAS has finished its one-time per-shape setup.
|
||||
|
||||
use std::ffi::c_void;
|
||||
|
||||
use xserv_cuda::allocator::{self, RetainedBlocks};
|
||||
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
|
||||
use xserv_tensor::Tensor;
|
||||
|
||||
use crate::gpt_oss::GptOss;
|
||||
use crate::paged_kv_cache::PagedKVCache;
|
||||
|
||||
pub struct GptOssDecodeGraph {
|
||||
stream: CudaStream,
|
||||
graph: CudaGraph,
|
||||
ids_buf: GpuBuffer, // [1] u32, persistent graph input
|
||||
pos_buf: GpuBuffer, // [1] u32, persistent graph input
|
||||
logits: Tensor, // graph output; rewritten in place by every replay
|
||||
_arena: RetainedBlocks,
|
||||
}
|
||||
|
||||
impl GptOssDecodeGraph {
|
||||
/// Capture one batch=1 decode step and replay it once (capture records
|
||||
/// without executing, so the replay performs this token's computation).
|
||||
pub fn capture(
|
||||
model: &GptOss,
|
||||
token: u32,
|
||||
position: usize,
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Self {
|
||||
let stream = CudaStream::new().expect("create capture stream");
|
||||
let mut ids_buf = allocator::cached_alloc(4).expect("alloc ids buf");
|
||||
let mut pos_buf = allocator::cached_alloc(4).expect("alloc pos buf");
|
||||
|
||||
model.decode_prepare(&[position], &[slot], cache);
|
||||
ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
|
||||
pos_buf
|
||||
.copy_from_host(&(position as u32).to_le_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Retained warmup: run the exact step once eagerly with the quarantine
|
||||
// ON. Freed intermediates are held back instead of recycled, so the
|
||||
// pool ends up stocked with a dedicated block for EVERY allocation the
|
||||
// step performs. The capture below repeats the same allocation
|
||||
// sequence and therefore never misses the pool — a pool miss would
|
||||
// call cudaMalloc, which is illegal while a stream is capturing (this
|
||||
// is also why one block per bucket is not enough: the capture's own
|
||||
// quarantine keeps freed blocks out of reuse). Re-running the step is
|
||||
// idempotent: the KV scatter rewrites the same cache position.
|
||||
allocator::begin_retain();
|
||||
{
|
||||
let _guard = xserv_cuda::push_stream(&stream);
|
||||
let _ = model.decode_core(
|
||||
ids_buf.as_ptr() as *const c_void,
|
||||
pos_buf.as_ptr() as *const c_void,
|
||||
1,
|
||||
cache,
|
||||
);
|
||||
}
|
||||
drop(allocator::end_retain()); // release the warmup blocks to the pool
|
||||
stream.synchronize().expect("warmup sync");
|
||||
|
||||
allocator::begin_retain();
|
||||
let mut graph = CudaGraph::new();
|
||||
let logits;
|
||||
{
|
||||
let _guard = xserv_cuda::stream::push_stream(&stream);
|
||||
graph
|
||||
.begin_capture(&stream)
|
||||
.expect("begin decode-graph capture");
|
||||
logits = model.decode_core(
|
||||
ids_buf.as_ptr() as *const c_void,
|
||||
pos_buf.as_ptr() as *const c_void,
|
||||
1,
|
||||
cache,
|
||||
);
|
||||
graph
|
||||
.end_capture(&stream)
|
||||
.expect("end decode-graph capture");
|
||||
}
|
||||
let arena = allocator::end_retain();
|
||||
|
||||
graph.launch(&stream).expect("first decode-graph replay");
|
||||
cache.advance_seq_len(slot, 1);
|
||||
|
||||
Self {
|
||||
stream,
|
||||
graph,
|
||||
ids_buf,
|
||||
pos_buf,
|
||||
logits,
|
||||
_arena: arena,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run one decode step by replaying the captured graph.
|
||||
pub fn step(
|
||||
&mut self,
|
||||
model: &GptOss,
|
||||
token: u32,
|
||||
position: usize,
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
model.decode_prepare(&[position], &[slot], cache);
|
||||
self.ids_buf.copy_from_host(&token.to_le_bytes()).unwrap();
|
||||
self.pos_buf
|
||||
.copy_from_host(&(position as u32).to_le_bytes())
|
||||
.unwrap();
|
||||
self.graph
|
||||
.launch(&self.stream)
|
||||
.expect("decode-graph replay");
|
||||
cache.advance_seq_len(slot, 1);
|
||||
// Shallow clone: the caller reads these logits before the next replay
|
||||
// rewrites the underlying buffer.
|
||||
self.logits.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Lazy capture policy: first decode step of the process runs eager (warms the
|
||||
/// allocator pool + cuBLAS so capture performs no "unsafe" CUDA calls), the
|
||||
/// second is captured, the rest replay. Batch>1 always falls back to eager.
|
||||
/// Disable with XSERV_DECODE_GRAPH=0.
|
||||
pub struct GraphedGptOssDecoder {
|
||||
graph: Option<GptOssDecodeGraph>,
|
||||
eager_steps: u32,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl GraphedGptOssDecoder {
|
||||
pub fn new() -> Self {
|
||||
let enabled = std::env::var("XSERV_DECODE_GRAPH")
|
||||
.map(|v| v != "0")
|
||||
.unwrap_or(true);
|
||||
Self {
|
||||
graph: None,
|
||||
eager_steps: 0,
|
||||
enabled,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode(
|
||||
&mut self,
|
||||
model: &GptOss,
|
||||
tokens: &[u32],
|
||||
positions: &[usize],
|
||||
slots: &[usize],
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
if self.enabled && tokens.len() == 1 {
|
||||
if let Some(g) = self.graph.as_mut() {
|
||||
return g.step(model, tokens[0], positions[0], slots[0], cache);
|
||||
}
|
||||
if self.eager_steps >= 1 {
|
||||
let g = GptOssDecodeGraph::capture(model, tokens[0], positions[0], slots[0], cache);
|
||||
let logits = g.logits.clone();
|
||||
self.graph = Some(g);
|
||||
return logits;
|
||||
}
|
||||
}
|
||||
self.eager_steps += 1;
|
||||
model.forward_decode_paged(tokens, positions, slots, cache)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GraphedGptOssDecoder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use crate::config::ModelConfig;
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Tensor};
|
||||
|
||||
/// GPU-resident KV cache. Pre-allocates max_seq_len on GPU,
|
||||
/// appends new K/V via D2D copy at offset (no CPU round-trip).
|
||||
@@ -46,17 +46,43 @@ impl GpuKVCache {
|
||||
v_staging.push(GpuBuffer::alloc(buf_size).expect("alloc KV staging V"));
|
||||
}
|
||||
|
||||
Self { k_bufs, v_bufs, k_staging, v_staging, seq_len: 0, max_seq_len, num_kv_heads, head_dim, elem_size, dtype, device }
|
||||
Self {
|
||||
k_bufs,
|
||||
v_bufs,
|
||||
k_staging,
|
||||
v_staging,
|
||||
seq_len: 0,
|
||||
max_seq_len,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
elem_size,
|
||||
dtype,
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn seq_len(&self) -> usize { self.seq_len }
|
||||
pub fn max_seq_len(&self) -> usize { self.max_seq_len }
|
||||
pub fn seq_len(&self) -> usize {
|
||||
self.seq_len
|
||||
}
|
||||
pub fn max_seq_len(&self) -> usize {
|
||||
self.max_seq_len
|
||||
}
|
||||
|
||||
/// Append new K/V tensors for a given layer.
|
||||
/// k_new, v_new: [1, num_kv_heads, new_tokens, head_dim] on GPU, contiguous.
|
||||
/// `write_pos` is the sequence position to write at (caller manages this).
|
||||
pub fn append(&mut self, layer: usize, k_new: &Tensor, v_new: &Tensor, new_tokens: usize, write_pos: usize) {
|
||||
assert!(write_pos + new_tokens <= self.max_seq_len, "KV cache overflow");
|
||||
pub fn append(
|
||||
&mut self,
|
||||
layer: usize,
|
||||
k_new: &Tensor,
|
||||
v_new: &Tensor,
|
||||
new_tokens: usize,
|
||||
write_pos: usize,
|
||||
) {
|
||||
assert!(
|
||||
write_pos + new_tokens <= self.max_seq_len,
|
||||
"KV cache overflow"
|
||||
);
|
||||
let es = self.elem_size;
|
||||
let hd = self.head_dim;
|
||||
let max_s = self.max_seq_len;
|
||||
@@ -69,14 +95,23 @@ impl GpuKVCache {
|
||||
let src_off = h * new_tokens * hd * es;
|
||||
let dst_off = (h * max_s + write_pos) * hd * es;
|
||||
let count = new_tokens * hd * es;
|
||||
self.k_bufs[layer].copy_from_device_at(k_src, src_off, dst_off, count).unwrap();
|
||||
self.v_bufs[layer].copy_from_device_at(v_src, src_off, dst_off, count).unwrap();
|
||||
self.k_bufs[layer]
|
||||
.copy_from_device_at(k_src, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
self.v_bufs[layer]
|
||||
.copy_from_device_at(v_src, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn advance_seq_len(&mut self, new_tokens: usize) {
|
||||
self.seq_len += new_tokens;
|
||||
assert!(self.seq_len <= self.max_seq_len, "KV cache seq_len ({}) exceeds max_seq_len ({})", self.seq_len, self.max_seq_len);
|
||||
assert!(
|
||||
self.seq_len <= self.max_seq_len,
|
||||
"KV cache seq_len ({}) exceeds max_seq_len ({})",
|
||||
self.seq_len,
|
||||
self.max_seq_len
|
||||
);
|
||||
}
|
||||
|
||||
/// Get K/V cache tensors for a layer up to `seq_len` tokens: [1, num_kv_heads, seq_len, head_dim]
|
||||
@@ -86,7 +121,11 @@ impl GpuKVCache {
|
||||
}
|
||||
|
||||
pub fn get_kv_len(&mut self, layer: usize, sl: usize) -> (Tensor, Tensor) {
|
||||
assert!(sl <= self.max_seq_len, "get_kv_len: sl ({sl}) exceeds max_seq_len ({})", self.max_seq_len);
|
||||
assert!(
|
||||
sl <= self.max_seq_len,
|
||||
"get_kv_len: sl ({sl}) exceeds max_seq_len ({})",
|
||||
self.max_seq_len
|
||||
);
|
||||
let hd = self.head_dim;
|
||||
let nh = self.num_kv_heads;
|
||||
let es = self.elem_size;
|
||||
@@ -104,8 +143,12 @@ impl GpuKVCache {
|
||||
let src_off = (h * max_s) * hd * es;
|
||||
let dst_off = (h * sl) * hd * es;
|
||||
let count = sl * hd * es;
|
||||
k_stg.copy_from_device_at(k_buf, src_off, dst_off, count).unwrap();
|
||||
v_stg.copy_from_device_at(v_buf, src_off, dst_off, count).unwrap();
|
||||
k_stg
|
||||
.copy_from_device_at(k_buf, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
v_stg
|
||||
.copy_from_device_at(v_buf, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
}
|
||||
// Grab raw pointers before dropping the mutable borrows
|
||||
let k_ptr = k_stg.as_mut_ptr();
|
||||
@@ -117,20 +160,35 @@ impl GpuKVCache {
|
||||
// get_kv_len call overwrites the staging buffer).
|
||||
let shape = &[1usize, nh, sl, hd];
|
||||
let k = unsafe {
|
||||
tensor_from_gpu_buffer(GpuBuffer::borrow_raw(k_ptr, out_size), shape, self.dtype, self.device)
|
||||
tensor_from_gpu_buffer(
|
||||
GpuBuffer::borrow_raw(k_ptr, out_size),
|
||||
shape,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
};
|
||||
let v = unsafe {
|
||||
tensor_from_gpu_buffer(GpuBuffer::borrow_raw(v_ptr, out_size), shape, self.dtype, self.device)
|
||||
tensor_from_gpu_buffer(
|
||||
GpuBuffer::borrow_raw(v_ptr, out_size),
|
||||
shape,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
};
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Tensor from a GpuBuffer (takes ownership).
|
||||
unsafe fn tensor_from_gpu_buffer(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
|
||||
use xserv_tensor::storage::Storage;
|
||||
use xserv_tensor::shape::contiguous_strides;
|
||||
unsafe fn tensor_from_gpu_buffer(
|
||||
buf: GpuBuffer,
|
||||
shape: &[usize],
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Tensor {
|
||||
use smallvec::SmallVec;
|
||||
use xserv_tensor::shape::contiguous_strides;
|
||||
use xserv_tensor::storage::Storage;
|
||||
|
||||
let storage = Storage::cuda(buf, device);
|
||||
Tensor::from_storage(
|
||||
@@ -146,6 +204,11 @@ unsafe fn tensor_from_gpu_buffer(buf: GpuBuffer, shape: &[usize], dtype: DType,
|
||||
///
|
||||
/// # Safety
|
||||
/// `buf` must be a valid GPU allocation with at least `product(shape) * dtype.size_bytes()` bytes.
|
||||
pub unsafe fn tensor_from_gpu_buffer_pub(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
|
||||
pub unsafe fn tensor_from_gpu_buffer_pub(
|
||||
buf: GpuBuffer,
|
||||
shape: &[usize],
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Tensor {
|
||||
tensor_from_gpu_buffer(buf, shape, dtype, device)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ pub mod config;
|
||||
pub mod decode_graph;
|
||||
pub mod gpt2;
|
||||
pub mod gpt_oss;
|
||||
pub mod gpt_oss_graph;
|
||||
pub mod kv_cache;
|
||||
pub mod loader;
|
||||
pub mod paged_kv_cache;
|
||||
@@ -10,10 +11,11 @@ pub mod sampling;
|
||||
|
||||
pub use config::ModelConfig;
|
||||
pub use decode_graph::{DecodeGraphState, LayerWeightPtrs};
|
||||
pub use gpt2::{GPT2, KVCache};
|
||||
pub use gpt_oss::GptOss;
|
||||
pub use gpt_oss_graph::{GptOssDecodeGraph, GraphedGptOssDecoder};
|
||||
pub use gpt2::{GPT2, KVCache};
|
||||
pub use kv_cache::GpuKVCache;
|
||||
pub use paged_kv_cache::{BlockAllocator, Location, PagedKVCache, BLOCK_SIZE};
|
||||
pub use paged_kv_cache::{BLOCK_SIZE, BlockAllocator, Location, PagedKVCache};
|
||||
pub use qwen3::Qwen3;
|
||||
pub use sampling::{SamplingParams, sample, sample_greedy_penalized};
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ use std::path::Path;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
pub fn load_safetensors(path: &Path, device: Device) -> HashMap<String, Tensor> {
|
||||
let data = std::fs::read(path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
|
||||
let data =
|
||||
std::fs::read(path).unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
|
||||
let st = SafeTensors::deserialize(&data)
|
||||
.unwrap_or_else(|e| panic!("failed to parse safetensors {}: {e}", path.display()));
|
||||
|
||||
@@ -60,7 +60,11 @@ pub fn load_model_dir(dir: &Path, device: Device) -> HashMap<String, Tensor> {
|
||||
all_tensors.extend(tensors);
|
||||
}
|
||||
|
||||
assert!(!all_tensors.is_empty(), "no safetensors files found in {}", dir.display());
|
||||
assert!(
|
||||
!all_tensors.is_empty(),
|
||||
"no safetensors files found in {}",
|
||||
dir.display()
|
||||
);
|
||||
all_tensors
|
||||
}
|
||||
|
||||
@@ -84,8 +88,6 @@ fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
|
||||
};
|
||||
Tensor::from_slice(bfs, shape)
|
||||
}
|
||||
DType::FP8E4M3 => {
|
||||
Tensor::from_raw_bytes(raw_bytes, shape, DType::FP8E4M3)
|
||||
}
|
||||
DType::FP8E4M3 => Tensor::from_raw_bytes(raw_bytes, shape, DType::FP8E4M3),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,10 @@ impl BlockAllocator {
|
||||
for b in (1..total_blocks).rev() {
|
||||
free_stack.push(b as u32);
|
||||
}
|
||||
Self { free_stack, total: total_blocks }
|
||||
Self {
|
||||
free_stack,
|
||||
total: total_blocks,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn alloc(&mut self) -> Option<u32> {
|
||||
@@ -136,8 +139,14 @@ impl PagedKVCache {
|
||||
device: u32,
|
||||
) -> Self {
|
||||
Self::new_tp(
|
||||
config, config.num_kv_heads(), total_blocks, cpu_total_blocks,
|
||||
max_seqs, max_blocks_per_seq, dtype, device,
|
||||
config,
|
||||
config.num_kv_heads(),
|
||||
total_blocks,
|
||||
cpu_total_blocks,
|
||||
max_seqs,
|
||||
max_blocks_per_seq,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -155,7 +164,10 @@ impl PagedKVCache {
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Self {
|
||||
assert!(total_blocks >= 2, "need at least 2 blocks (one is sentinel)");
|
||||
assert!(
|
||||
total_blocks >= 2,
|
||||
"need at least 2 blocks (one is sentinel)"
|
||||
);
|
||||
let num_layers = config.num_layers();
|
||||
let head_dim = config.head_dim();
|
||||
let elem_size = dtype.size_bytes();
|
||||
@@ -179,11 +191,17 @@ impl PagedKVCache {
|
||||
if cpu_total_blocks >= 2 {
|
||||
let cpu_pool_bytes = cpu_total_blocks * block_bytes;
|
||||
for _ in 0..num_layers {
|
||||
cpu_k_pools.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU K swap pool"));
|
||||
cpu_v_pools.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU V swap pool"));
|
||||
cpu_k_pools
|
||||
.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU K swap pool"));
|
||||
cpu_v_pools
|
||||
.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU V swap pool"));
|
||||
}
|
||||
}
|
||||
let cpu_allocator = BlockAllocator::new(if cpu_total_blocks >= 2 { cpu_total_blocks } else { 0 });
|
||||
let cpu_allocator = BlockAllocator::new(if cpu_total_blocks >= 2 {
|
||||
cpu_total_blocks
|
||||
} else {
|
||||
0
|
||||
});
|
||||
|
||||
let block_table_gpu =
|
||||
GpuBuffer::alloc(max_seqs * max_blocks_per_seq * std::mem::size_of::<i32>())
|
||||
@@ -220,22 +238,49 @@ impl PagedKVCache {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_layers(&self) -> usize { self.num_layers }
|
||||
pub fn num_kv_heads(&self) -> usize { self.num_kv_heads }
|
||||
pub fn head_dim(&self) -> usize { self.head_dim }
|
||||
pub fn dtype(&self) -> DType { self.dtype }
|
||||
pub fn max_seqs(&self) -> usize { self.max_seqs }
|
||||
pub fn max_blocks_per_seq(&self) -> usize { self.max_blocks_per_seq }
|
||||
pub fn free_blocks(&self) -> usize { self.allocator.free_count() }
|
||||
pub fn total_blocks(&self) -> usize { self.allocator.total() }
|
||||
pub fn num_layers(&self) -> usize {
|
||||
self.num_layers
|
||||
}
|
||||
pub fn num_kv_heads(&self) -> usize {
|
||||
self.num_kv_heads
|
||||
}
|
||||
pub fn head_dim(&self) -> usize {
|
||||
self.head_dim
|
||||
}
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
pub fn max_seqs(&self) -> usize {
|
||||
self.max_seqs
|
||||
}
|
||||
pub fn max_blocks_per_seq(&self) -> usize {
|
||||
self.max_blocks_per_seq
|
||||
}
|
||||
pub fn free_blocks(&self) -> usize {
|
||||
self.allocator.free_count()
|
||||
}
|
||||
pub fn total_blocks(&self) -> usize {
|
||||
self.allocator.total()
|
||||
}
|
||||
|
||||
pub fn k_pool(&self, layer: usize) -> &GpuBuffer { &self.k_pools[layer] }
|
||||
pub fn v_pool(&self, layer: usize) -> &GpuBuffer { &self.v_pools[layer] }
|
||||
pub fn block_table_gpu(&self) -> &GpuBuffer { &self.block_table_gpu }
|
||||
pub fn context_lens_gpu(&self) -> &GpuBuffer { &self.context_lens_gpu }
|
||||
pub fn k_pool(&self, layer: usize) -> &GpuBuffer {
|
||||
&self.k_pools[layer]
|
||||
}
|
||||
pub fn v_pool(&self, layer: usize) -> &GpuBuffer {
|
||||
&self.v_pools[layer]
|
||||
}
|
||||
pub fn block_table_gpu(&self) -> &GpuBuffer {
|
||||
&self.block_table_gpu
|
||||
}
|
||||
pub fn context_lens_gpu(&self) -> &GpuBuffer {
|
||||
&self.context_lens_gpu
|
||||
}
|
||||
|
||||
pub fn seq_len(&self, slot: usize) -> usize {
|
||||
self.seq_states[slot].as_ref().map(|s| s.seq_len).unwrap_or(0)
|
||||
self.seq_states[slot]
|
||||
.as_ref()
|
||||
.map(|s| s.seq_len)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn is_slot_free(&self, slot: usize) -> bool {
|
||||
@@ -280,7 +325,11 @@ impl PagedKVCache {
|
||||
let state = self.seq_states[slot].as_ref().expect("unregistered slot");
|
||||
let cur = state.block_ids.len();
|
||||
let needed_total = (state.seq_len + new_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
if needed_total > cur { needed_total - cur } else { 0 }
|
||||
if needed_total > cur {
|
||||
needed_total - cur
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-allocate enough physical blocks in `slot` to cover positions
|
||||
@@ -290,8 +339,14 @@ impl PagedKVCache {
|
||||
let state = self.seq_states[slot].as_mut().expect("unregistered slot");
|
||||
let needed_total = (end_pos + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
while state.block_ids.len() < needed_total {
|
||||
let b = self.allocator.alloc().expect("out of blocks (caller must check)");
|
||||
assert!(state.block_ids.len() < self.max_blocks_per_seq, "block table overflow");
|
||||
let b = self
|
||||
.allocator
|
||||
.alloc()
|
||||
.expect("out of blocks (caller must check)");
|
||||
assert!(
|
||||
state.block_ids.len() < self.max_blocks_per_seq,
|
||||
"block table overflow"
|
||||
);
|
||||
state.block_ids.push(b);
|
||||
}
|
||||
}
|
||||
@@ -318,7 +373,9 @@ impl PagedKVCache {
|
||||
num_tokens: usize,
|
||||
start_pos: usize,
|
||||
) {
|
||||
if num_tokens == 0 { return; }
|
||||
if num_tokens == 0 {
|
||||
return;
|
||||
}
|
||||
// Make sure blocks exist for the target range.
|
||||
self.ensure_capacity(slot, start_pos + num_tokens);
|
||||
|
||||
@@ -328,15 +385,21 @@ impl PagedKVCache {
|
||||
|
||||
// Stage block_ids on the GPU. Pool-allocated so this is essentially
|
||||
// free after the first call (same bucket every step).
|
||||
let block_ids: Vec<i32> = self.seq_states[slot].as_ref().unwrap()
|
||||
.block_ids.iter().map(|&b| b as i32).collect();
|
||||
let block_ids: Vec<i32> = self.seq_states[slot]
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.block_ids
|
||||
.iter()
|
||||
.map(|&b| b as i32)
|
||||
.collect();
|
||||
let bytes = block_ids.len() * std::mem::size_of::<i32>();
|
||||
let mut block_ids_gpu = xserv_cuda::allocator::cached_alloc(bytes)
|
||||
.expect("alloc append block_ids");
|
||||
let block_ids_bytes = unsafe {
|
||||
std::slice::from_raw_parts(block_ids.as_ptr() as *const u8, bytes)
|
||||
};
|
||||
block_ids_gpu.copy_from_host(block_ids_bytes).expect("upload block_ids");
|
||||
let mut block_ids_gpu =
|
||||
xserv_cuda::allocator::cached_alloc(bytes).expect("alloc append block_ids");
|
||||
let block_ids_bytes =
|
||||
unsafe { std::slice::from_raw_parts(block_ids.as_ptr() as *const u8, bytes) };
|
||||
block_ids_gpu
|
||||
.copy_from_host(block_ids_bytes)
|
||||
.expect("upload block_ids");
|
||||
|
||||
let k_src = k_new.data_ptr() as *const std::ffi::c_void;
|
||||
let v_src = v_new.data_ptr() as *const std::ffi::c_void;
|
||||
@@ -345,11 +408,17 @@ impl PagedKVCache {
|
||||
|
||||
unsafe {
|
||||
xserv_kernels::reshape_and_cache_bf16(
|
||||
k_src, v_src,
|
||||
k_pool_ptr, v_pool_ptr,
|
||||
k_src,
|
||||
v_src,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
block_ids_gpu.as_ptr() as *const i32,
|
||||
num_tokens, nkv, hd, start_pos, bs,
|
||||
std::ptr::null_mut(),
|
||||
num_tokens,
|
||||
nkv,
|
||||
hd,
|
||||
start_pos,
|
||||
bs,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
// block_ids_gpu drops here; the launch on the null stream will have
|
||||
@@ -378,7 +447,9 @@ impl PagedKVCache {
|
||||
v_new: &Tensor,
|
||||
batch: usize,
|
||||
) {
|
||||
if batch == 0 { return; }
|
||||
if batch == 0 {
|
||||
return;
|
||||
}
|
||||
let nkv = self.num_kv_heads;
|
||||
let hd = self.head_dim;
|
||||
debug_assert_eq!(k_new.shape(), &[batch, nkv, hd]);
|
||||
@@ -393,11 +464,18 @@ impl PagedKVCache {
|
||||
|
||||
unsafe {
|
||||
xserv_kernels::reshape_and_cache_batched_bf16(
|
||||
k_src, v_src,
|
||||
k_pool_ptr, v_pool_ptr,
|
||||
bt_ptr, cl_ptr,
|
||||
batch, nkv, hd, BLOCK_SIZE, self.max_blocks_per_seq,
|
||||
std::ptr::null_mut(),
|
||||
k_src,
|
||||
v_src,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
batch,
|
||||
nkv,
|
||||
hd,
|
||||
BLOCK_SIZE,
|
||||
self.max_blocks_per_seq,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -447,7 +525,10 @@ impl PagedKVCache {
|
||||
/// before advance_seq_len has run).
|
||||
pub fn sync_active_batch_with_lens(&mut self, slots: &[usize], kv_lens: &[i32]) {
|
||||
assert_eq!(slots.len(), kv_lens.len());
|
||||
assert!(slots.len() <= self.max_seqs, "active batch exceeds max_seqs");
|
||||
assert!(
|
||||
slots.len() <= self.max_seqs,
|
||||
"active batch exceeds max_seqs"
|
||||
);
|
||||
let stride = self.max_blocks_per_seq;
|
||||
for row in &mut self.block_table_host {
|
||||
*row = 0;
|
||||
@@ -456,7 +537,9 @@ impl PagedKVCache {
|
||||
*cl = 0;
|
||||
}
|
||||
for (i, &slot) in slots.iter().enumerate() {
|
||||
let s = self.seq_states[slot].as_ref().expect("unregistered slot in active batch");
|
||||
let s = self.seq_states[slot]
|
||||
.as_ref()
|
||||
.expect("unregistered slot in active batch");
|
||||
let row = &mut self.block_table_host[i * stride..(i + 1) * stride];
|
||||
for (j, b) in s.block_ids.iter().enumerate() {
|
||||
row[j] = *b as i32;
|
||||
@@ -515,8 +598,12 @@ impl PagedKVCache {
|
||||
let src_off = ((phys * nkv + h) * bs + slot_in_blk) * hd * es;
|
||||
let dst_off = (h * sl + p) * hd * es;
|
||||
let count = chunk * hd * es;
|
||||
k_dst.copy_from_device_at(k_pool, src_off, dst_off, count).unwrap();
|
||||
v_dst.copy_from_device_at(v_pool, src_off, dst_off, count).unwrap();
|
||||
k_dst
|
||||
.copy_from_device_at(k_pool, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
v_dst
|
||||
.copy_from_device_at(v_pool, src_off, dst_off, count)
|
||||
.unwrap();
|
||||
}
|
||||
p += chunk;
|
||||
}
|
||||
@@ -529,16 +616,26 @@ impl PagedKVCache {
|
||||
|
||||
// ----- Swapping (vLLM-style preemption to pinned host memory) -----
|
||||
|
||||
pub fn free_cpu_blocks(&self) -> usize { self.cpu_allocator.free_count() }
|
||||
pub fn swap_enabled(&self) -> bool { !self.cpu_k_pools.is_empty() }
|
||||
pub fn free_cpu_blocks(&self) -> usize {
|
||||
self.cpu_allocator.free_count()
|
||||
}
|
||||
pub fn swap_enabled(&self) -> bool {
|
||||
!self.cpu_k_pools.is_empty()
|
||||
}
|
||||
|
||||
pub fn is_swapped(&self, slot: usize) -> bool {
|
||||
matches!(self.seq_states[slot].as_ref().map(|s| s.location), Some(Location::Cpu))
|
||||
matches!(
|
||||
self.seq_states[slot].as_ref().map(|s| s.location),
|
||||
Some(Location::Cpu)
|
||||
)
|
||||
}
|
||||
|
||||
/// Number of physical blocks currently held by `slot` (in either pool).
|
||||
pub fn block_count(&self, slot: usize) -> usize {
|
||||
self.seq_states[slot].as_ref().map(|s| s.block_ids.len()).unwrap_or(0)
|
||||
self.seq_states[slot]
|
||||
.as_ref()
|
||||
.map(|s| s.block_ids.len())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Whether a swapped sequence at `slot` can be brought back (enough free GPU blocks).
|
||||
@@ -554,11 +651,17 @@ impl PagedKVCache {
|
||||
/// Evict `slot`'s KV from GPU to pinned host memory and free its GPU blocks.
|
||||
/// The slot stays registered (location = Cpu); the sequence is paused.
|
||||
pub fn swap_out(&mut self, slot: usize) -> Result<(), &'static str> {
|
||||
let state = self.seq_states[slot].as_ref().ok_or("swap_out: empty slot")?;
|
||||
if state.location == Location::Cpu { return Ok(()); }
|
||||
let state = self.seq_states[slot]
|
||||
.as_ref()
|
||||
.ok_or("swap_out: empty slot")?;
|
||||
if state.location == Location::Cpu {
|
||||
return Ok(());
|
||||
}
|
||||
let gpu_ids = state.block_ids.clone();
|
||||
let n = gpu_ids.len();
|
||||
if !self.cpu_allocator.can_alloc(n) { return Err("swap_out: CPU pool full"); }
|
||||
if !self.cpu_allocator.can_alloc(n) {
|
||||
return Err("swap_out: CPU pool full");
|
||||
}
|
||||
|
||||
let cpu_ids: Vec<u32> = (0..n)
|
||||
.map(|_| self.cpu_allocator.alloc().expect("checked can_alloc"))
|
||||
@@ -570,10 +673,18 @@ impl PagedKVCache {
|
||||
let g_off = gpu_ids[i] as usize * bb;
|
||||
let c_off = cpu_ids[i] as usize * bb;
|
||||
self.k_pools[layer]
|
||||
.copy_to_host_at(&mut self.cpu_k_pools[layer].as_mut_slice()[c_off..c_off + bb], g_off, bb)
|
||||
.copy_to_host_at(
|
||||
&mut self.cpu_k_pools[layer].as_mut_slice()[c_off..c_off + bb],
|
||||
g_off,
|
||||
bb,
|
||||
)
|
||||
.unwrap();
|
||||
self.v_pools[layer]
|
||||
.copy_to_host_at(&mut self.cpu_v_pools[layer].as_mut_slice()[c_off..c_off + bb], g_off, bb)
|
||||
.copy_to_host_at(
|
||||
&mut self.cpu_v_pools[layer].as_mut_slice()[c_off..c_off + bb],
|
||||
g_off,
|
||||
bb,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -589,11 +700,17 @@ impl PagedKVCache {
|
||||
|
||||
/// Bring `slot`'s KV back from host to GPU and free its CPU blocks.
|
||||
pub fn swap_in(&mut self, slot: usize) -> Result<(), &'static str> {
|
||||
let state = self.seq_states[slot].as_ref().ok_or("swap_in: empty slot")?;
|
||||
if state.location == Location::Gpu { return Ok(()); }
|
||||
let state = self.seq_states[slot]
|
||||
.as_ref()
|
||||
.ok_or("swap_in: empty slot")?;
|
||||
if state.location == Location::Gpu {
|
||||
return Ok(());
|
||||
}
|
||||
let cpu_ids = state.block_ids.clone();
|
||||
let n = cpu_ids.len();
|
||||
if !self.allocator.can_alloc(n) { return Err("swap_in: GPU pool full"); }
|
||||
if !self.allocator.can_alloc(n) {
|
||||
return Err("swap_in: GPU pool full");
|
||||
}
|
||||
|
||||
let gpu_ids: Vec<u32> = (0..n)
|
||||
.map(|_| self.allocator.alloc().expect("checked can_alloc"))
|
||||
@@ -605,10 +722,18 @@ impl PagedKVCache {
|
||||
let g_off = gpu_ids[i] as usize * bb;
|
||||
let c_off = cpu_ids[i] as usize * bb;
|
||||
self.k_pools[layer]
|
||||
.copy_from_host_at(&self.cpu_k_pools[layer].as_slice()[c_off..c_off + bb], g_off, bb)
|
||||
.copy_from_host_at(
|
||||
&self.cpu_k_pools[layer].as_slice()[c_off..c_off + bb],
|
||||
g_off,
|
||||
bb,
|
||||
)
|
||||
.unwrap();
|
||||
self.v_pools[layer]
|
||||
.copy_from_host_at(&self.cpu_v_pools[layer].as_slice()[c_off..c_off + bb], g_off, bb)
|
||||
.copy_from_host_at(
|
||||
&self.cpu_v_pools[layer].as_slice()[c_off..c_off + bb],
|
||||
g_off,
|
||||
bb,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -623,7 +748,12 @@ impl PagedKVCache {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn tensor_from_owned_buf(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
|
||||
unsafe fn tensor_from_owned_buf(
|
||||
buf: GpuBuffer,
|
||||
shape: &[usize],
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Tensor {
|
||||
use smallvec::SmallVec;
|
||||
use xserv_tensor::shape::contiguous_strides;
|
||||
use xserv_tensor::storage::Storage;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
use half::bf16;
|
||||
use std::collections::HashMap;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
use crate::config::ModelConfig;
|
||||
use crate::gpt2::KVCache;
|
||||
@@ -13,7 +13,7 @@ pub struct Qwen3 {
|
||||
embed_tokens: Tensor,
|
||||
layers: Vec<Qwen3Block>,
|
||||
norm: Tensor,
|
||||
lm_head_t: Tensor, // precomputed transpose
|
||||
lm_head_t: Tensor, // precomputed transpose
|
||||
rope_cache: RopeCache,
|
||||
// Tensor parallelism. `tp` is None (or world==1) for single-GPU; otherwise
|
||||
// this rank holds 1/world of the heads and AllReduces after o_proj/down_proj.
|
||||
@@ -28,22 +28,29 @@ pub struct Qwen3 {
|
||||
}
|
||||
|
||||
struct Qwen3Block {
|
||||
input_norm: Tensor, // [hidden]
|
||||
input_norm: Tensor, // [hidden]
|
||||
qkv_proj_wt: Tensor, // FUSED: [hidden, (H+2*KV)*D] — Q|K|V columns
|
||||
q_dim: usize, // num_heads * head_dim (Q slice boundary)
|
||||
kv_dim: usize, // num_kv_heads * head_dim (K/V slice size)
|
||||
o_proj_wt: Tensor, // TRANSPOSED: [num_heads*head_dim, hidden]
|
||||
q_norm: Tensor, // [head_dim]
|
||||
k_norm: Tensor, // [head_dim]
|
||||
post_norm: Tensor, // [hidden]
|
||||
gate_up_proj_wt: Tensor, // FUSED: [hidden, 2*intermediate]
|
||||
down_proj_wt: Tensor, // TRANSPOSED: [intermediate, hidden]
|
||||
o_proj_wt: Tensor, // TRANSPOSED: [num_heads*head_dim, hidden]
|
||||
q_norm: Tensor, // [head_dim]
|
||||
k_norm: Tensor, // [head_dim]
|
||||
post_norm: Tensor, // [hidden]
|
||||
gate_up_proj_wt: Tensor, // FUSED: [hidden, 2*intermediate]
|
||||
down_proj_wt: Tensor, // TRANSPOSED: [intermediate, hidden]
|
||||
}
|
||||
|
||||
impl Qwen3Block {
|
||||
fn q_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, 0, self.q_dim) }
|
||||
fn k_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, self.q_dim, self.kv_dim) }
|
||||
fn v_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, self.q_dim + self.kv_dim, self.kv_dim) }
|
||||
fn q_proj_wt(&self) -> Tensor {
|
||||
self.qkv_proj_wt.narrow(1, 0, self.q_dim)
|
||||
}
|
||||
fn k_proj_wt(&self) -> Tensor {
|
||||
self.qkv_proj_wt.narrow(1, self.q_dim, self.kv_dim)
|
||||
}
|
||||
fn v_proj_wt(&self) -> Tensor {
|
||||
self.qkv_proj_wt
|
||||
.narrow(1, self.q_dim + self.kv_dim, self.kv_dim)
|
||||
}
|
||||
fn gate_proj_wt(&self) -> Tensor {
|
||||
let half = self.gate_up_proj_wt.shape()[1] / 2;
|
||||
self.gate_up_proj_wt.narrow(1, 0, half)
|
||||
@@ -80,18 +87,31 @@ impl Qwen3 {
|
||||
crate::init_kernels();
|
||||
let dev = Device::Cuda(device);
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
w.remove(name)
|
||||
.unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
// Replicated weight: upload whole to this rank's device.
|
||||
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
|
||||
// column-parallel: keep this rank's rows of [out, in], upload, transpose → [in, out/world].
|
||||
let col = |t: Tensor| -> Tensor { shard_rows(&t, rank, world).to_device(dev).transpose(0, 1).contiguous() };
|
||||
let col = |t: Tensor| -> Tensor {
|
||||
shard_rows(&t, rank, world)
|
||||
.to_device(dev)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
};
|
||||
// row-parallel: keep this rank's cols of [out, in], upload, transpose → [in/world, out].
|
||||
let row = |t: Tensor| -> Tensor { shard_cols(&t, rank, world).to_device(dev).transpose(0, 1).contiguous() };
|
||||
let row = |t: Tensor| -> Tensor {
|
||||
shard_cols(&t, rank, world)
|
||||
.to_device(dev)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
};
|
||||
|
||||
let embed_tokens = repl(take(&mut w, "model.embed_tokens.weight"));
|
||||
let norm = repl(take(&mut w, "model.norm.weight"));
|
||||
let lm_head_t = repl(take(&mut w, "lm_head.weight")).transpose(0, 1).contiguous();
|
||||
let lm_head_t = repl(take(&mut w, "lm_head.weight"))
|
||||
.transpose(0, 1)
|
||||
.contiguous();
|
||||
|
||||
let rope_cache = RopeCache::new(
|
||||
config.max_seq_len(),
|
||||
@@ -102,7 +122,10 @@ impl Qwen3 {
|
||||
let num_layers = config.num_layers();
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
if rank == 0 {
|
||||
eprintln!("Loading+sharding weights for {} layers (world={world})...", num_layers);
|
||||
eprintln!(
|
||||
"Loading+sharding weights for {} layers (world={world})...",
|
||||
num_layers
|
||||
);
|
||||
}
|
||||
for i in 0..num_layers {
|
||||
let p = format!("model.layers.{i}");
|
||||
@@ -126,7 +149,10 @@ impl Qwen3 {
|
||||
o_proj_wt: row(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
|
||||
q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))),
|
||||
k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))),
|
||||
post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))),
|
||||
post_norm: repl(take(
|
||||
&mut w,
|
||||
&format!("{p}.post_attention_layernorm.weight"),
|
||||
)),
|
||||
gate_up_proj_wt,
|
||||
down_proj_wt: row(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
|
||||
});
|
||||
@@ -165,7 +191,10 @@ impl Qwen3 {
|
||||
let dev = Device::Cuda(device);
|
||||
assert!(num_stages >= 1);
|
||||
let num_layers = config.num_layers();
|
||||
assert!(num_layers % num_stages == 0, "num_layers {num_layers} not divisible by pp {num_stages}");
|
||||
assert!(
|
||||
num_layers % num_stages == 0,
|
||||
"num_layers {num_layers} not divisible by pp {num_stages}"
|
||||
);
|
||||
let per_stage = num_layers / num_stages;
|
||||
let lo = stage * per_stage;
|
||||
let hi = lo + per_stage;
|
||||
@@ -173,16 +202,29 @@ impl Qwen3 {
|
||||
let is_last_stage = stage == num_stages - 1;
|
||||
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
w.remove(name)
|
||||
.unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
|
||||
// Pre-transpose like the TP path's `col`/`row` do for world==1 (no shard).
|
||||
let wt = |t: Tensor| -> Tensor { t.to_device(dev).transpose(0, 1).contiguous() };
|
||||
let placeholder = || Tensor::from_slice(&[bf16::ZERO], &[1, 1]).to_device(dev);
|
||||
|
||||
let embed_tokens = if is_first_stage { repl(take(&mut w, "model.embed_tokens.weight")) } else { placeholder() };
|
||||
let norm = if is_last_stage { repl(take(&mut w, "model.norm.weight")) } else { placeholder() };
|
||||
let lm_head_t = if is_last_stage { wt(take(&mut w, "lm_head.weight")) } else { placeholder() };
|
||||
let embed_tokens = if is_first_stage {
|
||||
repl(take(&mut w, "model.embed_tokens.weight"))
|
||||
} else {
|
||||
placeholder()
|
||||
};
|
||||
let norm = if is_last_stage {
|
||||
repl(take(&mut w, "model.norm.weight"))
|
||||
} else {
|
||||
placeholder()
|
||||
};
|
||||
let lm_head_t = if is_last_stage {
|
||||
wt(take(&mut w, "lm_head.weight"))
|
||||
} else {
|
||||
placeholder()
|
||||
};
|
||||
|
||||
let rope_cache = RopeCache::new(
|
||||
config.max_seq_len(),
|
||||
@@ -217,7 +259,10 @@ impl Qwen3 {
|
||||
o_proj_wt: wt(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
|
||||
q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))),
|
||||
k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))),
|
||||
post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))),
|
||||
post_norm: repl(take(
|
||||
&mut w,
|
||||
&format!("{p}.post_attention_layernorm.weight"),
|
||||
)),
|
||||
gate_up_proj_wt,
|
||||
down_proj_wt: wt(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
|
||||
});
|
||||
@@ -252,8 +297,12 @@ impl Qwen3 {
|
||||
matmul_2d(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
pub fn pp_is_first(&self) -> bool { self.is_first_stage }
|
||||
pub fn pp_is_last(&self) -> bool { self.is_last_stage }
|
||||
pub fn pp_is_first(&self) -> bool {
|
||||
self.is_first_stage
|
||||
}
|
||||
pub fn pp_is_last(&self) -> bool {
|
||||
self.is_last_stage
|
||||
}
|
||||
|
||||
/// PP prefill over THIS stage's layers. `x` is `[S, hidden]` (stage 0: from
|
||||
/// `embed`; otherwise received from the previous stage). Writes K/V for this
|
||||
@@ -276,7 +325,9 @@ impl Qwen3 {
|
||||
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
|
||||
paged_cache.advance_seq_len(slot, new_tokens);
|
||||
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
@@ -285,7 +336,9 @@ impl Qwen3 {
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv
|
||||
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
|
||||
.contiguous();
|
||||
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
@@ -305,10 +358,12 @@ impl Qwen3 {
|
||||
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
|
||||
let attn_out = flash_attention(&q, &k_full, &v_full, true);
|
||||
|
||||
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_merged =
|
||||
xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
@@ -356,7 +411,9 @@ impl Qwen3 {
|
||||
let qkv_all = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q_all = qkv_all.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k_all = qkv_all.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v_all = qkv_all.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
let v_all = qkv_all
|
||||
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
|
||||
.contiguous();
|
||||
|
||||
let mut q_rows: Vec<Tensor> = Vec::with_capacity(batch);
|
||||
for b in 0..batch {
|
||||
@@ -394,14 +451,23 @@ impl Qwen3 {
|
||||
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
|
||||
let attn_out = xserv_kernels::paged_decode_attention(
|
||||
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
|
||||
batch, num_heads, num_kv_heads, head_dim, max_blocks,
|
||||
&q_4d,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
batch,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_blocks,
|
||||
);
|
||||
|
||||
let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
@@ -441,7 +507,9 @@ impl Qwen3 {
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
@@ -450,7 +518,9 @@ impl Qwen3 {
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv
|
||||
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
|
||||
.contiguous();
|
||||
|
||||
let q = reshape_heads(&q, new_tokens, num_heads, head_dim);
|
||||
let k = reshape_heads(&k, new_tokens, num_kv_heads, head_dim);
|
||||
@@ -531,7 +601,9 @@ impl Qwen3 {
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q_all = qkv.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k_all = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v_all = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
let v_all = qkv
|
||||
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
|
||||
.contiguous();
|
||||
|
||||
// Per-sequence: reshape, qk-norm, RoPE, KV cache, attention, merge
|
||||
let mut attn_outputs: Vec<Tensor> = Vec::with_capacity(batch);
|
||||
@@ -583,7 +655,8 @@ impl Qwen3 {
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
// Fused add + rmsnorm
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
@@ -662,13 +735,15 @@ impl Qwen3 {
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); // [B, (H+2*KV)*D]
|
||||
let q_dim = num_heads * head_dim;
|
||||
let kv_dim = num_kv_heads * head_dim;
|
||||
let q_all = qkv.narrow(1, 0, q_dim); // [B, H*D] (view)
|
||||
let k_all = qkv.narrow(1, q_dim, kv_dim); // [B, KV*D] (view)
|
||||
let q_all = qkv.narrow(1, 0, q_dim); // [B, H*D] (view)
|
||||
let k_all = qkv.narrow(1, q_dim, kv_dim); // [B, KV*D] (view)
|
||||
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
|
||||
|
||||
// Per-head RMSNorm on contiguous copies (narrow views are strided).
|
||||
let q_flat = q_all.contiguous().reshape(&[batch * num_heads, head_dim]);
|
||||
let k_flat = k_all.contiguous().reshape(&[batch * num_kv_heads, head_dim]);
|
||||
let k_flat = k_all
|
||||
.contiguous()
|
||||
.reshape(&[batch * num_kv_heads, head_dim]);
|
||||
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
|
||||
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
|
||||
|
||||
@@ -688,8 +763,16 @@ impl Qwen3 {
|
||||
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let attn_out = xserv_kernels::paged_decode_attention(
|
||||
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
|
||||
batch, num_heads, num_kv_heads, head_dim, max_blocks,
|
||||
&q_4d,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
batch,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_blocks,
|
||||
);
|
||||
|
||||
// attn_out shape [B, H, 1, D] is contiguous-equivalent to [B, H*D].
|
||||
@@ -697,7 +780,8 @@ impl Qwen3 {
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj); // TP: sum partial attention outputs
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
// Fused gate+up projection: one GEMV instead of two.
|
||||
@@ -743,7 +827,9 @@ impl Qwen3 {
|
||||
paged_cache.advance_seq_len(slot, new_tokens);
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
@@ -752,7 +838,9 @@ impl Qwen3 {
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv
|
||||
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
|
||||
.contiguous();
|
||||
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
@@ -773,11 +861,13 @@ impl Qwen3 {
|
||||
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
|
||||
let attn_out = flash_attention(&q, &k_full, &v_full, true);
|
||||
|
||||
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_merged =
|
||||
xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj);
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
@@ -798,14 +888,16 @@ impl Qwen3 {
|
||||
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
let pos_offset = cache.seq_len();
|
||||
let hidden = self.config.hidden();
|
||||
|
||||
let num_heads = self.config.num_heads();
|
||||
let num_kv_heads = self.config.num_kv_heads();
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
@@ -814,7 +906,9 @@ impl Qwen3 {
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv
|
||||
.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim)
|
||||
.contiguous();
|
||||
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
@@ -834,10 +928,12 @@ impl Qwen3 {
|
||||
let (k_full, v_full) = cache.get_kv_len(layer_idx, pos_offset + new_tokens);
|
||||
|
||||
let attn_out = flash_attention(&q, &k_full, &v_full, true);
|
||||
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_merged =
|
||||
xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
@@ -856,28 +952,33 @@ impl Qwen3 {
|
||||
|
||||
/// Extract weight pointers for CUDA Graph capture.
|
||||
pub fn layer_weight_ptrs(&self) -> Vec<crate::decode_graph::LayerWeightPtrs> {
|
||||
self.layers.iter().map(|l| crate::decode_graph::LayerWeightPtrs {
|
||||
input_norm: l.input_norm.data_ptr() as *const std::ffi::c_void,
|
||||
q_proj_wt: l.q_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
k_proj_wt: l.k_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
v_proj_wt: l.v_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
o_proj_wt: l.o_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
q_norm: l.q_norm.data_ptr() as *const std::ffi::c_void,
|
||||
k_norm: l.k_norm.data_ptr() as *const std::ffi::c_void,
|
||||
post_norm: l.post_norm.data_ptr() as *const std::ffi::c_void,
|
||||
gate_proj_wt: l.gate_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
up_proj_wt: l.up_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
down_proj_wt: l.down_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
}).collect()
|
||||
self.layers
|
||||
.iter()
|
||||
.map(|l| crate::decode_graph::LayerWeightPtrs {
|
||||
input_norm: l.input_norm.data_ptr() as *const std::ffi::c_void,
|
||||
q_proj_wt: l.q_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
k_proj_wt: l.k_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
v_proj_wt: l.v_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
o_proj_wt: l.o_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
q_norm: l.q_norm.data_ptr() as *const std::ffi::c_void,
|
||||
k_norm: l.k_norm.data_ptr() as *const std::ffi::c_void,
|
||||
post_norm: l.post_norm.data_ptr() as *const std::ffi::c_void,
|
||||
gate_proj_wt: l.gate_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
up_proj_wt: l.up_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
down_proj_wt: l.down_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get pointers needed for CUDA Graph capture.
|
||||
pub fn graph_capture_ptrs(&self) -> (
|
||||
*const std::ffi::c_void, // norm weight
|
||||
*const std::ffi::c_void, // lm_head_t
|
||||
*const std::ffi::c_void, // embed_tokens
|
||||
*const std::ffi::c_void, // rope cos
|
||||
*const std::ffi::c_void, // rope sin
|
||||
pub fn graph_capture_ptrs(
|
||||
&self,
|
||||
) -> (
|
||||
*const std::ffi::c_void, // norm weight
|
||||
*const std::ffi::c_void, // lm_head_t
|
||||
*const std::ffi::c_void, // embed_tokens
|
||||
*const std::ffi::c_void, // rope cos
|
||||
*const std::ffi::c_void, // rope sin
|
||||
) {
|
||||
(
|
||||
self.norm.data_ptr() as *const std::ffi::c_void,
|
||||
@@ -895,11 +996,16 @@ impl Qwen3 {
|
||||
/// (column-parallel split: split the OUTPUT dim). `world==1` returns the whole.
|
||||
/// Input must be a contiguous CPU (or device) BF16 tensor.
|
||||
fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
if world == 1 { return t.clone(); }
|
||||
if world == 1 {
|
||||
return t.clone();
|
||||
}
|
||||
let shape = t.shape();
|
||||
assert_eq!(shape.len(), 2, "shard_rows expects 2D weight");
|
||||
let (rows, cols) = (shape[0], shape[1]);
|
||||
assert!(rows % world == 0, "rows {rows} not divisible by world {world}");
|
||||
assert!(
|
||||
rows % world == 0,
|
||||
"rows {rows} not divisible by world {world}"
|
||||
);
|
||||
let local = rows / world;
|
||||
let host = t.to_device(Device::Cpu);
|
||||
let data = host.as_slice::<bf16>();
|
||||
@@ -911,11 +1017,16 @@ fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
/// Keep this rank's column-block of a 2D `[rows, cols]` BF16 tensor (row-parallel
|
||||
/// split: split the INPUT dim). Strided copy. `world==1` returns the whole.
|
||||
fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
||||
if world == 1 { return t.clone(); }
|
||||
if world == 1 {
|
||||
return t.clone();
|
||||
}
|
||||
let shape = t.shape();
|
||||
assert_eq!(shape.len(), 2, "shard_cols expects 2D weight");
|
||||
let (rows, cols) = (shape[0], shape[1]);
|
||||
assert!(cols % world == 0, "cols {cols} not divisible by world {world}");
|
||||
assert!(
|
||||
cols % world == 0,
|
||||
"cols {cols} not divisible by world {world}"
|
||||
);
|
||||
let local = cols / world;
|
||||
let c0 = rank * local;
|
||||
let host = t.to_device(Device::Cpu);
|
||||
@@ -1009,7 +1120,9 @@ fn transpose_from_rope(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
|
||||
}
|
||||
|
||||
fn repeat_kv(x: &Tensor, n_rep: usize) -> Tensor {
|
||||
if n_rep == 1 { return x.clone(); }
|
||||
if n_rep == 1 {
|
||||
return x.clone();
|
||||
}
|
||||
let kv_heads = x.shape()[1];
|
||||
let seq_len = x.shape()[2];
|
||||
let head_dim = x.shape()[3];
|
||||
@@ -1065,11 +1178,16 @@ fn concat_rows(rows: &[Tensor]) -> Tensor {
|
||||
let src_buf = row.storage().gpu_buffer();
|
||||
let src_offset = row.offset() * elem_size;
|
||||
let dst_offset = b * row_bytes;
|
||||
out_buf.copy_from_device_at(src_buf, src_offset, dst_offset, row_bytes).unwrap();
|
||||
out_buf
|
||||
.copy_from_device_at(src_buf, src_offset, dst_offset, row_bytes)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Wrap in a Tensor
|
||||
let device_id = match device { Device::Cuda(id) => id, _ => panic!("expected CUDA device") };
|
||||
let device_id = match device {
|
||||
Device::Cuda(id) => id,
|
||||
_ => panic!("expected CUDA device"),
|
||||
};
|
||||
unsafe {
|
||||
crate::kv_cache::tensor_from_gpu_buffer_pub(out_buf, &[batch, cols], dtype, device_id)
|
||||
}
|
||||
@@ -1082,12 +1200,15 @@ fn cat_cols(tensors: &[&Tensor]) -> Tensor {
|
||||
let dtype = tensors[0].dtype();
|
||||
let device = tensors[0].device();
|
||||
let elem = dtype.size_bytes();
|
||||
let total_cols: usize = tensors.iter().map(|t| {
|
||||
assert_eq!(t.ndim(), 2);
|
||||
assert_eq!(t.shape()[0], rows);
|
||||
assert!(t.is_contiguous());
|
||||
t.shape()[1]
|
||||
}).sum();
|
||||
let total_cols: usize = tensors
|
||||
.iter()
|
||||
.map(|t| {
|
||||
assert_eq!(t.ndim(), 2);
|
||||
assert_eq!(t.shape()[0], rows);
|
||||
assert!(t.is_contiguous());
|
||||
t.shape()[1]
|
||||
})
|
||||
.sum();
|
||||
let out = Tensor::empty(&[rows, total_cols], dtype, device);
|
||||
let dst_base = out.data_ptr() as *mut u8;
|
||||
for r in 0..rows {
|
||||
@@ -1126,7 +1247,9 @@ pub fn sample_greedy(logits: &Tensor) -> u32 {
|
||||
let seq_len = logits.shape()[0];
|
||||
let data = logits_cpu.as_slice::<bf16>();
|
||||
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
last.iter().enumerate()
|
||||
last.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(i, _)| i as u32).unwrap()
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
@@ -11,7 +11,11 @@ pub struct SamplingParams {
|
||||
|
||||
impl Default for SamplingParams {
|
||||
fn default() -> Self {
|
||||
Self { temperature: 0.0, top_k: 0, top_p: 1.0 }
|
||||
Self {
|
||||
temperature: 0.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +23,18 @@ impl Default for SamplingParams {
|
||||
/// Uses the last position's logits. Handles both F32 and BF16 dtypes.
|
||||
pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
// Greedy fast path: GPU argmax + 4-byte D2H instead of copying the whole
|
||||
// [seq, vocab] logits to the host and scanning it (~201k bf16/token).
|
||||
// NaN logits lose every `>` comparison in the kernel, matching the
|
||||
// NaN-safe host argmax below.
|
||||
if params.temperature == 0.0
|
||||
&& logits.dtype() == DType::BF16
|
||||
&& matches!(logits.device(), Device::Cuda(_))
|
||||
&& logits.is_contiguous()
|
||||
{
|
||||
let ids = xserv_kernels::argmax_bf16_to_host(logits);
|
||||
return *ids.last().unwrap();
|
||||
}
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
@@ -122,9 +138,14 @@ pub fn sample_greedy_penalized(logits: &Tensor, recent: &[u32], penalty: f32) ->
|
||||
let seq_len = logits.shape()[0];
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let mut last_row: Vec<f32> = match logits.dtype() {
|
||||
DType::F32 => logits_cpu.as_slice::<f32>()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec(),
|
||||
DType::BF16 => logits_cpu.as_slice::<bf16>()[(seq_len - 1) * vocab_size..seq_len * vocab_size]
|
||||
.iter().map(|v| v.to_f32()).collect(),
|
||||
DType::F32 => {
|
||||
logits_cpu.as_slice::<f32>()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec()
|
||||
}
|
||||
DType::BF16 => logits_cpu.as_slice::<bf16>()
|
||||
[(seq_len - 1) * vocab_size..seq_len * vocab_size]
|
||||
.iter()
|
||||
.map(|v| v.to_f32())
|
||||
.collect(),
|
||||
_ => panic!("unsupported dtype for sampling: {:?}", logits.dtype()),
|
||||
};
|
||||
if penalty > 1.0 {
|
||||
|
||||
@@ -72,7 +72,10 @@ impl ChatTemplate {
|
||||
let source = std::fs::read_to_string(&jinja_path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", jinja_path.display()));
|
||||
eprintln!("[chat-template] loaded from {}", jinja_path.display());
|
||||
return Self { source, model_type: model_type.to_string() };
|
||||
return Self {
|
||||
source,
|
||||
model_type: model_type.to_string(),
|
||||
};
|
||||
}
|
||||
|
||||
// 2. Try tokenizer_config.json → chat_template field
|
||||
@@ -82,7 +85,10 @@ impl ChatTemplate {
|
||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||
if let Some(ct) = v.get("chat_template").and_then(|v| v.as_str()) {
|
||||
eprintln!("[chat-template] loaded from tokenizer_config.json");
|
||||
return Self { source: ct.to_string(), model_type: model_type.to_string() };
|
||||
return Self {
|
||||
source: ct.to_string(),
|
||||
model_type: model_type.to_string(),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -90,7 +96,10 @@ impl ChatTemplate {
|
||||
|
||||
// 3. No template found — use empty source, will fall back to hardcoded
|
||||
eprintln!("[chat-template] no Jinja template found, using hardcoded fallback");
|
||||
Self { source: String::new(), model_type: model_type.to_string() }
|
||||
Self {
|
||||
source: String::new(),
|
||||
model_type: model_type.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn render(&self, messages: &[Message]) -> String {
|
||||
@@ -206,7 +215,10 @@ fn build_prompt_gpt_oss(messages: &[Message]) -> String {
|
||||
prompt.push_str("<|start|>system<|message|>");
|
||||
prompt.push_str("You are ChatGPT, a large language model trained by OpenAI.\n");
|
||||
prompt.push_str("Knowledge cutoff: 2024-06\n");
|
||||
prompt.push_str(&format!("Current date: {}\n\n", strftime_now("%Y-%m-%d".to_string())));
|
||||
prompt.push_str(&format!(
|
||||
"Current date: {}\n\n",
|
||||
strftime_now("%Y-%m-%d".to_string())
|
||||
));
|
||||
prompt.push_str("Reasoning: low\n\n");
|
||||
prompt.push_str("# Valid channels: analysis, commentary, final. Channel must be included for every message.");
|
||||
prompt.push_str("<|end|>");
|
||||
@@ -334,13 +346,11 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": prompt_token_count + completion_token_count
|
||||
}
|
||||
})).into_response()
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn chat_stream(
|
||||
state: Arc<AppState>,
|
||||
req: ChatRequest,
|
||||
) -> Response {
|
||||
fn chat_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
||||
let id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||
let model_name = state.model_name.clone();
|
||||
let created = unix_timestamp();
|
||||
@@ -356,7 +366,8 @@ fn chat_stream(
|
||||
if prompt_tokens.len() >= max_seq_len {
|
||||
return bad_request(format!(
|
||||
"prompt is {} tokens, exceeds max_seq_len {}",
|
||||
prompt_tokens.len(), max_seq_len
|
||||
prompt_tokens.len(),
|
||||
max_seq_len
|
||||
));
|
||||
}
|
||||
let max_tokens = req.max_tokens.min(max_seq_len - prompt_tokens.len());
|
||||
@@ -413,7 +424,9 @@ fn chat_stream(
|
||||
}
|
||||
});
|
||||
|
||||
Sse::new(ReceiverStream::new(sse_rx)).keep_alive(KeepAlive::default()).into_response()
|
||||
Sse::new(ReceiverStream::new(sse_rx))
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn validate_request(req: &ChatRequest, model_name: &str) -> Option<Response> {
|
||||
@@ -436,8 +449,13 @@ fn validate_request(req: &ChatRequest, model_name: &str) -> Option<Response> {
|
||||
/// prior handler panicked) and returns a clean 503 instead of panicking when the
|
||||
/// engine thread is gone, so one dead engine doesn't cascade into every request.
|
||||
fn submit_to_engine(state: &AppState, req: GenerateRequest) -> Result<(), Response> {
|
||||
let sender = state.engine_sender.lock().unwrap_or_else(|e| e.into_inner());
|
||||
sender.send(req).map_err(|_| service_unavailable("inference engine is not available"))
|
||||
let sender = state
|
||||
.engine_sender
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
sender
|
||||
.send(req)
|
||||
.map_err(|_| service_unavailable("inference engine is not available"))
|
||||
}
|
||||
|
||||
fn service_unavailable(message: impl Into<String>) -> Response {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::path::Path;
|
||||
use std::sync::mpsc;
|
||||
use std::sync::Once;
|
||||
use std::sync::mpsc;
|
||||
use std::time::Instant;
|
||||
use xserv_model::{ModelConfig, PagedKVCache, Qwen3, SamplingParams, sample, BLOCK_SIZE};
|
||||
use xserv_model::loader;
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, SamplingParams, sample};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -109,12 +109,23 @@ impl Engine {
|
||||
(total_blocks * bytes_per_block) as f64 / 1e9,
|
||||
info.free_memory as f64 / 1e9,
|
||||
);
|
||||
Self { model, config, tokenizer, max_batch_size, max_seq_len, paged_cache }
|
||||
Self {
|
||||
model,
|
||||
config,
|
||||
tokenizer,
|
||||
max_batch_size,
|
||||
max_seq_len,
|
||||
paged_cache,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tokenizer(&self) -> &Tokenizer { &self.tokenizer }
|
||||
pub fn tokenizer(&self) -> &Tokenizer {
|
||||
&self.tokenizer
|
||||
}
|
||||
|
||||
pub fn max_seq_len(&self) -> usize { self.max_seq_len }
|
||||
pub fn max_seq_len(&self) -> usize {
|
||||
self.max_seq_len
|
||||
}
|
||||
|
||||
/// Main scheduler loop. Receives requests from channel, manages concurrent sequences.
|
||||
///
|
||||
@@ -134,7 +145,8 @@ impl Engine {
|
||||
|
||||
loop {
|
||||
// Step 1: Remove finished sequences and return their slots.
|
||||
let finished_slots: Vec<usize> = running.iter()
|
||||
let finished_slots: Vec<usize> = running
|
||||
.iter()
|
||||
.filter(|s| is_finished(s))
|
||||
.filter_map(|s| s.seq_slot)
|
||||
.collect();
|
||||
@@ -147,10 +159,16 @@ impl Engine {
|
||||
// room (oldest first). They resume decoding from where they paused.
|
||||
while running.len() < self.max_batch_size && !swapped.is_empty() {
|
||||
let slot = swapped[0].seq_slot.expect("swapped slot");
|
||||
if !self.paged_cache.can_swap_in(slot) { break; }
|
||||
if !self.paged_cache.can_swap_in(slot) {
|
||||
break;
|
||||
}
|
||||
self.paged_cache.swap_in(slot).expect("swap_in");
|
||||
let seq = swapped.remove(0);
|
||||
eprintln!("[scheduler] swapped in seq {} ({} blocks)", seq.id, self.paged_cache.block_count(slot));
|
||||
eprintln!(
|
||||
"[scheduler] swapped in seq {} ({} blocks)",
|
||||
seq.id,
|
||||
self.paged_cache.block_count(slot)
|
||||
);
|
||||
running.push(seq);
|
||||
}
|
||||
|
||||
@@ -161,14 +179,22 @@ impl Engine {
|
||||
let mut avail = self.paged_cache.free_blocks();
|
||||
let decode_reserve = running.len();
|
||||
while running.len() < self.max_batch_size {
|
||||
let Some(front) = waiting.front() else { break; };
|
||||
let Some(front) = waiting.front() else {
|
||||
break;
|
||||
};
|
||||
let prompt_blocks = front.prompt_tokens.len().div_ceil(BLOCK_SIZE).max(1);
|
||||
if avail < prompt_blocks + decode_reserve { break; }
|
||||
if avail < prompt_blocks + decode_reserve {
|
||||
break;
|
||||
}
|
||||
let free_slot = (0..self.paged_cache.max_seqs())
|
||||
.find(|&s| self.paged_cache.is_slot_free(s));
|
||||
let Some(slot) = free_slot else { break; };
|
||||
let Some(slot) = free_slot else {
|
||||
break;
|
||||
};
|
||||
let mut seq = waiting.pop_front().unwrap();
|
||||
self.paged_cache.register_sequence(slot).expect("register paged slot");
|
||||
self.paged_cache
|
||||
.register_sequence(slot)
|
||||
.expect("register paged slot");
|
||||
seq.seq_slot = Some(slot);
|
||||
running.push(seq);
|
||||
avail -= prompt_blocks; // projected free after this seq prefills
|
||||
@@ -199,7 +225,9 @@ impl Engine {
|
||||
if !seq.prefilled {
|
||||
let slot = seq.seq_slot.expect("slot");
|
||||
let logits = self.model.forward_prefill_paged(
|
||||
&seq.prompt_tokens, slot, &mut self.paged_cache,
|
||||
&seq.prompt_tokens,
|
||||
slot,
|
||||
&mut self.paged_cache,
|
||||
);
|
||||
let next = sample(&logits, &seq.sampling);
|
||||
seq.generated_tokens.push(next);
|
||||
@@ -219,13 +247,18 @@ impl Engine {
|
||||
&& !newly_prefilled.contains(&running[p].id)
|
||||
&& running[p].seq_slot.is_some()
|
||||
});
|
||||
let Some(pos) = victim else { break; };
|
||||
let Some(pos) = victim else {
|
||||
break;
|
||||
};
|
||||
let seq = running.remove(pos);
|
||||
let slot = seq.seq_slot.unwrap();
|
||||
if self.paged_cache.can_swap_out(slot) {
|
||||
let nblocks = self.paged_cache.block_count(slot);
|
||||
self.paged_cache.swap_out(slot).expect("swap_out");
|
||||
eprintln!("[scheduler] preempt: swapped out seq {} ({nblocks} blocks) to host", seq.id);
|
||||
eprintln!(
|
||||
"[scheduler] preempt: swapped out seq {} ({nblocks} blocks) to host",
|
||||
seq.id
|
||||
);
|
||||
swapped.push(seq);
|
||||
needed = decode_block_need(&self.paged_cache, &running, &newly_prefilled);
|
||||
} else {
|
||||
@@ -235,7 +268,9 @@ impl Engine {
|
||||
}
|
||||
|
||||
// Step 5c: Batched paged decode for the surviving prefilled sequences.
|
||||
let decode_indices: Vec<usize> = running.iter().enumerate()
|
||||
let decode_indices: Vec<usize> = running
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, s)| s.prefilled && !newly_prefilled.contains(&s.id))
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
@@ -246,25 +281,32 @@ impl Engine {
|
||||
eprintln!("[scheduler] paged decode active");
|
||||
});
|
||||
|
||||
let tokens: Vec<u32> = decode_indices.iter()
|
||||
let tokens: Vec<u32> = decode_indices
|
||||
.iter()
|
||||
.map(|&i| *running[i].generated_tokens.last().unwrap())
|
||||
.collect();
|
||||
let positions: Vec<usize> = decode_indices.iter()
|
||||
let positions: Vec<usize> = decode_indices
|
||||
.iter()
|
||||
.map(|&i| self.paged_cache.seq_len(running[i].seq_slot.unwrap()))
|
||||
.collect();
|
||||
let slots: Vec<usize> = decode_indices.iter()
|
||||
let slots: Vec<usize> = decode_indices
|
||||
.iter()
|
||||
.map(|&i| running[i].seq_slot.unwrap())
|
||||
.collect();
|
||||
|
||||
let logits = self.model.forward_decode_paged(
|
||||
&tokens, &positions, &slots, &mut self.paged_cache,
|
||||
&tokens,
|
||||
&positions,
|
||||
&slots,
|
||||
&mut self.paged_cache,
|
||||
);
|
||||
|
||||
// Fast path: every active sequence is greedy → run argmax on
|
||||
// the GPU and only D2H the chosen token ids (a few bytes per
|
||||
// sequence) instead of the full [B, vocab_size] BF16 logits
|
||||
// (~1.2 MB for B=4, Qwen3 vocab=152K).
|
||||
let all_greedy = decode_indices.iter()
|
||||
let all_greedy = decode_indices
|
||||
.iter()
|
||||
.all(|&i| running[i].sampling.temperature == 0.0);
|
||||
if all_greedy {
|
||||
let next_ids = xserv_kernels::argmax_bf16_to_host(&logits);
|
||||
@@ -285,11 +327,15 @@ impl Engine {
|
||||
let row_start = j * vocab_size;
|
||||
let row_logits = &data[row_start..row_start + vocab_size];
|
||||
let next = if running[i].sampling.temperature == 0.0 {
|
||||
row_logits.iter().enumerate()
|
||||
row_logits
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(idx, _)| idx as u32).unwrap()
|
||||
.map(|(idx, _)| idx as u32)
|
||||
.unwrap()
|
||||
} else {
|
||||
let row_tensor = xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
|
||||
let row_tensor =
|
||||
xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
|
||||
sample(&row_tensor, &running[i].sampling)
|
||||
};
|
||||
running[i].generated_tokens.push(next);
|
||||
@@ -334,7 +380,8 @@ impl Engine {
|
||||
/// Total additional GPU blocks the next decode step needs across all
|
||||
/// currently-decoding (prefilled, not just-prefilled) sequences.
|
||||
fn decode_block_need(paged: &PagedKVCache, running: &[Sequence], newly_prefilled: &[u64]) -> usize {
|
||||
running.iter()
|
||||
running
|
||||
.iter()
|
||||
.filter(|s| s.prefilled && !newly_prefilled.contains(&s.id))
|
||||
.filter_map(|s| s.seq_slot)
|
||||
.map(|slot| paged.additional_blocks_needed(slot, 1))
|
||||
@@ -372,8 +419,12 @@ fn send_token_if_nonempty(seq: &Sequence, text: String) {
|
||||
}
|
||||
|
||||
fn is_finished(seq: &Sequence) -> bool {
|
||||
if seq.generated_tokens.is_empty() { return false; }
|
||||
if seq.generated_tokens.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let last = *seq.generated_tokens.last().unwrap();
|
||||
if seq.generated_tokens.len() >= seq.max_tokens { return true; }
|
||||
if seq.generated_tokens.len() >= seq.max_tokens {
|
||||
return true;
|
||||
}
|
||||
seq.sender.is_closed() || seq.eos_token_id == Some(last)
|
||||
}
|
||||
|
||||
@@ -3,10 +3,13 @@ mod engine;
|
||||
mod pp_engine;
|
||||
mod tp_engine;
|
||||
|
||||
use axum::{routing::{get, post}, Extension, Router};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{mpsc, Arc, Mutex};
|
||||
use axum::{
|
||||
Extension, Router,
|
||||
routing::{get, post},
|
||||
};
|
||||
use engine::GenerateRequest;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex, mpsc};
|
||||
use xserv_model::ModelConfig;
|
||||
|
||||
pub struct AppState {
|
||||
@@ -21,40 +24,48 @@ pub struct AppState {
|
||||
async fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N] [--pp N]");
|
||||
eprintln!(
|
||||
"Usage: xserv-server <model-dir> [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N] [--pp N]"
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let port: u16 = args.iter()
|
||||
let port: u16 = args
|
||||
.iter()
|
||||
.position(|a| a == "--port")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(8080);
|
||||
let max_batch: usize = args.iter()
|
||||
let max_batch: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--max-batch")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(4)
|
||||
.max(1);
|
||||
let requested_max_seq_len: usize = args.iter()
|
||||
let requested_max_seq_len: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--max-seq-len")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(2048)
|
||||
.max(1);
|
||||
let swap_space_gb: usize = args.iter()
|
||||
let swap_space_gb: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--swap-space-gb")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(8);
|
||||
let tp: usize = args.iter()
|
||||
let tp: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--tp")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
let pp: usize = args.iter()
|
||||
let pp: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--pp")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
@@ -65,6 +76,15 @@ async fn main() {
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
// gpt-oss is only implemented in the TP engine; route it there even at
|
||||
// tp=1 (single-rank world) so quantized models can serve on one GPU.
|
||||
let is_gpt_oss = model_config.model_type.as_deref() == Some("gpt_oss");
|
||||
if pp > 1 && is_gpt_oss {
|
||||
eprintln!(
|
||||
"gpt-oss is not supported by the pipeline-parallel engine (Qwen3 only); use --tp instead"
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_max_seq_len = model_config.max_seq_len();
|
||||
if model_max_seq_len == 0 {
|
||||
eprintln!("model config has invalid max_seq_len=0");
|
||||
@@ -77,7 +97,8 @@ async fn main() {
|
||||
);
|
||||
}
|
||||
|
||||
let model_name = model_dir.file_name()
|
||||
let model_name = model_dir
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
@@ -91,8 +112,13 @@ async fn main() {
|
||||
if pp > 1 {
|
||||
// Pipeline-parallel path: stage-0 coordinator + worker stage threads.
|
||||
pp_engine::run_pp(&model_dir_clone, pp, max_seq_len, rx);
|
||||
} else if tp <= 1 {
|
||||
let mut engine = engine::Engine::load_with_swap(&model_dir_clone, max_batch, max_seq_len, swap_space_gb);
|
||||
} else if tp <= 1 && !is_gpt_oss {
|
||||
let mut engine = engine::Engine::load_with_swap(
|
||||
&model_dir_clone,
|
||||
max_batch,
|
||||
max_seq_len,
|
||||
swap_space_gb,
|
||||
);
|
||||
engine.run(rx);
|
||||
} else {
|
||||
// Tensor-parallel path: rank-0 coordinator + worker rank threads.
|
||||
|
||||
@@ -15,15 +15,15 @@
|
||||
|
||||
use std::ffi::c_void;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::mpsc;
|
||||
use std::sync::Arc;
|
||||
use std::sync::mpsc;
|
||||
use std::thread;
|
||||
|
||||
use half::bf16;
|
||||
use xserv_distributed::{PpContext, UniqueId};
|
||||
use xserv_model::loader;
|
||||
use xserv_model::sampling::SamplingParams;
|
||||
use xserv_model::{sample, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, sample};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -38,9 +38,16 @@ enum PpCommand {
|
||||
Free(usize),
|
||||
/// Receive `[n_tokens, hidden]` from the previous stage, run this stage's
|
||||
/// layers; if last stage, sample with `sampling` and return the token.
|
||||
Prefill { n_tokens: usize, slot: usize, sampling: SamplingParams },
|
||||
Prefill {
|
||||
n_tokens: usize,
|
||||
slot: usize,
|
||||
sampling: SamplingParams,
|
||||
},
|
||||
/// Receive `[1, hidden]`, run this stage's layers; last stage samples.
|
||||
Decode { slot: usize, sampling: SamplingParams },
|
||||
Decode {
|
||||
slot: usize,
|
||||
sampling: SamplingParams,
|
||||
},
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
@@ -76,9 +83,21 @@ fn build_stage(
|
||||
let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE);
|
||||
let total_blocks = max_blocks_per_seq + 8; // v1 serial: one active sequence
|
||||
let cache = PagedKVCache::new(
|
||||
&stage_config, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device,
|
||||
&stage_config,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
device,
|
||||
);
|
||||
StageCtx { model, cache, pp, hidden: config.hidden(), device }
|
||||
StageCtx {
|
||||
model,
|
||||
cache,
|
||||
pp,
|
||||
hidden: config.hidden(),
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate a zeroed `[n, hidden]` device tensor and receive into it from `peer`.
|
||||
@@ -110,7 +129,15 @@ fn worker_loop(
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
token_tx: mpsc::Sender<u32>,
|
||||
) {
|
||||
let mut sc = build_stage(&model_dir, &config, stage, world, stage as u32, max_seq_len, id);
|
||||
let mut sc = build_stage(
|
||||
&model_dir,
|
||||
&config,
|
||||
stage,
|
||||
world,
|
||||
stage as u32,
|
||||
max_seq_len,
|
||||
id,
|
||||
);
|
||||
let is_last = stage == world - 1;
|
||||
let prev = stage - 1;
|
||||
let next = stage + 1;
|
||||
@@ -125,7 +152,11 @@ fn worker_loop(
|
||||
sc.cache.free_sequence(slot);
|
||||
let _ = ack_tx.send(());
|
||||
}
|
||||
PpCommand::Prefill { n_tokens, slot, sampling } => {
|
||||
PpCommand::Prefill {
|
||||
n_tokens,
|
||||
slot,
|
||||
sampling,
|
||||
} => {
|
||||
let x = recv_hidden(&sc, n_tokens, prev);
|
||||
let x = sc.model.forward_layers_prefill(x, slot, &mut sc.cache);
|
||||
if is_last {
|
||||
@@ -155,7 +186,12 @@ fn worker_loop(
|
||||
|
||||
/// Run the PP coordinator (stage 0) on the calling thread. Spawns worker stages
|
||||
/// 1..world and consumes generation requests from `rx`.
|
||||
pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver<GenerateRequest>) {
|
||||
pub fn run_pp(
|
||||
model_dir: &Path,
|
||||
world: usize,
|
||||
max_seq_len: usize,
|
||||
rx: mpsc::Receiver<GenerateRequest>,
|
||||
) {
|
||||
assert!(world >= 2, "run_pp requires world >= 2");
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
assert!(
|
||||
@@ -179,7 +215,17 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
let model_dir = model_dir.to_path_buf();
|
||||
let config = config.clone();
|
||||
thread::spawn(move || {
|
||||
worker_loop(stage, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx, token_tx);
|
||||
worker_loop(
|
||||
stage,
|
||||
world,
|
||||
id,
|
||||
model_dir,
|
||||
config,
|
||||
max_seq_len,
|
||||
ctx_rx,
|
||||
ack_tx,
|
||||
token_tx,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -207,11 +253,14 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
wait_acks(&ack_rx);
|
||||
|
||||
// Prefill: embed prompt, run stage-0 layers, push hidden into the pipe.
|
||||
broadcast(&cmd_txs, PpCommand::Prefill {
|
||||
n_tokens: req.prompt_tokens.len(),
|
||||
slot,
|
||||
sampling: req.sampling.clone(),
|
||||
});
|
||||
broadcast(
|
||||
&cmd_txs,
|
||||
PpCommand::Prefill {
|
||||
n_tokens: req.prompt_tokens.len(),
|
||||
slot,
|
||||
sampling: req.sampling.clone(),
|
||||
},
|
||||
);
|
||||
let x = sc.model.embed(&req.prompt_tokens);
|
||||
let x = sc.model.forward_layers_prefill(x, slot, &mut sc.cache);
|
||||
send_hidden(&sc, &x, next_peer);
|
||||
@@ -228,7 +277,13 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
if generated >= req.max_tokens {
|
||||
break "length";
|
||||
}
|
||||
broadcast(&cmd_txs, PpCommand::Decode { slot, sampling: req.sampling.clone() });
|
||||
broadcast(
|
||||
&cmd_txs,
|
||||
PpCommand::Decode {
|
||||
slot,
|
||||
sampling: req.sampling.clone(),
|
||||
},
|
||||
);
|
||||
let x = sc.model.embed(&[next]);
|
||||
let x = sc.model.forward_layers_decode(x, &[slot], &mut sc.cache);
|
||||
send_hidden(&sc, &x, next_peer);
|
||||
@@ -239,9 +294,14 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
|
||||
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
|
||||
if !tail.is_empty() {
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail });
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token {
|
||||
id: next,
|
||||
text: tail,
|
||||
});
|
||||
}
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() });
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Done {
|
||||
finish_reason: finish.to_string(),
|
||||
});
|
||||
|
||||
broadcast(&cmd_txs, PpCommand::Free(slot));
|
||||
sc.cache.free_sequence(slot);
|
||||
@@ -258,6 +318,8 @@ fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: &
|
||||
}
|
||||
let text = tokenizer.decode_token_stream(token_id, buf);
|
||||
if !text.is_empty() {
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
|
||||
let _ = req
|
||||
.sender
|
||||
.blocking_send(GenerateEvent::Token { id: token_id, text });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,13 +13,16 @@
|
||||
//! work; the single-GPU `Engine` still handles TP=1.
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::mpsc;
|
||||
use std::sync::Arc;
|
||||
use std::sync::mpsc;
|
||||
use std::thread;
|
||||
|
||||
use xserv_distributed::{TpContext, UniqueId};
|
||||
use xserv_model::loader;
|
||||
use xserv_model::{sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE};
|
||||
use xserv_model::{
|
||||
BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, sample,
|
||||
sample_greedy_penalized,
|
||||
};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -29,8 +32,15 @@ use crate::engine::{GenerateEvent, GenerateRequest};
|
||||
enum TpCommand {
|
||||
Register(usize),
|
||||
Free(usize),
|
||||
Prefill { tokens: Vec<u32>, slot: usize },
|
||||
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
|
||||
Prefill {
|
||||
tokens: Vec<u32>,
|
||||
slot: usize,
|
||||
},
|
||||
Decode {
|
||||
tokens: Vec<u32>,
|
||||
positions: Vec<usize>,
|
||||
slots: Vec<usize>,
|
||||
},
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
@@ -40,14 +50,25 @@ enum TpModel {
|
||||
}
|
||||
|
||||
impl TpModel {
|
||||
fn forward_prefill_paged(&self, tokens: &[u32], slot: usize, cache: &mut PagedKVCache) -> Tensor {
|
||||
fn forward_prefill_paged(
|
||||
&self,
|
||||
tokens: &[u32],
|
||||
slot: usize,
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
match self {
|
||||
TpModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache),
|
||||
TpModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward_decode_paged(&self, tokens: &[u32], positions: &[usize], slots: &[usize], cache: &mut PagedKVCache) -> Tensor {
|
||||
fn forward_decode_paged(
|
||||
&self,
|
||||
tokens: &[u32],
|
||||
positions: &[usize],
|
||||
slots: &[usize],
|
||||
cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
match self {
|
||||
TpModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache),
|
||||
TpModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache),
|
||||
@@ -58,6 +79,20 @@ impl TpModel {
|
||||
struct RankCtx {
|
||||
model: TpModel,
|
||||
cache: PagedKVCache,
|
||||
decoder: GraphedGptOssDecoder,
|
||||
}
|
||||
|
||||
/// Decode one step: gpt-oss batch=1 goes through the CUDA-graph decoder
|
||||
/// (lazy capture, replay thereafter); everything else runs eager.
|
||||
fn rank_decode(rc: &mut RankCtx, tokens: &[u32], positions: &[usize], slots: &[usize]) -> Tensor {
|
||||
match &rc.model {
|
||||
TpModel::GptOss(m) => rc
|
||||
.decoder
|
||||
.decode(m, tokens, positions, slots, &mut rc.cache),
|
||||
TpModel::Qwen3(_) => rc
|
||||
.model
|
||||
.forward_decode_paged(tokens, positions, slots, &mut rc.cache),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_rank(
|
||||
@@ -71,17 +106,42 @@ fn build_rank(
|
||||
) -> RankCtx {
|
||||
let weights = loader::load_model_dir(model_dir, Device::Cpu);
|
||||
let model = if config.is_moe() {
|
||||
TpModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, rank, world, device, tp))
|
||||
TpModel::GptOss(GptOss::from_weights_tp(
|
||||
config.clone(),
|
||||
weights,
|
||||
rank,
|
||||
world,
|
||||
device,
|
||||
tp,
|
||||
))
|
||||
} else {
|
||||
TpModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, rank, world, device, tp))
|
||||
TpModel::Qwen3(Qwen3::from_weights_tp(
|
||||
config.clone(),
|
||||
weights,
|
||||
rank,
|
||||
world,
|
||||
device,
|
||||
tp,
|
||||
))
|
||||
};
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
let cache = PagedKVCache::new_tp(
|
||||
config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device,
|
||||
config,
|
||||
local_kv,
|
||||
total_blocks,
|
||||
0,
|
||||
4,
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
device,
|
||||
);
|
||||
RankCtx { model, cache }
|
||||
RankCtx {
|
||||
model,
|
||||
cache,
|
||||
decoder: GraphedGptOssDecoder::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn worker_loop(
|
||||
@@ -95,7 +155,15 @@ fn worker_loop(
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
) {
|
||||
let tp = Arc::new(TpContext::init(rank, world, id, rank as u32));
|
||||
let mut rc = build_rank(&model_dir, &config, rank, world, rank as u32, max_seq_len, Some(tp));
|
||||
let mut rc = build_rank(
|
||||
&model_dir,
|
||||
&config,
|
||||
rank,
|
||||
world,
|
||||
rank as u32,
|
||||
max_seq_len,
|
||||
Some(tp),
|
||||
);
|
||||
while let Ok(cmd) = cmd_rx.recv() {
|
||||
match cmd {
|
||||
TpCommand::Register(slot) => {
|
||||
@@ -105,8 +173,12 @@ fn worker_loop(
|
||||
TpCommand::Prefill { tokens, slot } => {
|
||||
let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache);
|
||||
}
|
||||
TpCommand::Decode { tokens, positions, slots } => {
|
||||
let _ = rc.model.forward_decode_paged(&tokens, &positions, &slots, &mut rc.cache);
|
||||
TpCommand::Decode {
|
||||
tokens,
|
||||
positions,
|
||||
slots,
|
||||
} => {
|
||||
let _ = rank_decode(&mut rc, &tokens, &positions, &slots);
|
||||
}
|
||||
TpCommand::Shutdown => {
|
||||
let _ = ack_tx.send(());
|
||||
@@ -119,8 +191,15 @@ fn worker_loop(
|
||||
|
||||
/// Run the TP coordinator (rank 0) on the calling thread. Spawns worker ranks
|
||||
/// internally and consumes generation requests from `rx`.
|
||||
pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver<GenerateRequest>) {
|
||||
assert!(world >= 2, "run_tp requires world >= 2");
|
||||
pub fn run_tp(
|
||||
model_dir: &Path,
|
||||
world: usize,
|
||||
max_seq_len: usize,
|
||||
rx: mpsc::Receiver<GenerateRequest>,
|
||||
) {
|
||||
// world=1 is a valid single-rank configuration (gpt-oss has no
|
||||
// single-GPU engine path; NCCL init and all_reduce no-op at world=1).
|
||||
assert!(world >= 1, "run_tp requires world >= 1");
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
assert!(
|
||||
config.num_kv_heads() % world == 0,
|
||||
@@ -140,7 +219,16 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
let model_dir = model_dir.to_path_buf();
|
||||
let config = config.clone();
|
||||
thread::spawn(move || {
|
||||
worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx);
|
||||
worker_loop(
|
||||
rank,
|
||||
world,
|
||||
id,
|
||||
model_dir,
|
||||
config,
|
||||
max_seq_len,
|
||||
ctx_rx,
|
||||
ack_tx,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -153,10 +241,14 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
// models loop under pure greedy when numerics diverge from the reference).
|
||||
// Off by default; XSERV_REP_PENALTY>1 enables it over the last
|
||||
// XSERV_REP_WINDOW generated tokens. Applied only on the greedy path.
|
||||
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY").ok()
|
||||
.and_then(|s| s.parse().ok()).unwrap_or(1.0);
|
||||
let rep_window: usize = std::env::var("XSERV_REP_WINDOW").ok()
|
||||
.and_then(|s| s.parse().ok()).unwrap_or(128);
|
||||
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1.0);
|
||||
let rep_window: usize = std::env::var("XSERV_REP_WINDOW")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(128);
|
||||
let pick = |logits: &Tensor, sp: &xserv_model::SamplingParams, history: &[u32]| -> u32 {
|
||||
if rep_penalty > 1.0 && sp.temperature == 0.0 {
|
||||
let start = history.len().saturating_sub(rep_window);
|
||||
@@ -185,8 +277,16 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
wait_acks(&ack_rx);
|
||||
|
||||
// Prefill.
|
||||
broadcast(&cmd_txs, TpCommand::Prefill { tokens: req.prompt_tokens.clone(), slot });
|
||||
let logits = rc.model.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache);
|
||||
broadcast(
|
||||
&cmd_txs,
|
||||
TpCommand::Prefill {
|
||||
tokens: req.prompt_tokens.clone(),
|
||||
slot,
|
||||
},
|
||||
);
|
||||
let logits = rc
|
||||
.model
|
||||
.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache);
|
||||
wait_acks(&ack_rx);
|
||||
let mut gen_ids: Vec<u32> = Vec::new();
|
||||
let mut next = pick(&logits, &req.sampling, &gen_ids);
|
||||
@@ -204,8 +304,15 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
break "length";
|
||||
}
|
||||
let pos = rc.cache.seq_len(slot);
|
||||
broadcast(&cmd_txs, TpCommand::Decode { tokens: vec![next], positions: vec![pos], slots: vec![slot] });
|
||||
let logits = rc.model.forward_decode_paged(&[next], &[pos], &[slot], &mut rc.cache);
|
||||
broadcast(
|
||||
&cmd_txs,
|
||||
TpCommand::Decode {
|
||||
tokens: vec![next],
|
||||
positions: vec![pos],
|
||||
slots: vec![slot],
|
||||
},
|
||||
);
|
||||
let logits = rank_decode(&mut rc, &[next], &[pos], &[slot]);
|
||||
wait_acks(&ack_rx);
|
||||
next = pick(&logits, &req.sampling, &gen_ids);
|
||||
gen_ids.push(next);
|
||||
@@ -215,9 +322,14 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece
|
||||
|
||||
let tail = tokenizer.flush_decode_stream(&mut decode_buf);
|
||||
if !tail.is_empty() {
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail });
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token {
|
||||
id: next,
|
||||
text: tail,
|
||||
});
|
||||
}
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() });
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Done {
|
||||
finish_reason: finish.to_string(),
|
||||
});
|
||||
|
||||
broadcast(&cmd_txs, TpCommand::Free(slot));
|
||||
rc.cache.free_sequence(slot);
|
||||
@@ -234,6 +346,8 @@ fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: &
|
||||
}
|
||||
let text = tokenizer.decode_token_stream(token_id, buf);
|
||||
if !text.is_empty() {
|
||||
let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text });
|
||||
let _ = req
|
||||
.sender
|
||||
.blocking_send(GenerateEvent::Token { id: token_id, text });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,18 +43,30 @@ pub trait TensorDType: Copy + Send + Sync + 'static {
|
||||
|
||||
impl TensorDType for f32 {
|
||||
const DTYPE: DType = DType::F32;
|
||||
fn to_f64(self) -> f64 { self as f64 }
|
||||
fn from_f64(v: f64) -> Self { v as f32 }
|
||||
fn to_f64(self) -> f64 {
|
||||
self as f64
|
||||
}
|
||||
fn from_f64(v: f64) -> Self {
|
||||
v as f32
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorDType for f16 {
|
||||
const DTYPE: DType = DType::F16;
|
||||
fn to_f64(self) -> f64 { self.to_f32() as f64 }
|
||||
fn from_f64(v: f64) -> Self { f16::from_f32(v as f32) }
|
||||
fn to_f64(self) -> f64 {
|
||||
self.to_f32() as f64
|
||||
}
|
||||
fn from_f64(v: f64) -> Self {
|
||||
f16::from_f32(v as f32)
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorDType for bf16 {
|
||||
const DTYPE: DType = DType::BF16;
|
||||
fn to_f64(self) -> f64 { self.to_f32() as f64 }
|
||||
fn from_f64(v: f64) -> Self { bf16::from_f32(v as f32) }
|
||||
fn to_f64(self) -> f64 {
|
||||
self.to_f32() as f64
|
||||
}
|
||||
fn from_f64(v: f64) -> Self {
|
||||
bf16::from_f32(v as f32)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,4 +6,4 @@ pub mod tensor;
|
||||
pub use dtype::{DType, TensorDType};
|
||||
pub use shape::Dims;
|
||||
pub use storage::{Device, Storage};
|
||||
pub use tensor::{register_gpu_contiguous, Tensor};
|
||||
pub use tensor::{Tensor, register_gpu_contiguous};
|
||||
|
||||
@@ -46,8 +46,16 @@ pub fn broadcast_shape(a: &[usize], b: &[usize]) -> Option<Dims> {
|
||||
let ndim = a.len().max(b.len());
|
||||
let mut result = SmallVec::with_capacity(ndim);
|
||||
for i in 0..ndim {
|
||||
let da = if i < ndim - a.len() { 1 } else { a[i - (ndim - a.len())] };
|
||||
let db = if i < ndim - b.len() { 1 } else { b[i - (ndim - b.len())] };
|
||||
let da = if i < ndim - a.len() {
|
||||
1
|
||||
} else {
|
||||
a[i - (ndim - a.len())]
|
||||
};
|
||||
let db = if i < ndim - b.len() {
|
||||
1
|
||||
} else {
|
||||
b[i - (ndim - b.len())]
|
||||
};
|
||||
if da == db {
|
||||
result.push(da);
|
||||
} else if da == 1 {
|
||||
@@ -100,8 +108,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_broadcast_shape() {
|
||||
assert_eq!(broadcast_shape(&[3, 1], &[1, 4]).unwrap().as_slice(), &[3, 4]);
|
||||
assert_eq!(broadcast_shape(&[2, 3, 4], &[4]).unwrap().as_slice(), &[2, 3, 4]);
|
||||
assert_eq!(
|
||||
broadcast_shape(&[3, 1], &[1, 4]).unwrap().as_slice(),
|
||||
&[3, 4]
|
||||
);
|
||||
assert_eq!(
|
||||
broadcast_shape(&[2, 3, 4], &[4]).unwrap().as_slice(),
|
||||
&[2, 3, 4]
|
||||
);
|
||||
assert_eq!(broadcast_shape(&[1], &[5, 3]).unwrap().as_slice(), &[5, 3]);
|
||||
assert!(broadcast_shape(&[3], &[4]).is_none());
|
||||
}
|
||||
@@ -109,6 +123,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_broadcast_strides() {
|
||||
// [3,1] with strides [1,1] broadcast to [3,4]
|
||||
assert_eq!(broadcast_strides(&[3, 1], &[1, 1], &[3, 4]).as_slice(), &[1, 0]);
|
||||
assert_eq!(
|
||||
broadcast_strides(&[3, 1], &[1, 1], &[3, 4]).as_slice(),
|
||||
&[1, 0]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,8 +33,20 @@ impl Tensor {
|
||||
// --- Creation ---
|
||||
|
||||
/// Create a tensor from raw components (for advanced use like GPU KV cache).
|
||||
pub fn from_storage(storage: Storage, shape: Dims, strides: Dims, offset: usize, dtype: DType) -> Self {
|
||||
Self { storage, shape, strides, offset, dtype }
|
||||
pub fn from_storage(
|
||||
storage: Storage,
|
||||
shape: Dims,
|
||||
strides: Dims,
|
||||
offset: usize,
|
||||
dtype: DType,
|
||||
) -> Self {
|
||||
Self {
|
||||
storage,
|
||||
shape,
|
||||
strides,
|
||||
offset,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_slice<T: TensorDType>(data: &[T], shape: &[usize]) -> Self {
|
||||
@@ -60,7 +72,10 @@ impl Tensor {
|
||||
data.len(),
|
||||
numel * dtype.size_bytes(),
|
||||
"raw bytes length {} != expected {} (numel={} * elem_size={})",
|
||||
data.len(), numel * dtype.size_bytes(), numel, dtype.size_bytes()
|
||||
data.len(),
|
||||
numel * dtype.size_bytes(),
|
||||
numel,
|
||||
dtype.size_bytes()
|
||||
);
|
||||
Self {
|
||||
storage: Storage::cpu(data.to_vec()),
|
||||
@@ -112,14 +127,28 @@ impl Tensor {
|
||||
|
||||
// --- Properties ---
|
||||
|
||||
pub fn shape(&self) -> &[usize] { &self.shape }
|
||||
pub fn strides(&self) -> &[usize] { &self.strides }
|
||||
pub fn dtype(&self) -> DType { self.dtype }
|
||||
pub fn ndim(&self) -> usize { self.shape.len() }
|
||||
pub fn numel(&self) -> usize { shape::num_elements(&self.shape) }
|
||||
pub fn offset(&self) -> usize { self.offset }
|
||||
pub fn shape(&self) -> &[usize] {
|
||||
&self.shape
|
||||
}
|
||||
pub fn strides(&self) -> &[usize] {
|
||||
&self.strides
|
||||
}
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
pub fn ndim(&self) -> usize {
|
||||
self.shape.len()
|
||||
}
|
||||
pub fn numel(&self) -> usize {
|
||||
shape::num_elements(&self.shape)
|
||||
}
|
||||
pub fn offset(&self) -> usize {
|
||||
self.offset
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device { self.storage.device() }
|
||||
pub fn device(&self) -> Device {
|
||||
self.storage.device()
|
||||
}
|
||||
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
shape::is_contiguous(&self.shape, &self.strides)
|
||||
@@ -193,7 +222,11 @@ impl Tensor {
|
||||
shape::contiguous_strides(&new_shape)
|
||||
} else {
|
||||
let mut s = self.strides.clone();
|
||||
let stride_val = if dim < self.strides.len() { self.strides[dim] } else { 1 };
|
||||
let stride_val = if dim < self.strides.len() {
|
||||
self.strides[dim]
|
||||
} else {
|
||||
1
|
||||
};
|
||||
s.insert(dim, stride_val);
|
||||
s
|
||||
};
|
||||
@@ -230,7 +263,12 @@ impl Tensor {
|
||||
let ndim = self.ndim();
|
||||
let mut idx = vec![0usize; ndim];
|
||||
for flat in 0..numel {
|
||||
let src_offset = self.offset + idx.iter().zip(self.strides.iter()).map(|(i, s)| i * s).sum::<usize>();
|
||||
let src_offset = self.offset
|
||||
+ idx
|
||||
.iter()
|
||||
.zip(self.strides.iter())
|
||||
.map(|(i, s)| i * s)
|
||||
.sum::<usize>();
|
||||
let src_byte_offset = src_offset * elem_size;
|
||||
let dst_byte_offset = flat * elem_size;
|
||||
dst[dst_byte_offset..dst_byte_offset + elem_size]
|
||||
@@ -261,7 +299,10 @@ impl Tensor {
|
||||
}
|
||||
// Transfer the raw storage (preserving strides/offset).
|
||||
// Non-contiguous layout is preserved — the user can call contiguous() after.
|
||||
let new_storage = self.storage.to_device(device).expect("device transfer failed");
|
||||
let new_storage = self
|
||||
.storage
|
||||
.to_device(device)
|
||||
.expect("device transfer failed");
|
||||
Self {
|
||||
storage: new_storage,
|
||||
shape: self.shape.clone(),
|
||||
@@ -310,14 +351,20 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn storage(&self) -> &Storage { &self.storage }
|
||||
pub fn storage(&self) -> &Storage {
|
||||
&self.storage
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Tensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f, "Tensor(shape={:?}, dtype={}, device={}, contiguous={})",
|
||||
self.shape.as_slice(), self.dtype, self.device(), self.is_contiguous()
|
||||
f,
|
||||
"Tensor(shape={:?}, dtype={}, device={}, contiguous={})",
|
||||
self.shape.as_slice(),
|
||||
self.dtype,
|
||||
self.device(),
|
||||
self.is_contiguous()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,11 @@ fn test_zeros_and_ones() {
|
||||
|
||||
#[test]
|
||||
fn test_bf16_tensor() {
|
||||
let data: Vec<bf16> = vec![bf16::from_f32(1.0), bf16::from_f32(2.5), bf16::from_f32(-3.0)];
|
||||
let data: Vec<bf16> = vec![
|
||||
bf16::from_f32(1.0),
|
||||
bf16::from_f32(2.5),
|
||||
bf16::from_f32(-3.0),
|
||||
];
|
||||
let t = Tensor::from_slice(&data, &[3]);
|
||||
assert_eq!(t.dtype(), DType::BF16);
|
||||
let out = t.as_slice::<bf16>();
|
||||
|
||||
@@ -95,11 +95,15 @@ impl Tokenizer {
|
||||
let (a_str, b_str) = match entry {
|
||||
MergeEntry::Str(s) => {
|
||||
let parts: Vec<&str> = s.splitn(2, ' ').collect();
|
||||
if parts.len() != 2 { continue; }
|
||||
if parts.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
(parts[0].to_string(), parts[1].to_string())
|
||||
}
|
||||
MergeEntry::Pair(v) => {
|
||||
if v.len() != 2 { continue; }
|
||||
if v.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
(v[0].clone(), v[1].clone())
|
||||
}
|
||||
};
|
||||
@@ -174,7 +178,10 @@ impl Tokenizer {
|
||||
if byte_fallback {
|
||||
Regex::new(r"[\p{L}\p{N}]+|[^\s\p{L}\p{N}]|\s+").unwrap()
|
||||
} else {
|
||||
Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+").unwrap()
|
||||
Regex::new(
|
||||
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+",
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -262,7 +269,9 @@ impl Tokenizer {
|
||||
|
||||
// BPE merges
|
||||
loop {
|
||||
if token_ids.len() < 2 { break; }
|
||||
if token_ids.len() < 2 {
|
||||
break;
|
||||
}
|
||||
let mut best_rank = usize::MAX;
|
||||
let mut best_idx = 0;
|
||||
for i in 0..token_ids.len() - 1 {
|
||||
@@ -273,12 +282,15 @@ impl Tokenizer {
|
||||
}
|
||||
}
|
||||
}
|
||||
if best_rank == usize::MAX { break; }
|
||||
if best_rank == usize::MAX {
|
||||
break;
|
||||
}
|
||||
|
||||
let merged_bytes = [
|
||||
self.decoder[token_ids[best_idx] as usize].as_slice(),
|
||||
self.decoder[token_ids[best_idx + 1] as usize].as_slice(),
|
||||
].concat();
|
||||
]
|
||||
.concat();
|
||||
let merged_id = *self.encoder.get(&merged_bytes).unwrap_or_else(|| {
|
||||
panic!("merged token not in vocab");
|
||||
});
|
||||
@@ -389,14 +401,13 @@ fn unicode_to_byte(c: char) -> u8 {
|
||||
m
|
||||
});
|
||||
|
||||
*map.get(&(c as u32)).unwrap_or_else(|| {
|
||||
panic!("unmapped unicode char U+{:04X} in tokenizer", c as u32)
|
||||
})
|
||||
*map.get(&(c as u32))
|
||||
.unwrap_or_else(|| panic!("unmapped unicode char U+{:04X} in tokenizer", c as u32))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{take_valid_utf8, Tokenizer};
|
||||
use super::{Tokenizer, take_valid_utf8};
|
||||
|
||||
#[test]
|
||||
fn qwen_added_tokens_are_indivisible_and_im_end_is_eos() {
|
||||
|
||||
@@ -87,6 +87,17 @@ __global__ void add_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b,
|
||||
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(a[idx]) + __bfloat162float(b[idx]));
|
||||
}
|
||||
|
||||
// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c]
|
||||
__global__ void bias_add_2d_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ x, const __nv_bfloat16* __restrict__ bias,
|
||||
__nv_bfloat16* __restrict__ out, int rows, int cols
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= rows * cols) return;
|
||||
float v = __bfloat162float(x[idx]) + __bfloat162float(bias[idx % cols]);
|
||||
out[idx] = __float2bfloat16(v);
|
||||
}
|
||||
|
||||
// Element-wise mul: out = a * b
|
||||
__global__ void mul_f32_kernel(const float* a, const float* b, float* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
@@ -159,6 +170,14 @@ void launch_add_bf16(const void* a, const void* b, void* out, int n, void* strea
|
||||
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
void launch_bias_add_2d_bf16(const void* x, const void* bias, void* out, int rows, int cols, void* stream) {
|
||||
int n = rows * cols;
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
bias_add_2d_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const __nv_bfloat16*)bias, (__nv_bfloat16*)out, rows, cols);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
void launch_mul_f32(const void* a, const void* b, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
|
||||
@@ -93,11 +93,9 @@ __global__ void moe_replicate_bf16_kernel(
|
||||
int total = local_experts * num_tokens * hidden;
|
||||
if (idx >= total) return;
|
||||
|
||||
int expert = idx / (num_tokens * hidden);
|
||||
int remainder = idx % (num_tokens * hidden);
|
||||
// x_rep[expert, token, dim] = x[token, dim]
|
||||
x_rep[idx] = x[remainder];
|
||||
(void)expert; // suppress unused warning
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
|
||||
254
csrc/moe/moe_sparse.cu
Normal file
254
csrc/moe/moe_sparse.cu
Normal file
@@ -0,0 +1,254 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cstdint>
|
||||
#include "../common.cuh"
|
||||
|
||||
// ============================================================
|
||||
// Sparse MoE decode GEMVs — compute ONLY the routed experts.
|
||||
//
|
||||
// The dense path replicates x across all local experts and runs a
|
||||
// batched GEMM, reading every expert's weights per token. Decode is
|
||||
// memory-bound, so reading only the top-k routed experts' weights
|
||||
// (~2 of 16 local on average at TP=2) is a ~8x byte reduction.
|
||||
//
|
||||
// Each block handles one (token, slot) pair's tile of output columns.
|
||||
// It reads topk_ids[token, slot] from device memory (no host sync),
|
||||
// and exits early if the expert is not owned by this rank. The early
|
||||
// return is BLOCK-UNIFORM (every thread sees the same topk_ids value
|
||||
// and returns before the shared-memory staging + __syncthreads), so
|
||||
// it is safe — unlike the divergent-return bug fixed in gemv.cu.
|
||||
//
|
||||
// Outputs for non-local slots are NEVER written (uninitialized memory,
|
||||
// possibly NaN bit patterns). Downstream consumers must SKIP non-local
|
||||
// slots rather than multiply by zero (NaN * 0 = NaN).
|
||||
//
|
||||
// Per-expert weight scale and bias are fused into the epilogue:
|
||||
// y[t, slot, n] = acc * w_scale[lid] + bias[lid, n]
|
||||
// which matches the dense path's GEMM -> moe_bias_add_3d sequence.
|
||||
//
|
||||
// Activation addressing (x_per_slot):
|
||||
// gate_up: all slots of a token share x[token, :] (x_per_slot=0)
|
||||
// down: each slot has its own activation row
|
||||
// x[token * top_k + slot, :] (x_per_slot=1)
|
||||
// ============================================================
|
||||
|
||||
#define SPARSE_TILE_N 8 // output columns per block (= warps per block)
|
||||
|
||||
// Weights FP8 E4M3 [local_experts, N, K], activations BF16 (W8A16).
|
||||
// Decode is memory-bound (~2 FLOP/byte), so dequant-in-registers GEMV
|
||||
// loses nothing to tensor cores and skips activation quantization.
|
||||
__global__ void moe_sparse_gemv_fp8_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ x, // [T, K] or [T*top_k, K]
|
||||
const __nv_fp8_e4m3* __restrict__ w, // [local_experts, N, K]
|
||||
const float* __restrict__ w_scales, // [local_experts]
|
||||
const __nv_bfloat16* __restrict__ bias, // [local_experts, N]
|
||||
const int* __restrict__ topk_ids, // [T, top_k] global expert ids
|
||||
__nv_bfloat16* __restrict__ y, // [T, top_k, N]
|
||||
int N, int K, int top_k,
|
||||
int expert_start, int local_experts,
|
||||
int x_per_slot
|
||||
) {
|
||||
int token = blockIdx.z;
|
||||
int slot = blockIdx.y;
|
||||
int eid = topk_ids[token * top_k + slot];
|
||||
int lid = eid - expert_start;
|
||||
if (lid < 0 || lid >= local_experts) return; // block-uniform: safe
|
||||
|
||||
extern __shared__ float xs[]; // [K] activation row as float
|
||||
const __nv_bfloat16* xrow =
|
||||
x + (long long)(x_per_slot ? token * top_k + slot : token) * K;
|
||||
for (int i = threadIdx.x; i < K; i += blockDim.x) {
|
||||
xs[i] = __bfloat162float(xrow[i]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int n = blockIdx.x * SPARSE_TILE_N + (threadIdx.x >> 5);
|
||||
if (n >= N) return; // after __syncthreads: safe
|
||||
int lane = threadIdx.x & 31;
|
||||
|
||||
// One warp per output column; uint4 = 16 FP8 weights per lane, the
|
||||
// warp covers 512 contiguous bytes per iteration (coalesced).
|
||||
const uint8_t* wrow = (const uint8_t*)w + ((long long)lid * N + n) * K;
|
||||
float acc = 0.0f;
|
||||
for (int i = lane; i < (K >> 4); i += 32) {
|
||||
uint4 packed = *(const uint4*)(wrow + (long long)i * 16);
|
||||
const __nv_fp8_e4m3* pw = (const __nv_fp8_e4m3*)&packed;
|
||||
const float* xk = xs + i * 16;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 16; j++) {
|
||||
acc += xk[j] * float(pw[j]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int o = 16; o > 0; o >>= 1) {
|
||||
acc += __shfl_down_sync(0xffffffffu, acc, o);
|
||||
}
|
||||
if (lane == 0) {
|
||||
float v = acc * w_scales[lid]
|
||||
+ __bfloat162float(bias[(long long)lid * N + n]);
|
||||
y[((long long)token * top_k + slot) * N + n] = __float2bfloat16(v);
|
||||
}
|
||||
}
|
||||
|
||||
// MXFP4 W4A16 variant: packed E2M1 nibbles + per-32 UE8M0 block scale,
|
||||
// same structure as batched_gemv_mxfp4_bf16_kernel but expert-indexed
|
||||
// via topk_ids and with fused per-expert bias.
|
||||
#define MXFP4_BLOCK 32
|
||||
|
||||
__device__ __constant__ float kSparseFp4Levels[8] =
|
||||
{0.f, 0.5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f};
|
||||
|
||||
__device__ __forceinline__ float sparse_fp4_to_float(uint8_t code) {
|
||||
float mag = kSparseFp4Levels[code & 0x7];
|
||||
return (code & 0x8) ? -mag : mag;
|
||||
}
|
||||
|
||||
__global__ void moe_sparse_gemv_mxfp4_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ x, // [T, K] or [T*top_k, K]
|
||||
const uint8_t* __restrict__ w_packed, // [local_experts, N, K/2]
|
||||
const uint8_t* __restrict__ w_scales, // [local_experts, N, K/32]
|
||||
const __nv_bfloat16* __restrict__ bias, // [local_experts, N]
|
||||
const int* __restrict__ topk_ids, // [T, top_k]
|
||||
__nv_bfloat16* __restrict__ y, // [T, top_k, N]
|
||||
int N, int K, int top_k,
|
||||
int expert_start, int local_experts,
|
||||
int x_per_slot
|
||||
) {
|
||||
int token = blockIdx.z;
|
||||
int slot = blockIdx.y;
|
||||
int eid = topk_ids[token * top_k + slot];
|
||||
int lid = eid - expert_start;
|
||||
if (lid < 0 || lid >= local_experts) return; // block-uniform: safe
|
||||
|
||||
extern __shared__ float xs[];
|
||||
const __nv_bfloat16* xrow =
|
||||
x + (long long)(x_per_slot ? token * top_k + slot : token) * K;
|
||||
for (int i = threadIdx.x; i < K; i += blockDim.x) {
|
||||
xs[i] = __bfloat162float(xrow[i]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int n = blockIdx.x * SPARSE_TILE_N + (threadIdx.x >> 5);
|
||||
if (n >= N) return;
|
||||
int lane = threadIdx.x & 31;
|
||||
int nblk = K / MXFP4_BLOCK;
|
||||
|
||||
const uint8_t* wp = w_packed + ((long long)lid * N + n) * (K >> 1);
|
||||
const uint8_t* ws = w_scales + ((long long)lid * N + n) * nblk;
|
||||
|
||||
float acc = 0.0f;
|
||||
for (int blk = lane; blk < nblk; blk += 32) {
|
||||
float scale = exp2f((float)((int)ws[blk] - 127));
|
||||
uint4 packed = *(const uint4*)(wp + (long long)blk * 16); // 32 nibbles
|
||||
const uint8_t* pb = (const uint8_t*)&packed;
|
||||
const float* xk = xs + blk * MXFP4_BLOCK;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) {
|
||||
uint8_t b = pb[i];
|
||||
acc += xk[2 * i] * (sparse_fp4_to_float(b & 0xF) * scale);
|
||||
acc += xk[2 * i + 1] * (sparse_fp4_to_float(b >> 4) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int o = 16; o > 0; o >>= 1) {
|
||||
acc += __shfl_down_sync(0xffffffffu, acc, o);
|
||||
}
|
||||
if (lane == 0) {
|
||||
float v = acc + __bfloat162float(bias[(long long)lid * N + n]);
|
||||
y[((long long)token * top_k + slot) * N + n] = __float2bfloat16(v);
|
||||
}
|
||||
}
|
||||
|
||||
// Weighted sum over the slot axis: out[t, d] = sum over local slots of
|
||||
// topk_weights[t, k] * down[t, k, d]. Non-local slots hold uninitialized
|
||||
// memory and are SKIPPED (not multiplied by zero).
|
||||
__global__ void moe_weighted_sum_sparse_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ down, // [T, top_k, hidden]
|
||||
const int* __restrict__ topk_ids, // [T, top_k]
|
||||
const float* __restrict__ topk_weights, // [T, top_k]
|
||||
__nv_bfloat16* __restrict__ out, // [T, hidden]
|
||||
int num_tokens, int hidden, int top_k,
|
||||
int expert_start, int local_experts
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = num_tokens * hidden;
|
||||
if (idx >= total) return;
|
||||
|
||||
int token = idx / hidden;
|
||||
int dim = idx % hidden;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int k = 0; k < top_k; k++) {
|
||||
int lid = topk_ids[token * top_k + k] - expert_start;
|
||||
if (lid >= 0 && lid < local_experts) {
|
||||
float w = topk_weights[token * top_k + k];
|
||||
float v = __bfloat162float(
|
||||
down[((long long)token * top_k + k) * hidden + dim]);
|
||||
sum += w * v;
|
||||
}
|
||||
}
|
||||
out[idx] = __float2bfloat16(sum);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_moe_sparse_gemv_fp8_bf16(
|
||||
const void* x, const void* w, const void* w_scales, const void* bias,
|
||||
const void* topk_ids, void* y,
|
||||
int num_tokens, int N, int K, int top_k,
|
||||
int expert_start, int local_experts, int x_per_slot,
|
||||
void* stream
|
||||
) {
|
||||
dim3 grid((N + SPARSE_TILE_N - 1) / SPARSE_TILE_N, top_k, num_tokens);
|
||||
int block = SPARSE_TILE_N * 32;
|
||||
size_t smem = (size_t)K * sizeof(float);
|
||||
moe_sparse_gemv_fp8_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const __nv_fp8_e4m3*)w,
|
||||
(const float*)w_scales, (const __nv_bfloat16*)bias,
|
||||
(const int*)topk_ids, (__nv_bfloat16*)y,
|
||||
N, K, top_k, expert_start, local_experts, x_per_slot
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_moe_sparse_gemv_mxfp4_bf16(
|
||||
const void* x, const void* w_packed, const void* w_scales, const void* bias,
|
||||
const void* topk_ids, void* y,
|
||||
int num_tokens, int N, int K, int top_k,
|
||||
int expert_start, int local_experts, int x_per_slot,
|
||||
void* stream
|
||||
) {
|
||||
dim3 grid((N + SPARSE_TILE_N - 1) / SPARSE_TILE_N, top_k, num_tokens);
|
||||
int block = SPARSE_TILE_N * 32;
|
||||
size_t smem = (size_t)K * sizeof(float);
|
||||
moe_sparse_gemv_mxfp4_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const uint8_t*)w_packed,
|
||||
(const uint8_t*)w_scales, (const __nv_bfloat16*)bias,
|
||||
(const int*)topk_ids, (__nv_bfloat16*)y,
|
||||
N, K, top_k, expert_start, local_experts, x_per_slot
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_moe_weighted_sum_sparse_bf16(
|
||||
const void* down, const void* topk_ids, const void* topk_weights,
|
||||
void* out,
|
||||
int num_tokens, int hidden, int top_k,
|
||||
int expert_start, int local_experts,
|
||||
void* stream
|
||||
) {
|
||||
int total = num_tokens * hidden;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
moe_weighted_sum_sparse_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)down,
|
||||
(const int*)topk_ids, (const float*)topk_weights,
|
||||
(__nv_bfloat16*)out,
|
||||
num_tokens, hidden, top_k, expert_start, local_experts
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1748,6 +1748,27 @@ Text → Tokenizer → Text Tokens ────────────→
|
||||
|
||||
---
|
||||
|
||||
## 实际进展记录(与原计划的分叉,2026-06 更新)
|
||||
|
||||
Phase 0–17 按计划完成。Phase 18 起实际路线偏离了上面的原计划
|
||||
(speculative decoding 与多模态推迟),实际走向是 MoE + 量化 + 稀疏化:
|
||||
|
||||
| 实际 Phase | 内容 | 文档 |
|
||||
|---|---|---|
|
||||
| 18 | Pipeline Parallelism(PP=2/4) | `18-pipeline-parallelism.md`、`benchmarks/pp-sweep.md` |
|
||||
| 19 | **gpt-oss-20b MoE**:harmony 格式、attention sinks + sliding window、YaRN;两个 CUDA bug 实战(prefill sinks NaN、GEMV 未初始化 smem);GSM8K 94.5% 对齐 llama.cpp;FP8 W8A8 / MXFP4 W4A16 量化 | `19-gpt-oss-moe.md`、`benchmarks/{fp8-quantization,mxfp4-and-llama-decode}.md` |
|
||||
| 20 | **稀疏 top-k MoE decode**:只算被路由的专家,decode 13.9→7.0ms,TP=2 下 decode/TTFT 全面快于 llama.cpp 同配置;gpt-oss 单卡 serving | `20-sparse-moe.md`、`benchmarks/sparse-moe.md` |
|
||||
| 21 | **decode CUDA Graph + GPU argmax**:整个 decode step 录成一个图回放(thread-local launch stream、retained-warmup 分配策略、NCCL capture);greedy 采样换 GPU argmax。TPOT 7.5→5.9ms(TP=1)/ 5.8ms(TP=2);TP=2 全面领先 llama(1.26-1.47×),TP=1 差距 2.5×→2.0× | `21-cuda-graph-decode.md` |
|
||||
|
||||
**下一步候选(按预期收益排序):**
|
||||
|
||||
| 候选 Phase | 内容 | 预期 |
|
||||
|---|---|---|
|
||||
| 22 | **非专家权重量化**:qkv/o + lm_head(1.16GB/token)仍是 BF16 | TPOT 再省 ~1.5ms |
|
||||
| 23 | **稀疏 prefill**(按专家 permute + grouped GEMM) | 长 prompt TTFT 51-75 → ~30ms |
|
||||
| 24 | server 侧 harmony channel 分离(`reasoning_content` 流式输出,对齐 llama-server 行为) | API 易用性 |
|
||||
| — | Speculative Decoding、多模态(原 16/19) | 推迟 |
|
||||
|
||||
## 里程碑总结
|
||||
|
||||
| 里程碑 | Phase | 验收标准 |
|
||||
@@ -1757,7 +1778,9 @@ Text → Tokenizer → Text Tokens ────────────→
|
||||
| ③ E2E API | 13 | HTTP streaming API, Python OpenAI SDK 可调用, 10 并发正确 |
|
||||
| ④ 性能达标 | 15 | throughput >= 50% vLLM, profiling 报告完成 |
|
||||
| ⑤ 多卡推理 | 17 | TP=2/4 同组 GPU 推理正确, scaling benchmark 完成 |
|
||||
| ⑥ 多模态 | 19 | 图片输入 → 文字回答, API 端到端 |
|
||||
| ⑥ MoE 模型(实际) | 19 | gpt-oss-20b 端到端正确, GSM8K 与 llama.cpp 持平 ✅ |
|
||||
| ⑦ 性能反超(实际) | 20 | 同配置 decode 快于 llama.cpp(TP=2 达成;单卡是 Phase 21+ 目标) ✅ |
|
||||
| ⑧ 多模态 | 推迟 | 图片输入 → 文字回答, API 端到端 |
|
||||
|
||||
## 外部依赖清单
|
||||
|
||||
|
||||
118
docs/19-gpt-oss-moe.md
Normal file
118
docs/19-gpt-oss-moe.md
Normal file
@@ -0,0 +1,118 @@
|
||||
# Phase 19: gpt-oss-20b — MoE 模型支持与两次 CUDA 调试实战
|
||||
|
||||
> 目标:支持 OpenAI gpt-oss-20b(32 专家 top-4 MoE),GSM8K 精度对齐 llama.cpp,
|
||||
> 并以此为载体做 FP8 / MXFP4 量化。本文档事后整理,重点放在**踩过的坑**:
|
||||
> 两个教科书级的 CUDA bug 排查过程比结论本身更有学习价值。
|
||||
>
|
||||
> 后续:`docs/20-sparse-moe.md`(稀疏化),benchmark 数据见
|
||||
> `docs/benchmarks/{fp8-quantization,mxfp4-and-llama-decode,sparse-moe}.md`。
|
||||
|
||||
## 1. 模型架构(与 Qwen3 的差异点)
|
||||
|
||||
gpt-oss-20b(`config.json`,已在 dash5 验证):
|
||||
|
||||
| 项 | 值 | 说明 |
|
||||
|---|---|---|
|
||||
| layers / hidden | 24 / 2880 | hidden **不是** 128 的倍数的来源(2880 = 22.5×128) |
|
||||
| heads | 64 Q / 8 KV,head_dim **64** | head_dim ≠ hidden/heads(64×64=4096>2880),GQA n_rep=8 |
|
||||
| MoE | 32 experts,top-4,expert inter 2880 | router 是普通 Linear [2880→32] + bias |
|
||||
| attention | **交替 sliding(128)/full**,layer 0 是 sliding | 每层带 **attention sinks**(每 head 一个可学习标量) |
|
||||
| RoPE | YaRN(theta 150000, factor 32, orig 4096) | attn_factor = 0.1·ln(32)+1 乘在 cos/sin 上 |
|
||||
| 激活 | clamp 后的 GLU | gate=gu[::2], up=gu[1::2](**交错**), gate≤7, up∈[-7,7], glu=gate·σ(1.702·gate), h=(up+1)·glu |
|
||||
| 词表 | 201088 | EOS 是**列表** [200002,199999,200012] = `<|return|>`/`<|endoftext|>`/`<|call|>` |
|
||||
| 其它 | attention_bias=true | q/k/v/o 全部带 bias(Qwen3 没有) |
|
||||
|
||||
**Harmony 对话格式**:gpt-oss 不是普通 chat template,输出分 channel
|
||||
(`analysis`=思维链,`final`=正式回答),控制 token `<|start|>/<|channel|>/<|message|>/<|end|>`。
|
||||
三个坑:(1) system 消息必须含 `Reasoning:` 等 canonical 行,缺了模型 OOD、
|
||||
channel 选择不稳定;(2) repetition penalty 会惩罚必须重复出现的控制 token,
|
||||
导致模型只输出 analysis 不出 final(MoE 默认关掉);(3) 服务端要用多 EOS 判停。
|
||||
|
||||
## 2. MoE 前向(dense 版,Phase 20 之前)
|
||||
|
||||
```text
|
||||
router GEMV → topk_softmax(GPU)→ moe_replicate(复制到全部本地专家)
|
||||
→ batched GEMM gate_up → bias → GLU → batched GEMM down → bias
|
||||
→ weighted_sum(只取 top-4)→ all-reduce
|
||||
```
|
||||
|
||||
要点:top-k 的专家编号始终留在 GPU(`topk_ids`),host 不同步;
|
||||
dense 的代价(每 token 读全部专家权重)在 Phase 20 用 sparse GEMV 解决。
|
||||
TP 用 **expert parallelism**:rank r 拥有专家 [r·E/world, (r+1)·E/world),
|
||||
weighted_sum 里按 `expert_start + local_experts` 过滤非本地命中,
|
||||
all-reduce 把各 rank 的部分和加起来——这要求"跳过"语义而不是"乘 0"。
|
||||
|
||||
## 3. CUDA 调试实战 ①:prefill NaN(flash-attention sinks)
|
||||
|
||||
**症状**:长 prompt(≳192 token)prefill 后输出全 NaN → argmax 落在
|
||||
token 201087(`max_by` 平局取最后)或 token 0(`!`)。短 prompt 完全正常。
|
||||
|
||||
**定位手法**:给每个 stage 加 NaN 检查(环境变量开关,事后移除),
|
||||
二分出第一个出 NaN 的位置:layer-0 的 `flash_attention_sinks` 输出,
|
||||
而它的 q/k 输入是干净的 → bug 在 kernel 内部。
|
||||
|
||||
**根因**:causal 跳过逻辑只剔除"完全在未来"的 kv tile;一个完全滑出
|
||||
sliding window(128)的**过去** tile 仍被处理,所有 key 都被 mask 成 -inf
|
||||
→ `row_max = -inf` → online softmax 里 `expf(-inf-(-inf)) = NaN`,
|
||||
下一个有效 tile 的修正项 `0·NaN` 把整行毒掉。
|
||||
|
||||
**修复**:`row_max == -INFINITY` 的 tile 直接跳过(贡献为零)。
|
||||
**教训**:online softmax 的"空 tile"是边界条件标配;decode kernel 早就
|
||||
防了这个(`local_max==-INFINITY` guard),prefill kernel 漏了——
|
||||
**同一逻辑的两份实现要做同样的边界测试**。触发阈值 ~192 token 解释了
|
||||
"短测试全过、长对话必炸"的诡异表象。
|
||||
|
||||
## 4. CUDA 调试实战 ②:decode 间歇性乱码(GEMV 未初始化共享内存)
|
||||
|
||||
**症状**:同一 prompt ~70% 的运行在第二轮对话或长生成中突然输出
|
||||
`!!!!`/token 201087/NaN logits,**间歇性** → 不是确定性逻辑错误,
|
||||
是竞态或未初始化读。只有 gpt-oss 出问题,Qwen3 从不复现。
|
||||
|
||||
**定位**:逐 stage 检查,第一个出问题的是 decode 的 o_proj 输出
|
||||
(maxabs≈1e33),输入干净 → M=1 的 GEMV kernel。
|
||||
|
||||
**根因**(`gemv.cu`):
|
||||
|
||||
```cuda
|
||||
if (col >= N) return; // ← 在协作加载 x_shared 和 __syncthreads 之前!
|
||||
...cooperative load + __syncthreads()...
|
||||
```
|
||||
|
||||
当 `N % 128 != 0` 时,最后一个 block 的越界线程提前退出,**没参与**
|
||||
共享内存装载;在界线程读到未初始化的 smem(且 `__syncthreads` 在有线程
|
||||
已退出时是 UB)。命中条件:n=2880 的矩阵(o_proj、MoE gate_up/down)——
|
||||
2880 % 128 ≠ 0;而 Qwen3 所有维度都是 128 对齐的,**所以"只有 gpt-oss
|
||||
不稳定"**。q/k/v(4096)、lm_head(201088)对齐,幸免。
|
||||
|
||||
**修复**:所有线程先完成装载 + barrier,`col >= N` 检查移到 syncthreads
|
||||
**之后**。
|
||||
|
||||
**教训**:`__syncthreads()` 之前的任何 early-return 必须是 **block-uniform**
|
||||
的。Phase 20 的 sparse GEMV 专门遵守了这条(整个 block 基于同一个
|
||||
`topk_ids` 值统一退出,发生在装载之前)。
|
||||
|
||||
**修复后的验证**:GSM8K 全量 1319 题,xserv 94.5% vs llama.cpp 94.4%
|
||||
——统计上同一水平,证明两个 kernel bug 就是之前 55% vs 95% 差距的全部原因。
|
||||
|
||||
## 5. 量化(详见 benchmark 文档)
|
||||
|
||||
- **FP8 W8A8**(`tools/quantize_fp8.py`):per-expert 标量 scale,权重转置
|
||||
存 [E,N,K] 喂 cuBLASLt(Blackwell 要求 transA=T)。两个性能坑:
|
||||
(1) 每次调用重建 plan + 跑 heuristic → 比 BF16 还慢,修复 = per-shape
|
||||
plan cache;(2) 逐专家发射 ~768 个小 GEMM,修复 = 单条 strided-batched
|
||||
调用 + 把 scale 移到融合的 post-scale kernel。最终 1.41× vs BF16。
|
||||
- **MXFP4 W4A16**(`tools/quantize_mxfp4.py`):E2M1 + per-32 UE8M0 块 scale,
|
||||
13GB 模型,贪心输出与 BF16 逐字一致,但手写 dequant-GEMV 打不过
|
||||
cuBLASLt FP8(带宽效率差),定位为省显存方案。
|
||||
- 检测方式:safetensors 的 dtype/scale 秩自动识别,loader 无需配置。
|
||||
|
||||
## 6. 本阶段的工具沉淀
|
||||
|
||||
- `bench-gpt-oss`:in-process 推理 + `--forced`(teacher-forced prefill
|
||||
top-1)/`--forced-decode`(沿参考轨迹逐位置 top-1)——分离"前向算错"
|
||||
和"贪心轨迹分叉"的利器。
|
||||
- `tools/eval_gsm8k_fast.py`(持久 xserv-chat 管道)、
|
||||
`tools/xserv_vs_llama.py`(warm-server 同机对打,计入 llama 的
|
||||
reasoning_content)。
|
||||
- 经验:**贪心解码不是逐位可复现的**(cuBLAS 非确定性会翻转后段 argmax),
|
||||
多卡正确性要用"单卡×2 + 多卡×2 互相比",精度要用基准集而不是逐字 diff。
|
||||
160
docs/20-sparse-moe.md
Normal file
160
docs/20-sparse-moe.md
Normal file
@@ -0,0 +1,160 @@
|
||||
# Phase 20: Sparse MoE Decode — 只算被路由到的专家
|
||||
|
||||
> 目标:消除 dense MoE 的无效权重读取,decode TPOT 追上并超过 llama.cpp。
|
||||
> 前置:Phase 19(gpt-oss MoE 正确性)、FP8 W8A8 / MXFP4 W4A16 量化
|
||||
> (见 `docs/benchmarks/fp8-quantization.md`、`docs/benchmarks/mxfp4-and-llama-decode.md`)。
|
||||
|
||||
## 1. 现状:dense MoE 在浪费什么
|
||||
|
||||
gpt-oss-20b 是 32 专家 top-4 的 MoE:router 给每个 token 选 4 个专家,
|
||||
理论上每 token 只需要读 4/32 = 12.5% 的专家权重。但 `moe_forward`
|
||||
(`crates/xserv-model/src/gpt_oss.rs`)目前是 **dense** 实现:
|
||||
|
||||
```text
|
||||
1. router GEMV [T, 2880] → [T, 32]
|
||||
2. topk_softmax (GPU) → topk_ids [T,4], topk_weights [T,4]
|
||||
3. moe_replicate x 复制 16 份 → [16, T, 2880] ← 浪费开始
|
||||
4. batched GEMM gate_up 全部 16 个本地专家都算 ← 读 16 份权重
|
||||
5. bias + GLU
|
||||
6. batched GEMM down 全部 16 个本地专家都算 ← 读 16 份权重
|
||||
7. bias
|
||||
8. moe_weighted_sum 只挑出 top-4 加权求和,其余 12 个全部丢弃
|
||||
9. all-reduce
|
||||
```
|
||||
|
||||
为什么当初这么写:batched GEMM(cuBLAS strided-batched)要求规则的
|
||||
`[E, T, K]` 形状;top-4 的专家编号在 **GPU** 上(`topk_ids`),host 不知道
|
||||
该挑哪几个,挑了形状也不规则。dense 是"先把正确性做出来"的合理起点,
|
||||
但每 token 把 16 个专家的权重从 HBM 全部读一遍。
|
||||
|
||||
### 字节账本(decode,每 token,TP=2 每卡 16 个本地专家)
|
||||
|
||||
每层每专家:gate_up `[2880, 5760]` + down `[2880, 2880]` ≈ 24.9 M 参数。
|
||||
|
||||
| 方案 | 每卡每 token 专家字节 | 相对量 |
|
||||
|---|---|---|
|
||||
| xserv dense FP8(现状) | 16 × 24.9 MB × 24 层 ≈ **9.6 GB** | 1× |
|
||||
| xserv sparse FP8(本阶段) | ~2 × 24.9 MB × 24 层 ≈ **1.2 GB** | 1/8 |
|
||||
| llama.cpp sparse MXFP4 | ~2 × 12.5 MB × 24 层 ≈ **0.6 GB** | 1/16 |
|
||||
|
||||
(top-4 均匀散落在 2 张卡上,期望每卡 2 个命中;严格说每层取的是
|
||||
两卡命中数的 max,期望 ≈ 2.6,仍是 ~6-8× 的节省。)
|
||||
|
||||
实测旁证:FP8 dense TP=2 TPOT 13.1 ms,其中专家 GEMM ≈ 9.6 GB ÷ ~1 TB/s
|
||||
≈ 9.5 ms,其余(attention、qkv/o、lm_head、48 次 PCIe all-reduce)≈ 3.5 ms。
|
||||
**专家权重读取占 TPOT 的 ~3/4,这就是与 llama.cpp(6.6 ms)的全部差距。**
|
||||
|
||||
## 2. Roofline:M=1 时为什么"省字节 = 省时间"
|
||||
|
||||
decode 的 GEMV(M=1)每读 1 字节 FP8 权重只做 2 FLOP(乘加)。
|
||||
RTX 5090:HBM ~1.8 TB/s,BF16 算力 ~210 TFLOPS —— 算强比(arithmetic
|
||||
intensity)需要 ~100 FLOP/byte 才能喂饱算力,GEMV 只有 2。结论:
|
||||
|
||||
1. **decode 完全 memory-bound**,tensor core 帮不上忙 → 手写 W8A16 GEMV
|
||||
(权重 FP8、激活保持 BF16)不会输给 cuBLASLt 的 W8A8 tensor-core GEMM,
|
||||
还省掉激活量化 kernel,精度更好(激活不再有量化误差)。
|
||||
2. 优化只有一个方向:**少读字节**。sparse(×8)与 4-bit(×2)正交,
|
||||
可叠加。本阶段先做 sparse,FP8 与 MXFP4 两种权重格式都支持。
|
||||
|
||||
## 3. Sparse 设计:让 kernel 自己按 topk_ids 索引权重
|
||||
|
||||
关键观察:`topk_ids` 本来就在 GPU 上。不需要 host 知道选了谁 ——
|
||||
**让 GEMV kernel 的每个 block 自己读 `topk_ids[token, slot]`,
|
||||
直接寻址到对应专家的权重**,不命中本卡就整块退出。零 host 同步,
|
||||
管线保持完全异步(这是之前排查过的:decode 循环无 per-layer sync)。
|
||||
|
||||
新数据流(`num_tokens ≤ 8` 时启用):
|
||||
|
||||
```text
|
||||
x [T, 2880]
|
||||
├─ router → topk_ids/weights [T, 4] (不变)
|
||||
├─ sparse GEMV gate_up → [T, 4, 5760] bias 已融合,非本地 slot 不写
|
||||
├─ GLU → [T*4, 2880]
|
||||
├─ sparse GEMV down → [T, 4, 2880] bias 已融合,非本地 slot 不写
|
||||
└─ weighted_sum_sparse → [T, 2880] 只累加本地 slot
|
||||
all-reduce (不变)
|
||||
```
|
||||
|
||||
`moe_replicate` 和独立的 bias kernel 在 sparse 路径下消失;FP8 路径还省掉
|
||||
`quantize_bf16_to_fp8_rowwise`。
|
||||
|
||||
### Kernel 设计(`csrc/moe/moe_sparse.cu`)
|
||||
|
||||
`moe_sparse_gemv_{fp8,mxfp4}_bf16_kernel`:
|
||||
|
||||
- **grid = (N/8, top_k, tokens)**,block = 8 warp × 32 lane。
|
||||
每个 block 负责一个 (token, slot) 的 8 个输出列,**一个 warp 算一个输出**。
|
||||
- block 先读 `eid = topk_ids[token*top_k + slot]`,折算 `lid = eid - expert_start`;
|
||||
不在 `[0, local_experts)` 就整块 return。
|
||||
- 命中的 block 把激活行(K=2880 个 BF16 → float)协作搬进 shared memory
|
||||
(11.25 KB),`__syncthreads()`,然后每 warp 沿 K 维做点积:
|
||||
每 lane 一次 `uint4` 读 16 字节权重(FP8 = 16 个权重,MXFP4 = 32 个 nibble),
|
||||
warp 内 32 lane 连续 → 512B coalesced 事务。
|
||||
- epilogue(lane 0):`y = acc * w_scale[lid] + bias[lid, n]` —— per-expert
|
||||
scale 和 bias 都融合在这里,与 dense 路径的"GEMM → bias add → 路由加权"
|
||||
语义逐位等价(HF 参考实现也是先加 bias 再乘路由权重)。
|
||||
- gate_up 与 down 共用同一个 kernel,用 `x_per_slot` 区分激活寻址:
|
||||
gate_up 时 4 个 slot 共享 `x[token]`;down 时各读自己的 `act[token*4+slot]`。
|
||||
|
||||
### 两个容易写错的安全点
|
||||
|
||||
1. **early-return 必须 block-uniform。** Phase 19 的 GEMV 垃圾输出 bug
|
||||
(commit `3b9e32e`)正是"部分线程在 `__syncthreads()` 之前 return"导致
|
||||
读未初始化 shared memory。这里的 return 发生在 smem 装载**之前**,且整个
|
||||
block 基于同一个 `topk_ids` 值统一退出 —— 没有 divergence,合法且安全。
|
||||
2. **weighted-sum 对非本地 slot 必须"跳过",不能"乘 0"。** 非本地 slot 的
|
||||
GEMV 输出从未被写入(未初始化显存,可能是 NaN 位型),GLU 也会在上面算出
|
||||
垃圾。`NaN × 0 = NaN`,所以求和 kernel 用 `if (local) sum += w*v` 跳过,
|
||||
垃圾永远不进入数据流(dense 路径的 `moe_weighted_sum` 同理)。
|
||||
|
||||
## 4. 为什么 prefill 保持 dense
|
||||
|
||||
dense batched GEMM 把 16 份权重读**一次**,服务全部 M 个 token;
|
||||
sparse GEMV 是**每 token** 重读自己的 ~2 份。字节交叉点:
|
||||
|
||||
```text
|
||||
sparse 读 M × 2 份 vs dense 读 16 份 → M ≈ 8 (TP=2)
|
||||
```
|
||||
|
||||
M > 8 后 dense 更省(且 GEMM 是 compute-bound,tensor core 开始有用)。
|
||||
所以 sparse 只在 `num_tokens ≤ 8` 启用 —— 覆盖 decode(连续批合并的
|
||||
多请求 decode 也是小 M)和极短的 re-prefill。真正的 sparse prefill
|
||||
(按专家对 token 做 permute/gather 的 grouped GEMM,vLLM 的做法)是
|
||||
后续阶段,主要收益在长 prompt TTFT。
|
||||
|
||||
## 5. 实测结果(2026-06-12,完整数据见 `docs/benchmarks/sparse-moe.md`)
|
||||
|
||||
In-process decode(bench-gpt-oss,greedy 96 tok):
|
||||
|
||||
| | TPOT | tok/s |
|
||||
|---|---|---|
|
||||
| dense FP8 TP=2(基线) | 13.9 ms | 72 |
|
||||
| **sparse FP8 TP=2** | **7.6 ms(1.8×)** | **132** |
|
||||
| sparse MXFP4 TP=2 | 8.4 ms | 118 |
|
||||
| sparse FP8 TP=1(单卡) | 7.8 ms | 128 |
|
||||
|
||||
Warm-server 对打 llama.cpp(`tools/xserv_vs_llama.py`):
|
||||
|
||||
- **TP=2 vs TP=2:xserv 首次全面反超** —— TPOT 7.19-7.32 ms vs llama
|
||||
7.54-8.42 ms;短/中 prompt TTFT 也领先(35/49 vs 63/65 ms)。
|
||||
- **TP=1 vs TP=1:llama 大胜**(2.88-3.22 ms vs 7.0-7.2 ms,347 vs 140
|
||||
tok/s)。单卡才是 llama 的最优配置:它的跨卡 split 在 PCIe 上每 token
|
||||
损失 ~5 ms,而单卡时它"全模型 4-bit + CUDA graph 整 token 回放"的
|
||||
优势全部兑现。xserv 的残余 ~7 ms ≈ ~3 ms HBM(其中非专家权重还是
|
||||
BF16,含 1.16 GB 的 lm_head)+ ~4 ms 启动开销(~200 个 kernel
|
||||
launch/token,无 CUDA graph)。
|
||||
- **正确性:GSM8K-100 = 96%**(dense FP8 91% / BF16 90%,greedy 噪声内,
|
||||
无回归)。
|
||||
|
||||
教训:之前"CUDA graph ≈ 无用(~0.5-1.5ms)"的结论是相对 13 ms 的
|
||||
dense TPOT 而言;专家成本砍掉后,launch 开销变成了最大的单项。
|
||||
|
||||
## 6. 下一阶段(按收益排序)
|
||||
|
||||
1. **decode CUDA graph**(~2-4 ms):当前最大单项。
|
||||
2. **非专家权重量化**(~1-1.5 ms):qkv/o + lm_head 仍是 BF16,每 token
|
||||
白读 ~2.3 GB;llama 是全模型 4-bit。
|
||||
3. **sparse prefill**(grouped GEMM):长 prompt TTFT 94-120 ms → llama
|
||||
的 ~30 ms 量级。
|
||||
4. **W4A4 FP4 tensor core / 带宽调优的 MXFP4 GEMV**:让 4-bit 专家真正
|
||||
快过 FP8(目前 8.4 vs 7.6 ms,GEMV 效率抵消了字节优势)。
|
||||
111
docs/21-cuda-graph-decode.md
Normal file
111
docs/21-cuda-graph-decode.md
Normal file
@@ -0,0 +1,111 @@
|
||||
# Phase 21: gpt-oss decode CUDA Graph + GPU argmax
|
||||
|
||||
> 目标:消除 decode 的每 token 固定开销。Phase 20 之后 TPOT ~7ms,其中
|
||||
> GPU 实际计算只占一部分,剩下是 ~200 个 kernel launch 和 per-token 的
|
||||
> host 工作。本阶段把**整个 decode step 捕获成一个 CUDA graph**,每 token
|
||||
> 一次 `cudaGraphLaunch` 回放;顺带把 greedy 采样换成 GPU argmax。
|
||||
>
|
||||
> 实现:`crates/xserv-model/src/gpt_oss_graph.rs`(~150 行)+ 三块基础设施。
|
||||
|
||||
## 1. CUDA Graph 是什么,为什么有约束
|
||||
|
||||
`cudaStreamBeginCapture` 之后,发到该 stream 的 kernel 不执行而是被**录制**;
|
||||
`EndCapture + Instantiate` 得到可执行图;以后每步 `cudaGraphLaunch` 一次性
|
||||
重放全部 ~200 个 kernel,host 端开销从 ~200 次 launch 降到 1 次。
|
||||
|
||||
代价是三条硬约束,每条都对应一个工程问题:
|
||||
|
||||
1. **地址稳定**:录制时烤进图里的全部指针,回放时必须仍然有效且指向正确数据;
|
||||
2. **capture 期间禁止"不安全"调用**:`cudaMalloc`/同步 memcpy/`cudaDeviceSynchronize`
|
||||
都会让 capture 报错(error 900);
|
||||
3. **形状固定**:grid 尺寸被烤死,变 shape 就要重录。
|
||||
|
||||
## 2. 为什么 xserv 的 decode 本来就"差一点"就能整图捕获
|
||||
|
||||
逐项检查 decode step 的输入,发现绝大部分已经满足地址稳定:
|
||||
|
||||
| 每步会变的输入 | 地址 | 内容如何更新 |
|
||||
|---|---|---|
|
||||
| block table / context lens | PagedKVCache 的常驻 GPU 缓冲 ✓ | `decode_prepare` 在图外 H2D |
|
||||
| KV 写入位置 | scatter kernel **从 GPU 上的 context_lens 读** ✓ | 同上 |
|
||||
| attention 读取范围 | paged kernel 从同一缓冲读 ✓ | 同上 |
|
||||
| MoE 专家选择 | sparse GEMV 从图内刚写的 `topk_ids` 读 ✓ | 数据依赖,天然支持 |
|
||||
| token id / position | ✗ 原来是每步从 host slice 上传 | **本阶段改造点** |
|
||||
|
||||
也就是说,Phase 11(paged KV)和 Phase 20(sparse MoE)的"数据驱动"设计
|
||||
无意中已经为 graph 化铺平了路 —— 唯二需要动的是 embedding 的 token id 和
|
||||
RoPE 的 position:各加一个 device-buffer 变体(`embedding_device_ids` /
|
||||
`rope_inplace_device_pos`),id/pos 存进两个常驻 4 字节缓冲,每步图外更新。
|
||||
|
||||
重构后的结构:
|
||||
|
||||
```text
|
||||
forward_decode_paged = decode_prepare(host 簿记,图外)
|
||||
+ upload ids/pos(图外)
|
||||
+ decode_core(纯 GPU,可整段捕获)
|
||||
+ advance_seq_len(host 簿记,图外)
|
||||
```
|
||||
|
||||
## 3. 三个工程问题
|
||||
|
||||
### 3.1 null stream 不可捕获 → thread-local launch stream
|
||||
|
||||
全代码库的 kernel 都发射在 legacy null stream 上,而 capture 必须在显式
|
||||
stream 上。解法:`xserv_cuda::stream` 加一个 **thread-local 当前 stream**
|
||||
(默认 null,行为与从前逐字节一致),所有 kernel wrapper、cuBLAS 的
|
||||
`cublasSetStream`、NCCL 的 collective 全部改读它。capture 代码用 RAII guard
|
||||
(`push_stream`)把 capture stream 装进去,录完自动还原。
|
||||
顺序正确性:显式 stream 以默认(blocking)方式创建,legacy stream 与其
|
||||
双向隐式同步,所以图外的 H2D/采样 memcpy 与回放天然有序。
|
||||
|
||||
### 3.2 capture 期间禁止 cudaMalloc → "retained warmup" 二段式
|
||||
|
||||
中间张量来自 caching allocator;capture 中任何一次 pool miss 都会触发
|
||||
`cudaMalloc` → error 900。第一版实现就栽在这里:**隔离机制自己制造了
|
||||
pool miss**(capture 中释放的块被隔离,下一层同尺寸分配就找不到块了)。
|
||||
|
||||
解法是把同一个 step 先 eager 跑一遍、但**隔离打开**(`begin_retain`):
|
||||
释放的块全部扣下不回池 → 跑完后池外恰好积累了"这一步需要的每一块";
|
||||
把它们整批放回池,再开始 capture —— capture 重复完全相同的分配序列,
|
||||
每次分配都命中池,一次 cudaMalloc 都不会发生。
|
||||
(重复执行同一 step 是无害的:KV scatter 往同一个位置重写同样的值。)
|
||||
|
||||
### 3.3 回放引用的内存不能被别人拿走 → 隔离仓(quarantine)
|
||||
|
||||
capture 录下的中间缓冲在 host 侧早就 Drop 了,但图每次回放都会读写这些
|
||||
地址。若它们回到分配池、被后续 prefill 拿走长期持有,就是双写损坏。
|
||||
所以 capture 期间释放的块进入 `RetainedBlocks` 隔离仓,由 graph 对象持有,
|
||||
graph 销毁时才归还 —— 这些内存在 graph 存活期内被锁定为它专用。
|
||||
|
||||
### 3.4 两个顺手的点
|
||||
|
||||
- **THREAD_LOCAL capture mode**:GLOBAL 模式下,任何线程的 cudaMalloc 都会
|
||||
毒化 capture;TP 多 rank 线程并发 capture 必须用 THREAD_LOCAL。
|
||||
- **NCCL 可以被捕获**:rank 内 `ncclAllReduce` 发在 capture stream 上即可,
|
||||
TP=2 一次成功(各 rank 录各自的图,回放时 collective 自然配对)。
|
||||
|
||||
## 4. 意外的教训:launch 开销没有想象的大,argmax 才是大头
|
||||
|
||||
A/B 实测(in-process,FP8,96 tok):
|
||||
|
||||
| | TP=1 | TP=2 |
|
||||
|---|---|---|
|
||||
| eager + host argmax(Phase 20 末) | 7.5 ms | 7.6 ms |
|
||||
| graph + host argmax | 6.9 ms | 6.9 ms |
|
||||
| eager + **GPU argmax** | 6.5 ms | — |
|
||||
| **graph + GPU argmax** | **5.9 ms** | **5.8 ms** |
|
||||
|
||||
- **graph 只省了 ~0.6ms**:decode 循环本来就是全异步的,launch 大部分被
|
||||
GPU 执行掩盖,"~200 launch ≈ 4ms"的预估错了 —— 优化要测不要猜。
|
||||
- **GPU argmax 省了 ~1ms**:greedy 采样原来每 token 把 [1, 201088] 的
|
||||
logits(402KB)同步拷回 host、再扫描 201K 个 bf16。仓库里 Phase 15 就写好
|
||||
的 argmax kernel(kernel 内归约 + 4 字节 D2H)一直没接到 `sample()` 上。
|
||||
- 细节:GPU argmax 与 host `max_by` 对**完全相等**的 logits 平局取的索引
|
||||
不同,greedy 轨迹会在某个平局 token 处分叉 —— 输出同样合法(GSM8K 验证)。
|
||||
|
||||
## 5. 结果与剩余瓶颈
|
||||
|
||||
见 `docs/benchmarks/sparse-moe.md` 的 Phase 21 小节(warm-server 对打 llama
|
||||
的数字以那里为准)。剩余 TPOT 的构成:~3ms 是 HBM 字节(其中非专家权重
|
||||
仍是 BF16,含 1.16GB 的 lm_head —— **Phase 22 量化它们**),其余是 GEMV
|
||||
带宽效率与 attention。llama 单卡 2.9ms 的差距主要就在"全模型 4-bit"。
|
||||
111
docs/benchmarks/sparse-moe.md
Normal file
111
docs/benchmarks/sparse-moe.md
Normal file
@@ -0,0 +1,111 @@
|
||||
# Sparse MoE decode — 1.8× over dense; beats llama.cpp at TP=2 (gpt-oss-20b, RTX 5090)
|
||||
|
||||
Phase 20 (`docs/20-sparse-moe.md`): decode computes only the routed top-4
|
||||
experts via fused expert-indexed GEMVs (`csrc/moe/moe_sparse.cu`) instead of
|
||||
the dense all-local-expert batched GEMM. FP8 weights run W8A16 (weights FP8,
|
||||
activations BF16 — decode is memory-bound, tensor cores irrelevant at M=1);
|
||||
MXFP4 runs W4A16. Dense path retained for prefill / `num_tokens > 8` and via
|
||||
`XSERV_DENSE_MOE=1` for A/B.
|
||||
|
||||
## In-process decode (bench-gpt-oss, greedy, 96 tokens)
|
||||
|
||||
| config | TPOT | tok/s |
|
||||
|---|---|---|
|
||||
| dense FP8 TP=2 (baseline) | 13.9 ms | 72 |
|
||||
| **sparse FP8 TP=2** | **7.6 ms** | **132** |
|
||||
| sparse MXFP4 TP=2 | 8.4 ms | 118 |
|
||||
| sparse FP8 TP=1 (one 5090) | 7.8 ms | 128 |
|
||||
| sparse MXFP4 TP=1 | 8.9 ms | 113 |
|
||||
|
||||
- Sparse FP8 = **1.8× over dense**. Greedy output stays coherent.
|
||||
- TP=1 ≈ TP=2: expert reads are now so small that PCIe all-reduce eats the
|
||||
TP gain — single-GPU serving becomes the attractive deployment.
|
||||
- MXFP4 reads half the bytes of FP8 but stays slower: the 4-bit dequant GEMV
|
||||
has lower effective bandwidth (same fixed inefficiency seen in the dense
|
||||
MXFP4 experiments); at sparse sizes both are partly launch/latency-bound.
|
||||
|
||||
## Head-to-head vs llama.cpp (tools/xserv_vs_llama.py, warm servers, TP=2, GPUs 0-1, 6 reps, 256 tok)
|
||||
|
||||
| prompt | metric | xserv sparse FP8 | llama MXFP4 | xserv vs llama |
|
||||
|---|---|---|---|---|
|
||||
| short | TTFT | **35.3 ms** | 62.7 ms | 1.78× faster |
|
||||
| short | TPOT | **7.32 ms** | 8.42 ms | 1.15× faster |
|
||||
| medium | TTFT | **49.4 ms** | 65.0 ms | 1.32× faster |
|
||||
| medium | TPOT | **7.19 ms** | 7.54 ms | 1.05× faster |
|
||||
| medium | tok/s | **139.1** | 132.7 | |
|
||||
| long (1.6k) | TTFT | 94.1 ms | **44.7 ms** | 0.48× (llama wins) |
|
||||
| long | TPOT | **7.25 ms** | 7.64 ms | 1.05× faster |
|
||||
|
||||
**Decode TPOT now beats llama.cpp at every prompt length** (was 2× slower:
|
||||
13.1 vs 6.6 ms before sparse). Remaining loss: long-prompt TTFT — prefill is
|
||||
still the dense all-expert GEMM; sparse/grouped prefill is the next phase.
|
||||
|
||||
**Post-review fixes** (same harness, rerun): removing three leftover
|
||||
`cudaDeviceSynchronize` from the decode hot path and replacing the CPU-tiled
|
||||
prefill bias-add (96 D2H/H2D round-trips per prefill) with a GPU broadcast
|
||||
kernel improved both axes — TPOT 7.19-7.32 → **6.99-7.21 ms**, TTFT
|
||||
short/medium/long 35/49/94 → **29/42/79 ms**. GSM8K-50: 94% (unchanged).
|
||||
|
||||
## TP=1 head-to-head (single 5090; server now routes gpt-oss tp=1 to the TP engine)
|
||||
|
||||
| prompt | metric | xserv sparse FP8 | llama MXFP4 |
|
||||
|---|---|---|---|
|
||||
| short | TTFT / TPOT | 42.8 ms / 7.00 ms | **34.5 ms / 3.22 ms** |
|
||||
| medium | TTFT / TPOT | 57.1 ms / 7.19 ms | **37.3 ms / 2.89 ms** |
|
||||
| long | TTFT / TPOT | 119.6 ms / 7.20 ms | **27.8 ms / 2.88 ms** |
|
||||
| | tok/s | 139–143 | **311–347** |
|
||||
|
||||
**Single-GPU is llama.cpp's sweet spot and it wins 2.2–2.5×.** Two structural
|
||||
reasons, both instructive:
|
||||
|
||||
1. llama TP=2 (7.5–8.4 ms) is much WORSE than its TP=1 (2.9 ms): its PCIe
|
||||
cross-GPU split costs ~5 ms/token. xserv's NCCL all-reduce is cheap enough
|
||||
that TP=2 ≈ TP=1 (7.2 vs 7.0 ms) — but xserv's single-GPU floor is high.
|
||||
2. xserv TP=1 reads ~4.7 GB/token (experts FP8 2.4 GB + **non-expert weights
|
||||
still BF16** ~2.3 GB, half of that the 201k-vocab lm_head) ≈ 3.1 ms of pure
|
||||
HBM time; the other ~4 ms is launch overhead (~200 kernels/token, no CUDA
|
||||
graphs) + BF16 GEMV efficiency. llama reads ~1.3 GB (everything MXFP4) and
|
||||
replays the whole token as one CUDA graph.
|
||||
|
||||
## Correctness
|
||||
|
||||
- Greedy generations coherent across prompts (FP8/MXFP4, TP=1/2).
|
||||
- Sparse FP8 is W8A16 vs dense W8A8 — activations are no longer quantized, so
|
||||
tokens are not expected to be byte-identical to dense; quality is checked by
|
||||
GSM8K instead.
|
||||
- **GSM8K-100 (greedy, TP=2, `tools/eval_gsm8k_fast.py`): 96/100 = 96.0%** vs
|
||||
dense FP8 91.0% / BF16 90.0% — no regression (within greedy-nondeterminism
|
||||
noise; W8A16 removes activation-quantization error so ≥ dense is expected).
|
||||
Avg 1.3 s/problem also reflects the decode speedup.
|
||||
|
||||
## Phase 21 update: decode CUDA graph + GPU argmax (docs/21-cuda-graph-decode.md)
|
||||
|
||||
The whole batch=1 decode step now replays as one CUDA graph, and greedy
|
||||
sampling uses the GPU argmax kernel (4-byte D2H instead of a 402 KB logits
|
||||
copy + 201k-element host scan). In-process A/B: graph −0.6 ms, GPU argmax
|
||||
−1.0 ms. Warm-server head-to-head (same harness/GPUs, 6 reps):
|
||||
|
||||
| | xserv FP8 (graph) | llama MXFP4 | |
|
||||
|---|---|---|---|
|
||||
| TP=2 TPOT | **5.76–5.89 ms** (170–174 tok/s) | 7.42–8.45 ms | **xserv 1.26–1.47×** |
|
||||
| TP=2 TTFT s/m/l | **25 / 28 / 51 ms** | 63 / 66 / 45 ms | xserv 2.4× s/m; long ~par |
|
||||
| TP=1 TPOT | 5.78–5.95 ms | **2.80–3.22 ms** | llama 2.0× (was 2.5×) |
|
||||
| TP=1 TTFT s/m | **32 / 35 ms** | 34 / 36 ms | xserv slightly ahead |
|
||||
|
||||
GSM8K-50 through the graph path: 47/50 = 94% (unchanged). Note: GPU argmax
|
||||
breaks exact-tie logits differently than the host scan, so greedy trajectories
|
||||
can legitimately diverge at a tie token.
|
||||
|
||||
## Remaining gaps / next levers (to catch llama TP=1 at 2.8 ms)
|
||||
|
||||
Per-token fixed overhead is now mostly gone; the residual ~5.8 ms is
|
||||
dominated by HBM bytes and kernel efficiency. In impact order:
|
||||
|
||||
1. **Quantize non-expert weights** (~1.5 ms): attn qkv/o + the 1.16 GB BF16
|
||||
lm_head read every token; FP8/MXFP4 them like llama quantizes everything.
|
||||
2. **GEMV/attention bandwidth tuning**: effective BW of the hand GEMVs is
|
||||
well under peak; llama's 2.8 ms implies ~85%+ efficiency on ~1.3 GB.
|
||||
3. **Sparse prefill** (permute tokens by expert + grouped GEMM): long-prompt
|
||||
TTFT 51–75 ms → llama's ~30 ms territory.
|
||||
4. **W4A4 FP4 tensor cores / bandwidth-tuned MXFP4 GEMV**: make 4-bit experts
|
||||
actually beat FP8.
|
||||
@@ -81,7 +81,8 @@ def main():
|
||||
parser.add_argument("--max-tokens", type=int, default=512, help="Max generation tokens")
|
||||
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism")
|
||||
parser.add_argument("--offset", type=int, default=0, help="Start from problem N")
|
||||
parser.add_argument("--gpu", type=int, default=0, help="GPU device index")
|
||||
parser.add_argument("--gpu", type=str, default="0",
|
||||
help="CUDA_VISIBLE_DEVICES value, e.g. '0' or '2,3' (must cover --tp ranks)")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not DATA_PATH.exists():
|
||||
|
||||
Reference in New Issue
Block a user